from __future__ import annotations
from os import PathLike
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
from snputils._utils.allele_freq import aggregate_pop_allele_freq
def _slice_variant_axis(arr: Any, start: int, stop: int) -> Any:
if arr is None:
return None
arr = np.asarray(arr)
if arr.ndim == 0:
return arr
return arr[start:stop, ...]
def _iter_snpobject_chunks(snpobj: Any, chunk_size: int) -> Iterator[Any]:
"""
Yield variant-axis SNPObject chunks from an in-memory SNPObject.
"""
from snputils.snp.genobj.snpobj import SNPObject
if chunk_size < 1:
raise ValueError("chunk_size must be >= 1.")
n_snps = snpobj.n_snps
for start in range(0, n_snps, int(chunk_size)):
stop = min(start + int(chunk_size), n_snps)
yield SNPObject(
calldata_gt=_slice_variant_axis(snpobj.calldata_gt, start, stop),
samples=None if snpobj.samples is None else np.asarray(snpobj.samples),
variants_ref=_slice_variant_axis(snpobj.variants_ref, start, stop),
variants_alt=_slice_variant_axis(snpobj.variants_alt, start, stop),
variants_chrom=_slice_variant_axis(snpobj.variants_chrom, start, stop),
variants_filter_pass=_slice_variant_axis(snpobj.variants_filter_pass, start, stop),
variants_id=_slice_variant_axis(snpobj.variants_id, start, stop),
variants_pos=_slice_variant_axis(snpobj.variants_pos, start, stop),
variants_qual=_slice_variant_axis(snpobj.variants_qual, start, stop),
calldata_lai=_slice_variant_axis(snpobj.calldata_lai, start, stop),
ancestry_map=snpobj.ancestry_map,
)
def _canonical_chromosome(chromosome: Any) -> str:
text = str(chromosome).strip()
lower = text.lower()
if lower.startswith("chr"):
text = text[3:]
text = text.strip()
if text.isdigit():
return str(int(text))
return text.lower()
class _IterWindowsLAIMapper:
def __init__(
self,
lai_reader: Any,
*,
chunk_size: int,
iter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if chunk_size < 1:
raise ValueError("lai_chunk_size must be >= 1.")
if not hasattr(lai_reader, "iter_windows") or not callable(getattr(lai_reader, "iter_windows")):
raise TypeError("LAI reader must implement `iter_windows(...)`.")
self._window_iter = iter(
lai_reader.iter_windows(
chunk_size=int(chunk_size),
**(iter_kwargs or {}),
)
)
self._window_chrom: Optional[np.ndarray] = None
self._window_pos: Optional[np.ndarray] = None
self._window_lai: Optional[np.ndarray] = None
self._row_idx = 0
self._exhausted = False
self._active_window_chrom: Optional[str] = None
self._completed_window_chroms: Set[str] = set()
self._seen_snp_chroms: Set[str] = set()
self._last_snp_chrom: Optional[str] = None
self._last_snp_pos: Optional[int] = None
self._n_samples_lai: Optional[int] = None
def _load_next_window_chunk(self) -> bool:
prev_active = self._active_window_chrom
while True:
try:
chunk = next(self._window_iter)
except StopIteration:
self._exhausted = True
self._window_chrom = None
self._window_pos = None
self._window_lai = None
self._active_window_chrom = None
return False
lai = np.asarray(chunk.get("lai", None))
if lai.size == 0:
continue
if lai.ndim != 2:
raise ValueError("LAI `iter_windows` chunks must provide 2D `lai` arrays.")
if lai.shape[1] % 2 != 0:
raise ValueError("LAI `iter_windows` chunks must contain an even number of haplotype columns.")
chrom = np.asarray(chunk.get("chromosomes", None))
if chrom.ndim != 1 or chrom.shape[0] != lai.shape[0]:
raise ValueError(
"LAI `iter_windows` chunks must provide `chromosomes` with one value per window."
)
chrom = np.asarray([_canonical_chromosome(c) for c in chrom], dtype=object)
physical_pos = chunk.get("physical_pos", None)
if physical_pos is None:
raise ValueError(
"LAI reader windows must include physical positions (`physical_pos`) "
"to map ancestry to SNP chunks."
)
physical_pos = np.asarray(physical_pos)
if physical_pos.ndim != 2 or physical_pos.shape[1] != 2 or physical_pos.shape[0] != lai.shape[0]:
raise ValueError(
"LAI `iter_windows` chunks must provide `physical_pos` with shape (n_windows, 2)."
)
n_samples_lai = lai.shape[1] // 2
if self._n_samples_lai is None:
self._n_samples_lai = n_samples_lai
elif self._n_samples_lai != n_samples_lai:
raise ValueError("Inconsistent number of LAI samples across LAI window chunks.")
self._window_chrom = chrom
self._window_pos = physical_pos.astype(np.int64, copy=False)
self._window_lai = lai
self._row_idx = 0
next_active = str(self._window_chrom[0])
if prev_active is not None and prev_active != next_active:
self._completed_window_chroms.add(prev_active)
self._active_window_chrom = next_active
return True
def _current_window(self) -> Optional[Tuple[str, int, int, np.ndarray]]:
if self._window_chrom is None:
if not self._load_next_window_chunk():
return None
if self._window_chrom is None or self._window_pos is None or self._window_lai is None:
return None
chrom = str(self._window_chrom[self._row_idx])
start = int(self._window_pos[self._row_idx, 0])
end = int(self._window_pos[self._row_idx, 1])
lai_row = self._window_lai[self._row_idx]
return chrom, start, end, lai_row
def _advance_window(self) -> bool:
if self._window_chrom is None:
return self._load_next_window_chunk()
current_chrom = str(self._window_chrom[self._row_idx])
self._row_idx += 1
if self._window_chrom is not None and self._row_idx < len(self._window_chrom):
next_chrom = str(self._window_chrom[self._row_idx])
if current_chrom != next_chrom:
self._completed_window_chroms.add(current_chrom)
self._active_window_chrom = next_chrom
return True
self._active_window_chrom = current_chrom
return self._load_next_window_chunk()
def _assert_snp_order(self, chrom: str, pos: int) -> None:
if self._last_snp_chrom is None:
self._last_snp_chrom = chrom
self._last_snp_pos = pos
self._seen_snp_chroms.add(chrom)
return
if chrom == self._last_snp_chrom:
if self._last_snp_pos is not None and pos < self._last_snp_pos:
raise ValueError(
"SNP chunks must be sorted by ascending position within chromosome "
"when using a streaming LAI reader."
)
else:
if chrom in self._seen_snp_chroms:
raise ValueError(
"SNP chunks cannot revisit chromosomes when using a streaming LAI reader "
"(pass an in-memory LocalAncestryObject instead)."
)
self._seen_snp_chroms.add(chrom)
self._last_snp_chrom = chrom
self._last_snp_pos = pos
def map_chunk(
self,
*,
variants_chrom: np.ndarray,
variants_pos: np.ndarray,
n_samples_expected: int,
) -> np.ndarray:
variants_chrom = np.asarray(variants_chrom)
variants_pos = np.asarray(variants_pos)
if variants_chrom.ndim != 1 or variants_pos.ndim != 1 or variants_chrom.shape[0] != variants_pos.shape[0]:
raise ValueError("`variants_chrom` and `variants_pos` must be 1D arrays with matching length.")
if self._n_samples_lai is None:
_ = self._current_window()
if self._n_samples_lai is None:
raise ValueError("No LAI windows available in the provided LAI reader.")
if self._n_samples_lai != int(n_samples_expected):
raise ValueError(
"LAI sample count does not match SNP sample count "
f"({self._n_samples_lai} vs {n_samples_expected})."
)
n_snps = variants_pos.shape[0]
calldata_lai = np.full((n_snps, n_samples_expected, 2), -1, dtype=np.int16)
for snp_idx in range(n_snps):
chrom = _canonical_chromosome(variants_chrom[snp_idx])
try:
pos = int(variants_pos[snp_idx])
except Exception as exc:
raise ValueError("SNP variant positions must be integer-like when using LAI reader streaming.") from exc
self._assert_snp_order(chrom, pos)
while True:
window = self._current_window()
if window is None:
break
win_chrom, win_start, win_end, win_lai = window
if win_chrom == chrom:
if pos < win_start:
break
if pos <= win_end:
calldata_lai[snp_idx] = np.asarray(win_lai).reshape(n_samples_expected, 2)
break
else:
if chrom in self._completed_window_chroms:
raise ValueError(
"SNP chunks are not in the same chromosome order as LAI windows. "
"Use an in-memory LocalAncestryObject for out-of-order SNP access."
)
if not self._advance_window():
break
return calldata_lai
def _coerce_lai_source(
laiobj: Any,
*,
lai_chunk_size: int,
lai_iter_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[Any], Optional[_IterWindowsLAIMapper]]:
if laiobj is None:
return None, None
source = laiobj
if isinstance(source, (str, PathLike)):
from snputils.ancestry.io.local.read import LAIReader
source = LAIReader(source)
if hasattr(source, "iter_windows") and callable(getattr(source, "iter_windows")):
return None, _IterWindowsLAIMapper(
source,
chunk_size=lai_chunk_size,
iter_kwargs=lai_iter_kwargs,
)
if hasattr(source, "convert_to_snp_level") and callable(getattr(source, "convert_to_snp_level")):
return source, None
if hasattr(source, "read") and callable(getattr(source, "read")):
loaded = source.read()
if hasattr(loaded, "convert_to_snp_level") and callable(getattr(loaded, "convert_to_snp_level")):
return loaded, None
raise TypeError(
"`laiobj` must be one of: LocalAncestryObject, LAI reader/path, "
"or an object implementing `read()` that returns a LocalAncestryObject."
)
[docs]
def allele_freq_stream(
data: Any,
*,
chunk_size: int = 10_000,
sample_labels: Optional[Sequence[Any]] = None,
ancestry: Optional[Union[str, int]] = None,
laiobj: Optional[Any] = None,
lai_chunk_size: int = 1024,
lai_iter_kwargs: Optional[Dict[str, Any]] = None,
pseudohaploid: Union[bool, int] = False,
return_counts: bool = False,
as_dataframe: bool = False,
**iter_kwargs,
) -> Any:
"""
Compute allele frequencies in variant chunks.
Args:
data:
One of:
- in-memory SNPObject
- a reader implementing `iter_read(...)`
- an iterable yielding SNPObject chunks
chunk_size:
Number of SNPs per chunk.
sample_labels:
Population label per sample. If None, computes cohort-level frequencies.
ancestry:
Optional ancestry code for ancestry-specific masking.
laiobj:
Optional LAI source used to derive SNP-level LAI when missing in chunks.
Supported inputs:
- in-memory LocalAncestryObject
- LAI reader implementing `iter_windows(...)`
- path to an LAI file understood by `snputils.ancestry.io.local.read.LAIReader`
lai_chunk_size:
Number of LAI windows per chunk when `laiobj` is a streaming LAI reader/path.
lai_iter_kwargs:
Optional kwargs forwarded to `laiobj.iter_windows(...)`.
pseudohaploid:
If True, detects pseudo-haploid samples (samples with no heterozygotes in the first 1000 SNPs)
and treats them as haploid. If an integer `n` is provided, checks the first `n` SNPs.
If False, treats all samples as diploid.
return_counts:
If True, also return called haplotype counts.
as_dataframe:
If True, return pandas DataFrames.
**iter_kwargs:
Additional args forwarded to `iter_read(...)` when `data` is a reader.
"""
try:
from snputils.snp.genobj.snpobj import SNPObject
except Exception:
SNPObject = None # type: ignore
grouped_output = sample_labels is not None
if SNPObject is not None and isinstance(data, SNPObject):
chunk_iter: Iterable[Any] = _iter_snpobject_chunks(data, chunk_size=chunk_size)
elif hasattr(data, "iter_read") and callable(getattr(data, "iter_read")):
chunk_iter = data.iter_read(chunk_size=chunk_size, **iter_kwargs)
else:
chunk_iter = iter(data)
lai_object = None
lai_window_mapper: Optional[_IterWindowsLAIMapper] = None
if ancestry is not None and laiobj is not None:
lai_object, lai_window_mapper = _coerce_lai_source(
laiobj,
lai_chunk_size=lai_chunk_size,
lai_iter_kwargs=lai_iter_kwargs,
)
labels = None
n_samples_ref = None
pops_ref: Optional[List[Any]] = None
afs_parts: List[np.ndarray] = []
counts_parts: List[np.ndarray] = []
for chunk in chunk_iter:
if chunk is None or getattr(chunk, "calldata_gt", None) is None:
continue
gt_chunk = np.asarray(chunk.calldata_gt)
if gt_chunk.ndim not in (2, 3):
raise ValueError("'calldata_gt' must be 2D or 3D array")
n_samples = gt_chunk.shape[1]
if n_samples_ref is None:
n_samples_ref = n_samples
if sample_labels is None:
labels = np.repeat("__all__", n_samples_ref)
else:
labels = np.asarray(sample_labels)
if labels.ndim != 1:
labels = labels.ravel()
if labels.shape[0] != n_samples_ref:
raise ValueError(
"'sample_labels' must have length equal to the number of samples in `calldata_gt`."
)
elif n_samples != n_samples_ref:
raise ValueError("All chunks must have the same number of samples.")
calldata_lai = getattr(chunk, "calldata_lai", None)
if ancestry is not None and calldata_lai is None:
if lai_window_mapper is not None:
variants_chrom = getattr(chunk, "variants_chrom", None)
variants_pos = getattr(chunk, "variants_pos", None)
if variants_chrom is None or variants_pos is None:
raise ValueError(
"Ancestry-specific masking with a streaming LAI reader requires "
"`variants_chrom` and `variants_pos` on SNP chunks."
)
calldata_lai = lai_window_mapper.map_chunk(
variants_chrom=variants_chrom,
variants_pos=variants_pos,
n_samples_expected=n_samples,
)
elif lai_object is not None:
try:
converted_lai = lai_object.convert_to_snp_level(snpobject=chunk, lai_format="3D")
calldata_lai = getattr(converted_lai, "calldata_lai", None)
except Exception:
calldata_lai = None
if ancestry is not None and calldata_lai is None:
raise ValueError(
"Ancestry-specific masking requires SNP-level LAI "
"(provide `calldata_lai` on the chunks or pass `laiobj`)."
)
afs_chunk, counts_chunk, pops = aggregate_pop_allele_freq(
calldata_gt=gt_chunk,
sample_labels=labels,
ancestry=ancestry,
calldata_lai=calldata_lai,
pseudohaploid=pseudohaploid,
)
if pops_ref is None:
pops_ref = pops
elif pops_ref != pops:
raise ValueError("Population labels must be consistent across chunks.")
afs_parts.append(afs_chunk)
counts_parts.append(counts_chunk)
if not afs_parts:
raise ValueError("No genotype chunks were provided.")
afs = np.vstack(afs_parts)
counts = np.vstack(counts_parts)
if grouped_output:
freq_out = afs
count_out = counts
if as_dataframe:
import pandas as pd
freq_out = pd.DataFrame(afs, columns=pops_ref)
count_out = pd.DataFrame(counts, columns=pops_ref)
else:
freq_out = afs[:, 0]
count_out = counts[:, 0]
if as_dataframe:
import pandas as pd
freq_out = pd.DataFrame({"allele_freq": freq_out})
count_out = pd.DataFrame({"called_alleles": count_out})
if return_counts:
return freq_out, count_out
return freq_out
__all__ = ["allele_freq_stream"]