import logging
from typing import Iterator, List, Optional
import csv
import numpy as np
import polars as pl
import pgenlib as pg
from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.read.base import SNPBaseReader
from snputils.snp.io.read._pgenlib import (
estimate_separate_strands_peak_bytes,
read_separate_strands,
)
log = logging.getLogger(__name__)
[docs]
@SNPBaseReader.register
class BEDReader(SNPBaseReader):
[docs]
def read(
self,
fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
sample_ids: Optional[np.ndarray] = None,
sample_idxs: Optional[np.ndarray] = None,
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
separator: Optional[str] = None,
) -> SNPObject:
"""
Read a bed fileset (bed, bim, fam) into a SNPObject.
Args:
fields (str, None, or list of str, optional): Fields to extract data for that should be included in the returned SNPObject.
Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS'.
To extract all fields, set fields to None. Defaults to None.
exclude_fields (str, None, or list of str, optional): Fields to exclude from the returned SNPObject.
Available fields are 'GT', 'IID', 'REF', 'ALT', '#CHROM', 'CM', 'ID', 'POS'.
To exclude no fields, set exclude_fields to None. Defaults to None.
sample_ids: List of sample IDs to read. If None and sample_idxs is None, all samples are read.
sample_idxs: List of sample indices to read. If None and sample_ids is None, all samples are read.
variant_ids: List of variant IDs to read. If None and variant_idxs is None, all variants are read.
variant_idxs: List of variant indices to read. If None and variant_ids is None, all variants are read.
sum_strands: If True, maternal and paternal strands are combined into a single `int8` array with values `{0, 1, 2`}.
If False, strands are stored separately as an `int8` array with values `{0, 1}` for each strand.
Note: With the pgenlib backend, `False` uses a temporary `int32` allele buffer.
separator: Separator used in the pvar file. If None, the separator is automatically detected.
If the automatic detection fails, please specify the separator manually.
Returns:
**SNPObject**:
A SNPObject instance.
"""
assert (
sample_idxs is None or sample_ids is None
), "Only one of sample_idxs and sample_ids can be specified"
assert (
variant_idxs is None or variant_ids is None
), "Only one of variant_idxs and variant_ids can be specified"
if isinstance(fields, str):
fields = [fields]
if isinstance(exclude_fields, str):
exclude_fields = [exclude_fields]
fields = fields or ["GT", "IID", "REF", "ALT", "#CHROM", "CM", "ID", "POS"]
exclude_fields = exclude_fields or []
fields = [field for field in fields if field not in exclude_fields]
only_read_bed = fields == ["GT"] and variant_idxs is None and sample_idxs is None
filename_noext = str(self.filename)
if filename_noext[-4:].lower() in (".bed", ".bim", ".fam"):
filename_noext = filename_noext[:-4]
if only_read_bed:
with open(filename_noext + '.fam', 'r') as f:
file_num_samples = sum(1 for _ in f) # Get sample count from fam file
file_num_variants = None # Not needed
else:
log.info(f"Reading {filename_noext}.bim")
if separator is None:
with open(filename_noext + ".bim", "r") as file:
separator = csv.Sniffer().sniff(file.readline()).delimiter
bim = pl.read_csv(
filename_noext + ".bim",
separator=separator,
has_header=False,
new_columns=["#CHROM", "ID", "CM", "POS", "ALT", "REF"],
schema_overrides={
"#CHROM": pl.String,
"ID": pl.String,
"CM": pl.Float64,
"POS": pl.Int64,
"ALT": pl.String,
"REF": pl.String
},
null_values=["NA"]
).with_row_index()
file_num_variants = bim.height
if variant_ids is not None:
variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)]
variant_id_or_pos = (
pl.col("ID").is_in(variant_id_values)
| pl.concat_str(
[pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)]
).is_in(variant_id_values)
)
variant_idxs = (
bim.filter(variant_id_or_pos)
.select("index")
.to_series()
.to_numpy()
)
if variant_idxs is None:
num_variants = file_num_variants
variant_idxs = np.arange(num_variants, dtype=np.uint32)
else:
requested_variant_idxs = np.asarray(variant_idxs, dtype=np.uint32).ravel()
bim = bim.filter(pl.col("index").is_in(requested_variant_idxs))
variant_idxs = bim.select("index").to_series().to_numpy()
variant_idxs = np.asarray(variant_idxs, dtype=np.uint32)
num_variants = np.size(variant_idxs)
log.info(f"Reading {filename_noext}.fam")
fam = pl.read_csv(
filename_noext + ".fam",
separator=separator,
has_header=False,
new_columns=["Family ID", "IID", "Father ID",
"Mother ID", "Sex code", "Phenotype value"],
schema_overrides={
"Family ID": pl.String,
"IID": pl.String,
"Father ID": pl.String,
"Mother ID": pl.String,
"Sex code": pl.String,
},
null_values=["NA"]
).with_row_index()
file_num_samples = fam.height
if sample_ids is not None:
sample_idxs = fam.filter(pl.col("IID").is_in(sample_ids)).select("index").to_series().to_numpy()
if sample_idxs is None:
num_samples = file_num_samples
else:
num_samples = np.size(sample_idxs)
sample_idxs = np.array(sample_idxs, dtype=np.uint32)
fam = fam.filter(pl.col("index").is_in(sample_idxs))
if "GT" in fields:
log.info(f"Reading {filename_noext}.bed")
pgen_reader = pg.PgenReader(
str.encode(filename_noext + ".bed"),
raw_sample_ct=file_num_samples,
variant_ct=file_num_variants,
sample_subset=sample_idxs,
)
if only_read_bed:
num_samples = pgen_reader.get_raw_sample_ct()
num_variants = pgen_reader.get_variant_ct()
variant_idxs = np.arange(num_variants, dtype=np.uint32)
# required arrays: variant_idxs + sample_idxs + genotypes
if not sum_strands:
required_ram = (
(num_samples + num_variants) * 4
+ estimate_separate_strands_peak_bytes(num_variants, num_samples)
)
else:
required_ram = (num_samples + num_variants) * 4 + num_variants * num_samples
log.info(f">{required_ram / 1024**3:.2f} GiB of RAM are required to process {num_samples} samples with {num_variants} variants each")
if not sum_strands:
genotypes = read_separate_strands(
pgen_reader,
variant_idxs,
num_variants,
num_samples,
)
else:
genotypes = np.empty((num_variants, num_samples), dtype=np.int8)
pgen_reader.read_list(variant_idxs, genotypes)
pgen_reader.close()
else:
genotypes = None
log.info("Constructing SNPObject")
fid_col = None
if "IID" in fields and "Family ID" in fam.columns:
fid_col = fam.get_column("Family ID").to_numpy()
sex_col = None
if "IID" in fields and "Sex code" in fam.columns:
sex_col = fam.get_column("Sex code").to_numpy()
snpobj = SNPObject(
calldata_gt=genotypes if "GT" in fields else None,
samples=fam.get_column("IID").to_numpy() if "IID" in fields and "IID" in fam.columns else None,
sample_fid=fid_col,
sample_sex=sex_col,
**{f'variants_{k.lower()}': bim.get_column(v).to_numpy() if v in fields and v in bim.columns else None
for k, v in {'ref': 'REF', 'alt': 'ALT', 'chrom': '#CHROM', 'cm': 'CM', 'id': 'ID', 'pos': 'POS'}.items()}
)
log.info("Finished constructing SNPObject")
return snpobj
def _resolve_variant_idxs_for_iter(
self,
*,
variant_ids: Optional[np.ndarray],
variant_idxs: Optional[np.ndarray],
separator: Optional[str],
) -> np.ndarray:
"""
Resolve variant selectors to canonical file-order row indices.
"""
filename_noext = str(self.filename)
if filename_noext[-4:].lower() in (".bed", ".bim", ".fam"):
filename_noext = filename_noext[:-4]
local_separator = separator
if local_separator is None:
with open(filename_noext + ".bim", "r") as file:
local_separator = csv.Sniffer().sniff(file.readline()).delimiter
bim = pl.read_csv(
filename_noext + ".bim",
separator=local_separator,
has_header=False,
new_columns=["#CHROM", "ID", "CM", "POS", "ALT", "REF"],
schema_overrides={
"#CHROM": pl.String,
"ID": pl.String,
"CM": pl.Float64,
"POS": pl.Int64,
"ALT": pl.String,
"REF": pl.String,
},
null_values=["NA"],
).with_row_index()
if variant_ids is not None:
variant_id_values = [str(v) for v in np.atleast_1d(variant_ids)]
variant_id_or_pos = (
pl.col("ID").is_in(variant_id_values)
| pl.concat_str([pl.col("#CHROM"), pl.lit(":"), pl.col("POS").cast(pl.String)]).is_in(
variant_id_values
)
)
resolved = (
bim.filter(variant_id_or_pos)
.select("index")
.to_series()
.to_numpy()
)
return np.asarray(resolved, dtype=np.uint32)
if variant_idxs is not None:
requested = np.asarray(variant_idxs, dtype=np.uint32).ravel()
resolved = (
bim.filter(pl.col("index").is_in(requested))
.select("index")
.to_series()
.to_numpy()
)
return np.asarray(resolved, dtype=np.uint32)
return np.arange(bim.height, dtype=np.uint32)
[docs]
def iter_read(
self,
fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
sample_ids: Optional[np.ndarray] = None,
sample_idxs: Optional[np.ndarray] = None,
variant_ids: Optional[np.ndarray] = None,
variant_idxs: Optional[np.ndarray] = None,
sum_strands: bool = False,
separator: Optional[str] = None,
chunk_size: int = 10_000,
) -> Iterator[SNPObject]:
"""
Stream the BED fileset in variant chunks.
This yields a sequence of SNPObject chunks along the SNP axis.
"""
if chunk_size < 1:
raise ValueError("chunk_size must be >= 1.")
if sample_idxs is not None and sample_ids is not None:
raise ValueError("Only one of sample_idxs and sample_ids can be specified.")
if variant_idxs is not None and variant_ids is not None:
raise ValueError("Only one of variant_idxs and variant_ids can be specified.")
selectors = self._resolve_variant_idxs_for_iter(
variant_ids=variant_ids,
variant_idxs=variant_idxs,
separator=separator,
)
n_selectors = int(selectors.size)
for start in range(0, n_selectors, int(chunk_size)):
stop = min(start + int(chunk_size), n_selectors)
selector_chunk = np.asarray(selectors[start:stop], dtype=np.uint32)
yield self.read(
fields=fields,
exclude_fields=exclude_fields,
sample_ids=sample_ids,
sample_idxs=sample_idxs,
variant_idxs=selector_chunk,
sum_strands=sum_strands,
separator=separator,
)