Source code for snputils.snp.io.write.pgen

import logging
import numpy as np
import polars as pl
import pgenlib as pg
from pathlib import Path
import zstandard as zstd
from typing import Optional, Sequence, Union

from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.write._plink import coerce_sex_codes

log = logging.getLogger(__name__)


[docs] class PGENWriter: """ Writes a genotype object in PGEN format (.pgen, .psam, and .pvar files) in the specified output path. """ def __init__(self, snpobj: SNPObject, filename: str): """ Initializes the PGENWriter instance. Args: snpobj (SNPObject): The SNPObject containing genotype data to be written. filename (str): Base path for the output files (excluding extension). """ self.__snpobj = snpobj self.__filename = Path(filename)
[docs] def write( self, vzs: bool = False, rename_missing_values: bool = True, before: Union[int, float, str] = -1, after: Union[int, float, str] = '.' ): """ Writes the SNPObject data to .pgen, .psam, and .pvar files. Args: vzs (bool, optional): If True, compresses the .pvar file using zstd and saves it as .pvar.zst. Defaults to False. rename_missing_values (bool, optional): If True, renames potential missing values in `snpobj.calldata_gt` before writing. Defaults to True. before (int, float, or str, default=-1): The current representation of missing values in `calldata_gt`. Common values might be -1, '.', or NaN. Default is -1. after (int, float, or str, default='.'): The value that will replace `before`. Default is '.'. """ file_extensions = (".pgen", ".psam", ".pvar", ".pvar.zst") if self.__filename.suffix in file_extensions: self.__filename = self.__filename.with_suffix('') # Optionally rename potential missing values in `snpobj.calldata_gt` before writing if rename_missing_values: self.__snpobj.rename_missings(before=before, after=after, inplace=True) self.write_pvar(vzs=vzs) self.write_psam() self.write_pgen()
[docs] def write_pvar(self, vzs: bool = False): """ Writes variant data to the .pvar file. Args: vzs (bool, optional): If True, compresses the .pvar file using zstd and saves it as .pvar.zst. Defaults to False. """ output_filename = f"{self.__filename}.pvar" if vzs: output_filename += ".zst" log.info(f"Writing to {output_filename} (compressed)") else: log.info(f"Writing to {output_filename}") df = pl.DataFrame( { "#CHROM": self.__snpobj.variants_chrom, "POS": self.__snpobj.variants_pos, "ID": self.__snpobj.variants_id, "REF": self.__snpobj.variants_ref, "ALT": self.__snpobj.variants_alt, "QUAL": self._coerce_variant_column(self.__snpobj.variants_qual, self.__snpobj.n_snps), "FILTER": self._coerce_variant_column(self.__snpobj.variants_filter_pass, self.__snpobj.n_snps), "INFO": self._coerce_variant_column(self.__snpobj.variants_info, self.__snpobj.n_snps), } ) # Write the DataFrame to a CSV string csv_data = "##fileformat=VCFv4.2\n##source=snputils\n" + df.write_csv(None, separator="\t") if vzs: # Compress the CSV data using zstd cctx = zstd.ZstdCompressor() compressed_data = cctx.compress(csv_data.encode('utf-8')) with open(output_filename, 'wb') as f: f.write(compressed_data) else: with open(output_filename, 'w') as f: f.write(csv_data)
[docs] def write_psam(self): """ Writes sample metadata to the .psam file. """ log.info(f"Writing {self.__filename}.psam") columns = {} if self.__snpobj.sample_fid is not None: columns["#FID"] = self._coerce_sample_column( self.__snpobj.sample_fid, self.__snpobj.n_samples, column_name="sample_fid", ) columns["IID"] = self.__snpobj.samples else: columns["#IID"] = self.__snpobj.samples columns["SEX"] = coerce_sex_codes(self.__snpobj.sample_sex, self.__snpobj.n_samples, missing_code="NA") df = pl.DataFrame(columns) df.write_csv(f"{self.__filename}.psam", separator="\t")
@staticmethod def _coerce_variant_column( values: Optional[Union[np.ndarray, Sequence[Union[str, int, float]]]], n_variants: int, default: str = ".", ) -> np.ndarray: if values is None: return np.repeat(default, n_variants) arr = np.asarray(values) if arr.shape[0] != n_variants: raise ValueError(f"variant metadata length ({arr.shape[0]}) must match number of variants ({n_variants}).") return np.asarray([PGENWriter._missing_to_default(value, default) for value in arr], dtype=object) @staticmethod def _coerce_sample_column( values: Union[np.ndarray, Sequence[Union[str, int, float]]], n_samples: int, *, column_name: str, ) -> np.ndarray: arr = np.asarray(values) if arr.shape[0] != n_samples: raise ValueError(f"{column_name} length ({arr.shape[0]}) must match number of samples ({n_samples}).") return arr @staticmethod def _missing_to_default(value, default: str) -> str: if value is None: return default text = str(value) if text == "" or text.lower() in {"nan", "none"}: return default return text
[docs] def write_pgen(self): """ Writes the genotype data to a .pgen file. """ log.info(f"Writing to {self.__filename}.pgen") summed_strands = False if self.__snpobj.calldata_gt.ndim == 3 else True if not summed_strands: num_variants, num_samples, num_alleles = self.__snpobj.calldata_gt.shape # Flatten the genotype matrix for pgenlib flat_genotypes = self.__snpobj.calldata_gt.reshape( num_variants, num_samples * num_alleles ) with pg.PgenWriter( filename=f"{self.__filename}.pgen".encode('utf-8'), sample_ct=num_samples, variant_ct=num_variants, hardcall_phase_present=True, ) as writer: for variant_index in range(num_variants): writer.append_alleles( flat_genotypes[variant_index].astype(np.int32), all_phased=True ) else: num_variants, num_samples = self.__snpobj.calldata_gt.shape # Transpose to (samples, variants) genotypes = self.__snpobj.calldata_gt.T # Shape is (samples, variants) with pg.PgenWriter( filename=f"{self.__filename}.pgen".encode('utf-8'), sample_ct=num_samples, variant_ct=num_variants, hardcall_phase_present=False, ) as writer: for variant_index in range(num_variants): variant_genotypes = genotypes[:, variant_index].astype(np.int8) # Map missing genotypes to -9 if necessary variant_genotypes[variant_genotypes == -1] = -9 writer.append_biallelic(np.ascontiguousarray(variant_genotypes))