Source code for snputils.snp.io.read.bgen

from __future__ import annotations

import logging
from typing import List, Optional, Sequence, Union

import numpy as np
from bgen import BgenReader as _BgenReader

from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.read.base import SNPBaseReader

log = logging.getLogger(__name__)


def _as_field_list(fields: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
    if fields is None:
        return None
    if isinstance(fields, str):
        return [fields]
    return list(fields)


def _variant_identifier(var) -> str:
    varid = str(var.varid)
    if varid and varid != ".":
        return varid
    return str(var.rsid)


[docs] @SNPBaseReader.register class BGENReader(SNPBaseReader):
[docs] def read( self, fields: Optional[Union[str, Sequence[str]]] = None, exclude_fields: Optional[Union[str, Sequence[str]]] = None, sample_path: Optional[Union[str, bytes]] = None, sample_ids: Optional[Sequence[str]] = None, sample_idxs: Optional[Sequence[int]] = None, variant_ids: Optional[Sequence[str]] = None, variant_idxs: Optional[Sequence[int]] = None, ) -> SNPObject: """ Read a BGEN file into a SNPObject. Args: fields: Fields to include. Available fields are ``GP``, ``IID``, ``REF``, ``ALT``, ``#CHROM``, ``ID``, and ``POS``. ``GT`` is intentionally unsupported because this reader preserves BGEN genotype probabilities instead of converting them to hard calls. exclude_fields: Fields to exclude from the returned SNPObject. sample_path: Optional Oxford ``.sample`` file for BGEN files without embedded sample identifiers. sample_ids: Sample IDs to read. If None and sample_idxs is None, all samples are read. sample_idxs: Sample indices to read. If None and sample_ids is None, all samples are read. variant_ids: Variant IDs to read. Matches BGEN varid, rsid, or ``chrom:pos``. variant_idxs: Variant indices to read. If None and variant_ids is None, all variants are read. Returns: SNPObject: A SNPObject with genotype probabilities in ``calldata_gp``. Mixed probability widths are padded with NaN columns. """ if sample_idxs is not None and sample_ids is not None: raise ValueError("Only one of sample_idxs and sample_ids can be specified.") if variant_idxs is not None and variant_ids is not None: raise ValueError("Only one of variant_idxs and variant_ids can be specified.") fields_list = _as_field_list(fields) or ["GP", "IID", "REF", "ALT", "#CHROM", "ID", "POS"] exclude = set(_as_field_list(exclude_fields) or []) fields_set = {field for field in fields_list if field not in exclude} if "GT" in fields_set: raise NotImplementedError( "BGENReader preserves genotype probabilities in `calldata_gp` and does not hard-call GT." ) sample_path_str = "" if sample_path is None else str(sample_path) log.info(f"Reading {self.filename}") with _BgenReader(str(self.filename), sample_path_str, delay_parsing=True) as bfile: file_samples = np.asarray(bfile.samples, dtype=object) sample_indices = self._resolve_sample_indices(file_samples, sample_ids, sample_idxs) variant_indices = self._resolve_variant_indices(bfile, variant_ids, variant_idxs) samples = file_samples[sample_indices] if "IID" in fields_set else None variants_ref = [] variants_alt = [] variants_chrom = [] variants_id = [] variants_pos = [] probabilities = [] prob_width = None read_gp = "GP" in fields_set selected_set = set(int(idx) for idx in variant_indices) for idx, var in enumerate(bfile): if idx not in selected_set: continue selected_set.remove(idx) alleles = [str(allele) for allele in var.alleles] if "REF" in fields_set: variants_ref.append(alleles[0] if alleles else "") if "ALT" in fields_set: variants_alt.append(",".join(alleles[1:]) if len(alleles) > 1 else "") if "#CHROM" in fields_set: variants_chrom.append(str(var.chrom)) if "ID" in fields_set: variants_id.append(_variant_identifier(var)) if "POS" in fields_set: variants_pos.append(int(var.pos)) if read_gp: probs = np.asarray(var.probabilities, dtype=np.float32)[sample_indices, :] if prob_width is None: prob_width = probs.shape[1] elif probs.shape[1] != prob_width: new_width = max(prob_width, probs.shape[1]) if new_width != prob_width: probabilities = [self._pad_probabilities(p, new_width) for p in probabilities] prob_width = new_width probs = self._pad_probabilities(probs, prob_width) probabilities.append(probs) if not selected_set: break if read_gp: calldata_gp = ( np.stack(probabilities, axis=0) if probabilities else np.empty((0, len(sample_indices), 0), dtype=np.float32) ) else: calldata_gp = None return SNPObject( genotypes=None, calldata_gp=calldata_gp, samples=samples, variants_ref=np.asarray(variants_ref, dtype=object) if "REF" in fields_set else None, variants_alt=np.asarray(variants_alt, dtype=object) if "ALT" in fields_set else None, variants_chrom=np.asarray(variants_chrom, dtype=object) if "#CHROM" in fields_set else None, variants_id=np.asarray(variants_id, dtype=object) if "ID" in fields_set else None, variants_pos=np.asarray(variants_pos, dtype=np.int64) if "POS" in fields_set else None, )
@staticmethod def _resolve_sample_indices( file_samples: np.ndarray, sample_ids: Optional[Sequence[str]], sample_idxs: Optional[Sequence[int]], ) -> np.ndarray: if sample_idxs is not None: idx = np.asarray(sample_idxs, dtype=int).ravel() n_samples = len(file_samples) if np.any((idx < -n_samples) | (idx >= n_samples)): raise ValueError("One or more sample indexes are out of bounds.") return np.mod(idx, n_samples) if sample_ids is None: return np.arange(len(file_samples), dtype=int) requested = np.asarray(sample_ids, dtype=object).ravel() sample_lookup = {str(sample): i for i, sample in enumerate(file_samples)} missing = [str(sample) for sample in requested if str(sample) not in sample_lookup] if missing: raise ValueError(f"The following specified samples were not found: {missing}") return np.asarray([sample_lookup[str(sample)] for sample in requested], dtype=int) @staticmethod def _pad_probabilities(probabilities: np.ndarray, width: int) -> np.ndarray: if probabilities.shape[1] == width: return probabilities padded = np.full((probabilities.shape[0], width), np.nan, dtype=probabilities.dtype) padded[:, : probabilities.shape[1]] = probabilities return padded @staticmethod def _resolve_variant_indices( bfile, variant_ids: Optional[Sequence[str]], variant_idxs: Optional[Sequence[int]], ) -> np.ndarray: if variant_idxs is not None: idx = np.asarray(variant_idxs, dtype=int).ravel() n_variants = len(bfile) if np.any((idx < -n_variants) | (idx >= n_variants)): raise ValueError("One or more variant indexes are out of bounds.") return np.mod(idx, n_variants) if variant_ids is None: return np.arange(len(bfile), dtype=int) requested = {str(value) for value in np.asarray(variant_ids, dtype=object).ravel()} selected = [] varids = bfile.varids() rsids = bfile.rsids() chroms = bfile.chroms() positions = bfile.positions() for idx, (varid, rsid, chrom, pos) in enumerate(zip(varids, rsids, chroms, positions)): identifiers = {str(varid), str(rsid), f"{chrom}:{pos}"} if requested.intersection(identifiers): selected.append(idx) found = set() for idx in selected: found.update({str(varids[idx]), str(rsids[idx]), f"{chroms[idx]}:{positions[idx]}"}) missing = sorted(requested - found) if missing: raise ValueError(f"The following specified variants were not found: {missing}") return np.asarray(selected, dtype=int)