from __future__ import annotations
import logging
import math
from pathlib import Path
from typing import Optional, Union
import numpy as np
from bgen import BgenWriter as _BgenWriter
from snputils.snp.genobj.snpobj import SNPObject
log = logging.getLogger(__name__)
[docs]
class BGENWriter:
"""
Write a SNPObject to BGEN format.
``calldata_gp`` is written directly when present. If it is absent and
``genotypes`` is present, hard calls are encoded as one-hot genotype
probabilities so SNPObjects created from VCF/BED/PGEN can still be exported.
"""
def __init__(self, snpobj: SNPObject, filename: Union[str, Path]):
"""
Initialize the BGENWriter.
Args:
snpobj: SNPObject containing genotype probabilities or hard-call genotypes.
filename: Output path. A ``.bgen`` suffix is appended if missing.
"""
self.__snpobj = snpobj
self.__filename = Path(filename)
[docs]
def write(
self,
compression: Optional[str] = "zstd",
layout: int = 2,
bit_depth: int = 8,
phased: Optional[bool] = None,
metadata: Optional[str] = None,
) -> None:
"""
Write the SNPObject to a BGEN file.
Args:
compression: BGEN compression type. Supported by the backend: ``None``,
``"zlib"``, and ``"zstd"``.
layout: BGEN layout version. The backend supports layouts 1 and 2.
bit_depth: Number of bits used to store each probability.
phased: Whether probabilities are phased. If None, inferred per variant
from ``calldata_gp`` width and NaN padding when possible.
metadata: Optional free-form BGEN metadata string.
"""
output = self.__filename
if output.suffix != ".bgen":
output = output.with_suffix(".bgen")
probabilities = self._probabilities(phased)
if probabilities.ndim != 3:
raise ValueError("BGEN genotype probabilities must have shape (n_snps, n_samples, n_probabilities).")
n_variants, n_samples, _ = probabilities.shape
samples = self._samples(n_samples)
variants_ref = self._required_variant_column("variants_ref", n_variants)
variants_alt = self._required_variant_column("variants_alt", n_variants)
variants_chrom = self._required_variant_column("variants_chrom", n_variants)
variants_pos = self._required_variant_column("variants_pos", n_variants)
variants_id = self._variant_ids(n_variants)
log.info(f"Writing to {output}")
with _BgenWriter(
str(output),
n_samples=n_samples,
samples=[str(sample) for sample in samples],
compression=compression,
layout=layout,
metadata=metadata,
) as bfile:
for idx in range(n_variants):
variant_probabilities = np.asarray(probabilities[idx], dtype=np.float64)
alleles = self._alleles(variants_ref[idx], variants_alt[idx])
variant_probabilities = self._trim_trailing_nan_probability_columns(variant_probabilities)
variant_phased = self._variant_phased(variant_probabilities, alleles, phased)
ploidy = self._infer_ploidy(variant_probabilities, alleles, variant_phased)
self._validate_probability_width(variant_probabilities, alleles, variant_phased)
bfile.add_variant(
varid=str(variants_id[idx]),
rsid=str(variants_id[idx]),
chrom=str(variants_chrom[idx]),
pos=int(variants_pos[idx]),
alleles=alleles,
genotypes=variant_probabilities,
ploidy=ploidy,
phased=variant_phased,
bit_depth=int(bit_depth),
)
def _probabilities(self, phased: Optional[bool]) -> np.ndarray:
if self.__snpobj.calldata_gp is not None:
return np.asarray(self.__snpobj.calldata_gp, dtype=np.float64)
if self.__snpobj.genotypes is None:
raise ValueError("BGENWriter requires either `calldata_gp` or `genotypes`.")
return self._hardcalls_to_probabilities(np.asarray(self.__snpobj.genotypes), phased=phased)
@staticmethod
def _hardcalls_to_probabilities(genotypes: np.ndarray, phased: Optional[bool]) -> np.ndarray:
if genotypes.ndim == 3 and phased:
n_variants, n_samples, n_alleles = genotypes.shape
if n_alleles != 2:
raise ValueError("Phased BGEN export expects genotype shape (n_snps, n_samples, 2).")
probabilities = np.zeros((n_variants, n_samples, 4), dtype=np.float64)
missing = np.any(genotypes < 0, axis=2)
probabilities[:, :, 0] = genotypes[:, :, 0] == 0
probabilities[:, :, 1] = genotypes[:, :, 0] == 1
probabilities[:, :, 2] = genotypes[:, :, 1] == 0
probabilities[:, :, 3] = genotypes[:, :, 1] == 1
probabilities[missing, :] = np.nan
return probabilities
if genotypes.ndim == 3:
dosage = genotypes.sum(axis=2, dtype=np.int16)
missing = np.any(genotypes < 0, axis=2)
elif genotypes.ndim == 2:
dosage = genotypes
missing = genotypes < 0
else:
raise ValueError("`genotypes` must be a 2D hard-call or 3D allele array.")
n_variants, n_samples = dosage.shape
probabilities = np.zeros((n_variants, n_samples, 3), dtype=np.float64)
for genotype_value in (0, 1, 2):
probabilities[:, :, genotype_value] = dosage == genotype_value
probabilities[missing, :] = np.nan
return probabilities
def _samples(self, n_samples: int) -> np.ndarray:
if self.__snpobj.samples is None:
return np.asarray([str(i) for i in range(n_samples)], dtype=object)
samples = np.asarray(self.__snpobj.samples, dtype=object)
if samples.shape[0] != n_samples:
raise ValueError(f"samples length ({samples.shape[0]}) must match genotype sample count ({n_samples}).")
return samples
def _required_variant_column(self, attr: str, n_variants: int) -> np.ndarray:
values = getattr(self.__snpobj, attr)
if values is None:
raise ValueError(f"BGENWriter requires `{attr}`.")
arr = np.asarray(values)
if arr.shape[0] != n_variants:
raise ValueError(f"{attr} length ({arr.shape[0]}) must match number of variants ({n_variants}).")
return arr
def _variant_ids(self, n_variants: int) -> np.ndarray:
if self.__snpobj.variants_id is None:
return np.asarray([f"variant_{idx}" for idx in range(n_variants)], dtype=object)
arr = np.asarray(self.__snpobj.variants_id, dtype=object)
if arr.shape[0] != n_variants:
raise ValueError(f"variants_id length ({arr.shape[0]}) must match number of variants ({n_variants}).")
return arr
@staticmethod
def _variant_phased(probabilities: np.ndarray, alleles: list[str], phased: Optional[bool]) -> bool:
if phased is not None:
return bool(phased)
n_alleles = len(alleles)
counts = BGENWriter._nonmissing_probability_counts(probabilities)
counts = counts[counts > 0]
if counts.size == 0:
return probabilities.shape[1] == 4 and n_alleles == 2
phased_ploidy = [BGENWriter._phased_ploidy_from_width(int(count), n_alleles) for count in counts]
unphased_ploidy = [BGENWriter._unphased_ploidy_from_width(int(count), n_alleles) for count in counts]
phased_possible = all(ploidy is not None for ploidy in phased_ploidy)
unphased_possible = all(ploidy is not None for ploidy in unphased_ploidy)
if phased_possible and not unphased_possible:
return True
if unphased_possible and not phased_possible:
return False
if phased_possible and unphased_possible:
# For phased data, each haplotype contributes one probability per allele
# and each haplotype's allele probabilities sum to one. Unphased rows
# instead sum to one across the full genotype distribution.
if BGENWriter._looks_phased(probabilities, n_alleles, phased_ploidy):
return True
if probabilities.shape[1] == 4 and n_alleles == 2 and not np.isnan(probabilities[:, 3]).all():
return True
return False
@staticmethod
def _trim_trailing_nan_probability_columns(probabilities: np.ndarray) -> np.ndarray:
keep = probabilities.shape[1]
while keep > 1 and np.isnan(probabilities[:, keep - 1]).all():
keep -= 1
return probabilities[:, :keep]
@staticmethod
def _nonmissing_probability_counts(probabilities: np.ndarray) -> np.ndarray:
finite = np.isfinite(probabilities)
counts = finite.sum(axis=1)
for sample_idx, count in enumerate(counts):
if count == 0:
continue
if finite[sample_idx, :count].all() and not finite[sample_idx, count:].any():
continue
raise ValueError(
"BGEN probability rows may only contain NaN values as all-missing rows "
"or as trailing padding for lower-ploidy samples."
)
return counts
@staticmethod
def _phased_ploidy_from_width(width: int, n_alleles: int) -> Optional[int]:
if n_alleles <= 0 or width <= 0 or width % n_alleles != 0:
return None
return width // n_alleles
@staticmethod
def _unphased_ploidy_from_width(width: int, n_alleles: int) -> Optional[int]:
if width <= 0 or n_alleles <= 0:
return None
for ploidy in range(0, 64):
if math.comb(ploidy + n_alleles - 1, n_alleles - 1) == width:
return ploidy
return None
@staticmethod
def _looks_phased(
probabilities: np.ndarray,
n_alleles: int,
ploidies: list[Optional[int]],
) -> bool:
for sample_probabilities, ploidy in zip(probabilities, ploidies):
if ploidy is None:
return False
if np.isnan(sample_probabilities).all():
continue
width = ploidy * n_alleles
haplotypes = sample_probabilities[:width].reshape(ploidy, n_alleles)
if not np.allclose(haplotypes.sum(axis=1), 1.0, atol=1e-4, rtol=0):
return False
return True
@staticmethod
def _infer_ploidy(probabilities: np.ndarray, alleles: list[str], phased: bool) -> Union[int, np.ndarray]:
n_alleles = len(alleles)
counts = BGENWriter._nonmissing_probability_counts(probabilities)
ploidies = np.empty(probabilities.shape[0], dtype=np.uint8)
inferred = []
for count in counts:
if count == 0:
inferred.append(None)
continue
if phased:
ploidy = BGENWriter._phased_ploidy_from_width(int(count), n_alleles)
else:
ploidy = BGENWriter._unphased_ploidy_from_width(int(count), n_alleles)
if ploidy is None:
mode = "phased" if phased else "unphased"
if phased and n_alleles == 2:
raise ValueError(
f"Biallelic diploid phased BGEN probabilities require 4 columns; "
f"got {int(count)}."
)
raise ValueError(
f"Cannot infer {mode} BGEN ploidy from {count} probability columns "
f"and {n_alleles} alleles."
)
inferred.append(ploidy)
fallback_ploidy = max((ploidy for ploidy in inferred if ploidy is not None), default=2)
for idx, ploidy in enumerate(inferred):
ploidies[idx] = fallback_ploidy if ploidy is None else ploidy
if np.all(ploidies == ploidies[0]):
return int(ploidies[0])
return ploidies
@staticmethod
def _validate_probability_width(probabilities: np.ndarray, alleles: list[str], phased: bool) -> None:
BGENWriter._nonmissing_probability_counts(probabilities)
@staticmethod
def _alleles(ref: Union[str, bytes], alt: Union[str, bytes]) -> list[str]:
ref_text = str(ref)
alt_text = str(alt)
if not ref_text or ref_text == "." or not alt_text or alt_text == ".":
raise ValueError("BGENWriter requires non-missing REF and ALT alleles.")
return [ref_text] + alt_text.split(",")