import logging
from typing import Iterator, List, Optional
import os
import csv
import numpy as np
import polars as pl
import pgenlib as pg
from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.read.base import SNPBaseReader
from snputils.snp.io.read._pgenlib import (
estimate_separate_strands_peak_bytes,
read_separate_strands,
)
log = logging.getLogger(__name__)
def _open_textfile(filename):
if filename.endswith(".zst"):
import zstandard as zstd
return zstd.open(filename, "rt")
return open(filename, "rt")
[docs]
@SNPBaseReader.register
class PGENReader(SNPBaseReader):
[docs]
def read(
self,
fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
sample_ids: Optional[np.ndarray] = None,
sample_idxs: Optional[np.ndarray] = None,
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
separator: str = None,
) -> SNPObject:
"""
Read a pgen fileset (pgen, psam, pvar) into a SNPObject.
Args:
fields (str, None, or list of str, optional): Fields to extract data for that should be included in the returned SNPObject.
Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS', 'FILTER', 'QUAL', 'INFO'.
To extract all fields, set fields to None. Defaults to None.
exclude_fields (str, None, or list of str, optional): Fields to exclude from the returned SNPObject.
Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS', 'FILTER', 'QUAL', 'INFO'.
To exclude no fields, set exclude_fields to None. Defaults to None.
sample_ids: List of sample IDs to read. If None and sample_idxs is None, all samples are read.
sample_idxs: List of sample indices to read. If None and sample_ids is None, all samples are read.
variant_ids: List of variant IDs to read. If None and variant_idxs is None, all variants are read.
variant_idxs: List of variant indices to read. If None and variant_ids is None, all variants are read.
sum_strands: If True, maternal and paternal strands are combined into a single `int8` array with values `{0, 1, 2}`.
If False, strands are stored separately as an `int8` array with values `{0, 1}` for each strand.
Note: With the pgenlib backend, `False` uses a temporary `int32` allele buffer.
separator: Separator used in the pvar file. If None, the separator is automatically detected.
If the automatic detection fails, please specify the separator manually.
Returns:
**SNPObject**:
A SNPObject instance.
"""
assert (
sample_idxs is None or sample_ids is None
), "Only one of sample_idxs and sample_ids can be specified"
assert (
variant_idxs is None or variant_ids is None
), "Only one of variant_idxs and variant_ids can be specified"
if isinstance(fields, str):
fields = [fields]
if isinstance(exclude_fields, str):
exclude_fields = [exclude_fields]
fields = fields or ["GT", "IID", "REF", "ALT", "#CHROM", "CM", "ID", "POS", "FILTER", "QUAL", "INFO"]
exclude_fields = exclude_fields or []
fields = [field for field in fields if field not in exclude_fields]
only_read_pgen = fields == ["GT"] and variant_idxs is None and sample_idxs is None
filename_noext = str(self.filename)
for ext in [".pgen", ".pvar", ".pvar.zst", ".psam"]:
if filename_noext.endswith(ext):
filename_noext = filename_noext[:-len(ext)]
break
if only_read_pgen:
file_num_samples = None # Not needed for pgen
file_num_variants = None # Not needed
else:
pvar_extensions = [".pvar", ".pvar.zst"]
pvar_filename = None
for ext in pvar_extensions:
possible_pvar = filename_noext + ext
if os.path.exists(possible_pvar):
pvar_filename = possible_pvar
break
if pvar_filename is None:
raise FileNotFoundError(f"No .pvar or .pvar.zst file found for {filename_noext}")
log.info(f"Reading {pvar_filename}")
pvar_has_header = True
pvar_header_line_num = 0
with _open_textfile(pvar_filename) as file:
for line_num, line in enumerate(file):
if line.startswith("##"): # Metadata
continue
else:
if separator is None:
separator = csv.Sniffer().sniff(file.readline()).delimiter
if line.startswith("#CHROM"): # Header
pvar_header_line_num = line_num
header = line.strip().split()
break
elif not line.startswith("#"): # If no header, look at line 1
pvar_has_header = False
cols_in_pvar = len(line.strip().split(separator))
if cols_in_pvar == 5:
header = ["#CHROM", "ID", "POS", "ALT", "REF"]
elif cols_in_pvar == 6:
header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"]
else:
raise ValueError(
f"{pvar_filename} is not a valid pvar file."
)
break
pvar_reading_args = {
'separator': separator,
'skip_rows': pvar_header_line_num,
'has_header': pvar_has_header,
'new_columns': None if pvar_has_header else header,
'schema_overrides': {
"#CHROM": pl.String,
"CM": pl.Float64,
"POS": pl.UInt32,
"ID": pl.String,
"REF": pl.String,
"ALT": pl.String,
"QUAL": pl.String,
"FILTER": pl.String,
"INFO": pl.String,
},
'null_values': ["NA"],
}
if pvar_filename.endswith('.zst'):
pvar = pl.read_csv(pvar_filename, **pvar_reading_args).lazy()
else:
pvar = pl.scan_csv(pvar_filename, **pvar_reading_args)
# We need to map requested IDs to row positions before reading genotypes.
variant_meta = pvar.select(["ID", "#CHROM", "POS"]).with_row_index().collect()
file_num_variants = variant_meta.height
if variant_ids is not None:
variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)]
variant_id_or_pos = (
pl.col("ID").is_in(variant_id_values)
| pl.concat_str(
[pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)]
).is_in(variant_id_values)
)
variant_idxs = (
variant_meta.filter(variant_id_or_pos)
.select("index")
.to_series()
.to_numpy()
)
if variant_idxs is None:
num_variants = file_num_variants
variant_idxs = np.arange(num_variants, dtype=np.uint32)
pvar = pvar.collect()
else:
pvar = (
pvar.with_row_index()
.filter(pl.col("index").is_in(np.asarray(variant_idxs, dtype=np.uint32).ravel()))
.collect()
)
variant_idxs = pvar.select("index").to_series().to_numpy()
variant_idxs = np.asarray(variant_idxs, dtype=np.uint32)
num_variants = np.size(variant_idxs)
pvar = pvar.drop("index")
log.info(f"Reading {filename_noext}.psam")
with open(filename_noext + ".psam") as file:
first_line = file.readline().strip()
psam_has_header = first_line.startswith(("#FID", "FID", "#IID", "IID"))
psam = pl.read_csv(
filename_noext + ".psam",
separator=separator,
has_header=psam_has_header,
new_columns=None if psam_has_header else ["FID", "IID", "PAT", "MAT", "SEX", "PHENO1"],
schema_overrides={
"#FID": pl.String,
"FID": pl.String,
"#IID": pl.String,
"IID": pl.String,
"PAT": pl.String,
"MAT": pl.String,
"SEX": pl.String,
"PHENO1": pl.String,
},
null_values=["NA"],
).with_row_index()
if "#IID" in psam.columns:
psam = psam.rename({"#IID": "IID"})
if "#FID" in psam.columns:
psam = psam.rename({"#FID": "FID"})
file_num_samples = psam.height
if sample_ids is not None:
psam = psam.filter(pl.col("IID").is_in(sample_ids))
sample_idxs = psam.select("index").to_series().to_numpy()
num_samples = np.size(sample_idxs)
elif sample_idxs is not None:
num_samples = np.size(sample_idxs)
sample_idxs = np.array(sample_idxs, dtype=np.uint32)
psam = psam.filter(pl.col("index").is_in(sample_idxs))
else:
num_samples = file_num_samples
if "GT" in fields:
log.info(f"Reading {filename_noext}.pgen")
pgen_reader = pg.PgenReader(
str.encode(filename_noext + ".pgen"),
raw_sample_ct=file_num_samples,
variant_ct=file_num_variants,
sample_subset=sample_idxs,
)
if only_read_pgen:
num_samples = pgen_reader.get_raw_sample_ct()
num_variants = pgen_reader.get_variant_ct()
variant_idxs = np.arange(num_variants, dtype=np.uint32)
# required arrays: variant_idxs + sample_idxs + genotypes
if not sum_strands:
required_ram = (
(num_samples + num_variants) * 4
+ estimate_separate_strands_peak_bytes(num_variants, num_samples)
)
else:
required_ram = (num_samples + num_variants) * 4 + num_variants * num_samples
log.info(f">{required_ram / 1024**3:.2f} GiB of RAM are required to process {num_samples} samples with {num_variants} variants each")
if not sum_strands:
genotypes = read_separate_strands(
pgen_reader,
variant_idxs,
num_variants,
num_samples,
)
else:
genotypes = np.empty((num_variants, num_samples), dtype=np.int8)
pgen_reader.read_list(variant_idxs, genotypes)
pgen_reader.close()
else:
genotypes = None
log.info("Constructing SNPObject")
fid_col = None
if "IID" in fields and "FID" in psam.columns:
fid_col = psam.get_column("FID").fill_null("NA").cast(pl.String).to_numpy()
sex_col = None
if "IID" in fields and "SEX" in psam.columns:
sex_col = psam.get_column("SEX").fill_null("NA").cast(pl.String).to_numpy()
snpobj = SNPObject(
calldata_gt=genotypes if "GT" in fields else None,
samples=psam.get_column("IID").to_numpy() if "IID" in fields and "IID" in psam.columns else None,
sample_fid=fid_col,
sample_sex=sex_col,
**{f'variants_{k.lower()}': pvar.get_column(v).to_numpy() if v in fields and v in pvar.columns else None
for k, v in {'ref': 'REF', 'alt': 'ALT', 'chrom': '#CHROM', 'cm': 'CM', 'id': 'ID', 'pos': 'POS', 'filter_pass': 'FILTER', 'qual': 'QUAL', 'info': 'INFO'}.items()}
)
log.info("Finished constructing SNPObject")
return snpobj
def _resolve_variant_idxs_for_iter(
self,
*,
variant_ids: Optional[np.ndarray],
variant_idxs: Optional[np.ndarray],
separator: str = None,
) -> np.ndarray:
"""
Resolve variant selectors to canonical file-order row indices.
"""
filename_noext = str(self.filename)
for ext in [".pgen", ".pvar", ".pvar.zst", ".psam"]:
if filename_noext.endswith(ext):
filename_noext = filename_noext[:-len(ext)]
break
pvar_filename = None
for ext in [".pvar", ".pvar.zst"]:
candidate = filename_noext + ext
if os.path.exists(candidate):
pvar_filename = candidate
break
if pvar_filename is None:
raise FileNotFoundError(f"No .pvar or .pvar.zst file found for {filename_noext}")
local_separator = separator
pvar_has_header = True
pvar_header_line_num = 0
with _open_textfile(pvar_filename) as file:
for line_num, line in enumerate(file):
if line.startswith("##"):
continue
if local_separator is None:
local_separator = csv.Sniffer().sniff(file.readline()).delimiter
if line.startswith("#CHROM"):
pvar_header_line_num = line_num
header = line.strip().split()
break
if not line.startswith("#"):
pvar_has_header = False
cols_in_pvar = len(line.strip().split(local_separator))
if cols_in_pvar == 5:
header = ["#CHROM", "ID", "POS", "ALT", "REF"]
elif cols_in_pvar == 6:
header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"]
else:
raise ValueError(f"{pvar_filename} is not a valid pvar file.")
break
pvar_reading_args = {
"separator": local_separator,
"skip_rows": pvar_header_line_num,
"has_header": pvar_has_header,
"new_columns": None if pvar_has_header else header,
"schema_overrides": {
"#CHROM": pl.String,
"CM": pl.Float64,
"POS": pl.UInt32,
"ID": pl.String,
"REF": pl.String,
"ALT": pl.String,
"QUAL": pl.String,
"FILTER": pl.String,
"INFO": pl.String,
},
"null_values": ["NA"],
}
if pvar_filename.endswith(".zst"):
pvar = pl.read_csv(pvar_filename, **pvar_reading_args)
else:
pvar = pl.scan_csv(pvar_filename, **pvar_reading_args).collect()
variant_meta = pvar.select(["ID", "#CHROM", "POS"]).with_row_index()
if variant_ids is not None:
variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)]
variant_id_or_pos = (
pl.col("ID").is_in(variant_id_values)
| pl.concat_str([pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)]).is_in(
variant_id_values
)
)
resolved = (
variant_meta.filter(variant_id_or_pos)
.select("index")
.to_series()
.to_numpy()
)
return np.asarray(resolved, dtype=np.uint32)
if variant_idxs is not None:
requested = np.asarray(variant_idxs, dtype=np.uint32).ravel()
resolved = (
variant_meta.filter(pl.col("index").is_in(requested))
.select("index")
.to_series()
.to_numpy()
)
return np.asarray(resolved, dtype=np.uint32)
return np.arange(variant_meta.height, dtype=np.uint32)
[docs]
def iter_read(
self,
fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
sample_ids: Optional[np.ndarray] = None,
sample_idxs: Optional[np.ndarray] = None,
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
separator: str = None,
chunk_size: int = 10_000,
) -> Iterator[SNPObject]:
"""
Stream the PGEN fileset in variant chunks.
This yields a sequence of SNPObject chunks along the SNP axis.
"""
if chunk_size < 1:
raise ValueError("chunk_size must be >= 1.")
if sample_idxs is not None and sample_ids is not None:
raise ValueError("Only one of sample_idxs and sample_ids can be specified.")
if variant_idxs is not None and variant_ids is not None:
raise ValueError("Only one of variant_idxs and variant_ids can be specified.")
selectors = self._resolve_variant_idxs_for_iter(
variant_ids=variant_ids,
variant_idxs=variant_idxs,
separator=separator,
)
n_selectors = int(selectors.size)
for start in range(0, n_selectors, int(chunk_size)):
stop = min(start + int(chunk_size), n_selectors)
selector_chunk = np.asarray(selectors[start:stop], dtype=np.uint32)
yield self.read(
fields=fields,
exclude_fields=exclude_fields,
sample_ids=sample_ids,
sample_idxs=sample_idxs,
variant_idxs=selector_chunk,
sum_strands=sum_strands,
separator=separator,
)