Source code for snputils.snp.io.read.pgen

import logging
from typing import Iterator, List, Optional
import os
import csv

import numpy as np
import polars as pl
import pgenlib as pg

from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.read.base import SNPBaseReader
from snputils.snp.io.read._pgenlib import (
    estimate_separate_strands_peak_bytes,
    read_separate_strands,
)

log = logging.getLogger(__name__)


def _open_textfile(filename):
    if filename.endswith(".zst"):
        import zstandard as zstd
        return zstd.open(filename, "rt")
    return open(filename, "rt")


[docs] @SNPBaseReader.register class PGENReader(SNPBaseReader):
[docs] def read( self, fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, sample_ids: Optional[np.ndarray] = None, sample_idxs: Optional[np.ndarray] = None, variant_ids: Optional[np.ndarray] = None, variant_idxs: Optional[np.ndarray] = None, sum_strands: bool = False, separator: str = None, ) -> SNPObject: """ Read a pgen fileset (pgen, psam, pvar) into a SNPObject. Args: fields (str, None, or list of str, optional): Fields to extract data for that should be included in the returned SNPObject. Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS', 'FILTER', 'QUAL', 'INFO'. To extract all fields, set fields to None. Defaults to None. exclude_fields (str, None, or list of str, optional): Fields to exclude from the returned SNPObject. Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS', 'FILTER', 'QUAL', 'INFO'. To exclude no fields, set exclude_fields to None. Defaults to None. sample_ids: List of sample IDs to read. If None and sample_idxs is None, all samples are read. sample_idxs: List of sample indices to read. If None and sample_ids is None, all samples are read. variant_ids: List of variant IDs to read. If None and variant_idxs is None, all variants are read. variant_idxs: List of variant indices to read. If None and variant_ids is None, all variants are read. sum_strands: If True, maternal and paternal strands are combined into a single `int8` array with values `{0, 1, 2}`. If False, strands are stored separately as an `int8` array with values `{0, 1}` for each strand. Note: With the pgenlib backend, `False` uses a temporary `int32` allele buffer. separator: Separator used in the pvar file. If None, the separator is automatically detected. If the automatic detection fails, please specify the separator manually. Returns: **SNPObject**: A SNPObject instance. """ assert ( sample_idxs is None or sample_ids is None ), "Only one of sample_idxs and sample_ids can be specified" assert ( variant_idxs is None or variant_ids is None ), "Only one of variant_idxs and variant_ids can be specified" if isinstance(fields, str): fields = [fields] if isinstance(exclude_fields, str): exclude_fields = [exclude_fields] fields = fields or ["GT", "IID", "REF", "ALT", "#CHROM", "CM", "ID", "POS", "FILTER", "QUAL", "INFO"] exclude_fields = exclude_fields or [] fields = [field for field in fields if field not in exclude_fields] only_read_pgen = fields == ["GT"] and variant_idxs is None and sample_idxs is None filename_noext = str(self.filename) for ext in [".pgen", ".pvar", ".pvar.zst", ".psam"]: if filename_noext.endswith(ext): filename_noext = filename_noext[:-len(ext)] break if only_read_pgen: file_num_samples = None # Not needed for pgen file_num_variants = None # Not needed else: pvar_extensions = [".pvar", ".pvar.zst"] pvar_filename = None for ext in pvar_extensions: possible_pvar = filename_noext + ext if os.path.exists(possible_pvar): pvar_filename = possible_pvar break if pvar_filename is None: raise FileNotFoundError(f"No .pvar or .pvar.zst file found for {filename_noext}") log.info(f"Reading {pvar_filename}") pvar_has_header = True pvar_header_line_num = 0 with _open_textfile(pvar_filename) as file: for line_num, line in enumerate(file): if line.startswith("##"): # Metadata continue else: if separator is None: separator = csv.Sniffer().sniff(file.readline()).delimiter if line.startswith("#CHROM"): # Header pvar_header_line_num = line_num header = line.strip().split() break elif not line.startswith("#"): # If no header, look at line 1 pvar_has_header = False cols_in_pvar = len(line.strip().split(separator)) if cols_in_pvar == 5: header = ["#CHROM", "ID", "POS", "ALT", "REF"] elif cols_in_pvar == 6: header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"] else: raise ValueError( f"{pvar_filename} is not a valid pvar file." ) break pvar_reading_args = { 'separator': separator, 'skip_rows': pvar_header_line_num, 'has_header': pvar_has_header, 'new_columns': None if pvar_has_header else header, 'schema_overrides': { "#CHROM": pl.String, "CM": pl.Float64, "POS": pl.UInt32, "ID": pl.String, "REF": pl.String, "ALT": pl.String, "QUAL": pl.String, "FILTER": pl.String, "INFO": pl.String, }, 'null_values': ["NA"], } if pvar_filename.endswith('.zst'): pvar = pl.read_csv(pvar_filename, **pvar_reading_args).lazy() else: pvar = pl.scan_csv(pvar_filename, **pvar_reading_args) # We need to map requested IDs to row positions before reading genotypes. variant_meta = pvar.select(["ID", "#CHROM", "POS"]).with_row_index().collect() file_num_variants = variant_meta.height if variant_ids is not None: variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)] variant_id_or_pos = ( pl.col("ID").is_in(variant_id_values) | pl.concat_str( [pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)] ).is_in(variant_id_values) ) variant_idxs = ( variant_meta.filter(variant_id_or_pos) .select("index") .to_series() .to_numpy() ) if variant_idxs is None: num_variants = file_num_variants variant_idxs = np.arange(num_variants, dtype=np.uint32) pvar = pvar.collect() else: pvar = ( pvar.with_row_index() .filter(pl.col("index").is_in(np.asarray(variant_idxs, dtype=np.uint32).ravel())) .collect() ) variant_idxs = pvar.select("index").to_series().to_numpy() variant_idxs = np.asarray(variant_idxs, dtype=np.uint32) num_variants = np.size(variant_idxs) pvar = pvar.drop("index") log.info(f"Reading {filename_noext}.psam") with open(filename_noext + ".psam") as file: first_line = file.readline().strip() psam_has_header = first_line.startswith(("#FID", "FID", "#IID", "IID")) psam = pl.read_csv( filename_noext + ".psam", separator=separator, has_header=psam_has_header, new_columns=None if psam_has_header else ["FID", "IID", "PAT", "MAT", "SEX", "PHENO1"], schema_overrides={ "#FID": pl.String, "FID": pl.String, "#IID": pl.String, "IID": pl.String, "PAT": pl.String, "MAT": pl.String, "SEX": pl.String, "PHENO1": pl.String, }, null_values=["NA"], ).with_row_index() if "#IID" in psam.columns: psam = psam.rename({"#IID": "IID"}) if "#FID" in psam.columns: psam = psam.rename({"#FID": "FID"}) file_num_samples = psam.height if sample_ids is not None: psam = psam.filter(pl.col("IID").is_in(sample_ids)) sample_idxs = psam.select("index").to_series().to_numpy() num_samples = np.size(sample_idxs) elif sample_idxs is not None: num_samples = np.size(sample_idxs) sample_idxs = np.array(sample_idxs, dtype=np.uint32) psam = psam.filter(pl.col("index").is_in(sample_idxs)) else: num_samples = file_num_samples if "GT" in fields: log.info(f"Reading {filename_noext}.pgen") pgen_reader = pg.PgenReader( str.encode(filename_noext + ".pgen"), raw_sample_ct=file_num_samples, variant_ct=file_num_variants, sample_subset=sample_idxs, ) if only_read_pgen: num_samples = pgen_reader.get_raw_sample_ct() num_variants = pgen_reader.get_variant_ct() variant_idxs = np.arange(num_variants, dtype=np.uint32) # required arrays: variant_idxs + sample_idxs + genotypes if not sum_strands: required_ram = ( (num_samples + num_variants) * 4 + estimate_separate_strands_peak_bytes(num_variants, num_samples) ) else: required_ram = (num_samples + num_variants) * 4 + num_variants * num_samples log.info(f">{required_ram / 1024**3:.2f} GiB of RAM are required to process {num_samples} samples with {num_variants} variants each") if not sum_strands: genotypes = read_separate_strands( pgen_reader, variant_idxs, num_variants, num_samples, ) else: genotypes = np.empty((num_variants, num_samples), dtype=np.int8) pgen_reader.read_list(variant_idxs, genotypes) pgen_reader.close() else: genotypes = None log.info("Constructing SNPObject") fid_col = None if "IID" in fields and "FID" in psam.columns: fid_col = psam.get_column("FID").fill_null("NA").cast(pl.String).to_numpy() sex_col = None if "IID" in fields and "SEX" in psam.columns: sex_col = psam.get_column("SEX").fill_null("NA").cast(pl.String).to_numpy() snpobj = SNPObject( calldata_gt=genotypes if "GT" in fields else None, samples=psam.get_column("IID").to_numpy() if "IID" in fields and "IID" in psam.columns else None, sample_fid=fid_col, sample_sex=sex_col, **{f'variants_{k.lower()}': pvar.get_column(v).to_numpy() if v in fields and v in pvar.columns else None for k, v in {'ref': 'REF', 'alt': 'ALT', 'chrom': '#CHROM', 'cm': 'CM', 'id': 'ID', 'pos': 'POS', 'filter_pass': 'FILTER', 'qual': 'QUAL', 'info': 'INFO'}.items()} ) log.info("Finished constructing SNPObject") return snpobj
def _resolve_variant_idxs_for_iter( self, *, variant_ids: Optional[np.ndarray], variant_idxs: Optional[np.ndarray], separator: str = None, ) -> np.ndarray: """ Resolve variant selectors to canonical file-order row indices. """ filename_noext = str(self.filename) for ext in [".pgen", ".pvar", ".pvar.zst", ".psam"]: if filename_noext.endswith(ext): filename_noext = filename_noext[:-len(ext)] break pvar_filename = None for ext in [".pvar", ".pvar.zst"]: candidate = filename_noext + ext if os.path.exists(candidate): pvar_filename = candidate break if pvar_filename is None: raise FileNotFoundError(f"No .pvar or .pvar.zst file found for {filename_noext}") local_separator = separator pvar_has_header = True pvar_header_line_num = 0 with _open_textfile(pvar_filename) as file: for line_num, line in enumerate(file): if line.startswith("##"): continue if local_separator is None: local_separator = csv.Sniffer().sniff(file.readline()).delimiter if line.startswith("#CHROM"): pvar_header_line_num = line_num header = line.strip().split() break if not line.startswith("#"): pvar_has_header = False cols_in_pvar = len(line.strip().split(local_separator)) if cols_in_pvar == 5: header = ["#CHROM", "ID", "POS", "ALT", "REF"] elif cols_in_pvar == 6: header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"] else: raise ValueError(f"{pvar_filename} is not a valid pvar file.") break pvar_reading_args = { "separator": local_separator, "skip_rows": pvar_header_line_num, "has_header": pvar_has_header, "new_columns": None if pvar_has_header else header, "schema_overrides": { "#CHROM": pl.String, "CM": pl.Float64, "POS": pl.UInt32, "ID": pl.String, "REF": pl.String, "ALT": pl.String, "QUAL": pl.String, "FILTER": pl.String, "INFO": pl.String, }, "null_values": ["NA"], } if pvar_filename.endswith(".zst"): pvar = pl.read_csv(pvar_filename, **pvar_reading_args) else: pvar = pl.scan_csv(pvar_filename, **pvar_reading_args).collect() variant_meta = pvar.select(["ID", "#CHROM", "POS"]).with_row_index() if variant_ids is not None: variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)] variant_id_or_pos = ( pl.col("ID").is_in(variant_id_values) | pl.concat_str([pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)]).is_in( variant_id_values ) ) resolved = ( variant_meta.filter(variant_id_or_pos) .select("index") .to_series() .to_numpy() ) return np.asarray(resolved, dtype=np.uint32) if variant_idxs is not None: requested = np.asarray(variant_idxs, dtype=np.uint32).ravel() resolved = ( variant_meta.filter(pl.col("index").is_in(requested)) .select("index") .to_series() .to_numpy() ) return np.asarray(resolved, dtype=np.uint32) return np.arange(variant_meta.height, dtype=np.uint32)
[docs] def iter_read( self, fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, sample_ids: Optional[np.ndarray] = None, sample_idxs: Optional[np.ndarray] = None, variant_ids: Optional[np.ndarray] = None, variant_idxs: Optional[np.ndarray] = None, sum_strands: bool = False, separator: str = None, chunk_size: int = 10_000, ) -> Iterator[SNPObject]: """ Stream the PGEN fileset in variant chunks. This yields a sequence of SNPObject chunks along the SNP axis. """ if chunk_size < 1: raise ValueError("chunk_size must be >= 1.") if sample_idxs is not None and sample_ids is not None: raise ValueError("Only one of sample_idxs and sample_ids can be specified.") if variant_idxs is not None and variant_ids is not None: raise ValueError("Only one of variant_idxs and variant_ids can be specified.") selectors = self._resolve_variant_idxs_for_iter( variant_ids=variant_ids, variant_idxs=variant_idxs, separator=separator, ) n_selectors = int(selectors.size) for start in range(0, n_selectors, int(chunk_size)): stop = min(start + int(chunk_size), n_selectors) selector_chunk = np.asarray(selectors[start:stop], dtype=np.uint32) yield self.read( fields=fields, exclude_fields=exclude_fields, sample_ids=sample_ids, sample_idxs=sample_idxs, variant_idxs=selector_chunk, sum_strands=sum_strands, separator=separator, )