from __future__ import annotations
import gzip
import logging
import re
import struct
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from snputils.snp.genobj.snpobj import SNPObject
from snputils.snp.io.read.base import SNPBaseReader
from snputils.snp.io.read.vcf import _parse_vcf_region, _vcf_region_matches
log = logging.getLogger(__name__)
_DEFAULT_FIELDS = ["GT", "IID", "REF", "ALT", "#CHROM", "ID", "POS", "QUAL", "FILTER"]
_ALL_FIELDS = ["GT", "GP", "IID", "REF", "ALT", "#CHROM", "ID", "POS", "QUAL", "FILTER", "INFO"]
_BCF_MAGIC = b"BCF\x02\x02"
_U32 = struct.Struct("<I")
_I32 = struct.Struct("<i")
_F32 = struct.Struct("<f")
_TYPE_SIZES = {0: 0, 1: 1, 2: 2, 3: 4, 5: 4, 7: 1}
_INT_UNSIGNED_DTYPES = {1: np.uint8, 2: np.dtype("<u2"), 4: np.dtype("<u4")}
_FLOAT_MISSING = 0x7F800001
_FLOAT_VECTOR_END = 0x7F800002
_HEADER_META_RE = re.compile(r"^##(contig|INFO|FORMAT|FILTER)=<(.*)>$")
_HEADER_KV_RE = re.compile(r'([^=,]+)=(".*?"|[^,<>]+)')
# Struct for reading the 6 fixed u32 fields from a BCF record's shared section.
# Layout at base (= record_offset + 8):
# [0] chrom_id (i32), [1] pos (i32), [2] rlen (i32), [3] qual (f32),
# [4] n_alleles_info (u32: top 16 = n_alleles, bottom 16 = n_info),
# [5] n_fmt_n_samples (u32: top 8 = n_fmt, bottom 24 = n_samples)
_FIXED_FIELDS = struct.Struct("<iiIfII")
@dataclass(frozen=True)
class _BCFHeader:
samples: np.ndarray
contigs: Dict[int, str]
filters: Dict[int, str]
info: Dict[int, Dict[str, str]]
formats: Dict[int, Dict[str, str]]
def _as_field_list(fields: Optional[Union[str, Sequence[str]]]) -> Optional[List[str]]:
if fields is None:
return None
if isinstance(fields, str):
return [fields]
return list(fields)
def _normalize_fields(
fields: Optional[Union[str, Sequence[str]]],
exclude_fields: Optional[Union[str, Sequence[str]]],
) -> list[str]:
exclude = {"#CHROM" if field == "CHROM" else field for field in (_as_field_list(exclude_fields) or [])}
requested = _as_field_list(fields)
if requested is None:
resolved = list(_DEFAULT_FIELDS)
elif requested == ["*"]:
resolved = list(_ALL_FIELDS)
else:
resolved = requested
normalized = []
for field in resolved:
canonical = "#CHROM" if field == "CHROM" else field
if canonical not in _ALL_FIELDS:
raise ValueError(
f"Unsupported BCF field: {field}. "
f"Supported fields are {', '.join(_ALL_FIELDS)} and '*'."
)
if canonical not in exclude:
normalized.append(canonical)
return normalized
def _resolve_sample_indices(
file_samples: np.ndarray,
sample_ids: Optional[Sequence[str]],
sample_idxs: Optional[Sequence[int]],
) -> np.ndarray:
if sample_idxs is not None and sample_ids is not None:
raise ValueError("Only one of sample_idxs and sample_ids can be specified.")
if sample_idxs is not None:
idx = np.asarray(sample_idxs, dtype=int).ravel()
n_samples = len(file_samples)
if np.any((idx < -n_samples) | (idx >= n_samples)):
raise ValueError("One or more sample indexes are out of bounds.")
return np.mod(idx, n_samples)
if sample_ids is None:
return np.arange(len(file_samples), dtype=int)
requested = np.asarray(sample_ids, dtype=object).ravel()
sample_lookup = {str(sample): i for i, sample in enumerate(file_samples)}
missing = [str(sample) for sample in requested if str(sample) not in sample_lookup]
if missing:
raise ValueError(f"The following specified samples were not found: {missing}")
return np.asarray([sample_lookup[str(sample)] for sample in requested], dtype=int)
def _parse_header_fields(text: str) -> Dict[str, str]:
return {
key: value.strip('"')
for key, value in _HEADER_KV_RE.findall(text)
}
def _parse_bcf_header(text: str) -> _BCFHeader:
samples: list[str] = []
contigs: Dict[int, str] = {}
filters: Dict[int, str] = {0: "PASS"}
info: Dict[int, Dict[str, str]] = {}
formats: Dict[int, Dict[str, str]] = {}
contig_idx = 0
for line in text.rstrip("\0").splitlines():
if not line:
continue
if line.startswith("#CHROM"):
parts = line.split("\t")
samples = parts[9:] if len(parts) > 9 else []
continue
match = _HEADER_META_RE.match(line)
if match is None:
continue
kind, payload = match.groups()
fields = _parse_header_fields(payload)
if kind == "contig":
contigs[contig_idx] = fields["ID"]
contig_idx += 1
continue
idx = int(fields.get("IDX", "-1"))
if idx < 0:
continue
if kind == "FILTER":
filters[idx] = fields["ID"]
elif kind == "INFO":
info[idx] = fields
elif kind == "FORMAT":
formats[idx] = fields
return _BCFHeader(
samples=np.asarray(samples, dtype=object),
contigs=contigs,
filters=filters,
info=info,
formats=formats,
)
def _load_bcf_data(filename: Union[str, bytes]) -> Tuple[bytes, int, _BCFHeader]:
with gzip.open(filename, "rb") as handle:
data = handle.read()
if data[:5] != _BCF_MAGIC:
raise ValueError(f"{filename!r} does not look like a BCF2.2 file.")
header_len = _U32.unpack_from(data, 5)[0]
header_start = 9
header_end = header_start + header_len
header_text = data[header_start:header_end].decode("utf-8", "replace")
return data, header_end, _parse_bcf_header(header_text)
def _read_typed_descriptor(data: bytes, offset: int) -> Tuple[int, int, int, int]:
byte = data[offset]
offset += 1
type_code = byte & 0x0F
n_vals = byte >> 4
if n_vals == 0:
return 0, 0, 0, offset
if n_vals == 15:
length_descriptor = data[offset]
offset += 1
length_type = length_descriptor & 0x0F
length_size = _TYPE_SIZES.get(length_type, 0)
if length_type not in (1, 2, 3) or length_size == 0:
raise ValueError("Cannot identify the BCF typed-value length encoding.")
n_vals = int.from_bytes(data[offset:offset + length_size], "little", signed=False)
offset += length_size
type_size = _TYPE_SIZES.get(type_code, 0)
if type_code not in _TYPE_SIZES:
raise ValueError(f"Unsupported BCF atomic type code: {type_code}")
return n_vals, type_code, type_size, offset
def _skip_typed_value_fast(data: bytes, offset: int) -> int:
b = data[offset]
type_size = _TYPE_SIZES[b & 0x0F]
n_vals = b >> 4
if n_vals < 15:
return offset + 1 + n_vals * type_size
length_descriptor = data[offset + 1]
length_type = length_descriptor & 0x0F
length_size = _TYPE_SIZES[length_type]
n_vals = int.from_bytes(data[offset + 2 : offset + 2 + length_size], "little")
return offset + 2 + length_size + n_vals * type_size
def _read_scalar_typed_int(data: bytes, offset: int) -> Tuple[int, int]:
n_vals, type_code, type_size, offset = _read_typed_descriptor(data, offset)
if n_vals != 1 or type_code not in (1, 2, 3):
raise ValueError("Expected a scalar integer typed value in the BCF record.")
value = int.from_bytes(data[offset:offset + type_size], "little", signed=False)
return value, offset + type_size
def _read_typed_string(data: bytes, offset: int) -> Tuple[str, int]:
n_vals, type_code, type_size, offset = _read_typed_descriptor(data, offset)
if type_code != 7:
raise ValueError("Expected a typed string in the BCF record.")
end = offset + n_vals * type_size
value = data[offset:end].split(b"\0", 1)[0].decode("utf-8")
return value, end
def _skip_typed_value(data: bytes, offset: int) -> int:
n_vals, _type_code, type_size, offset = _read_typed_descriptor(data, offset)
return offset + n_vals * type_size
def _read_int_list(data: bytes, offset: int) -> Tuple[List[Optional[int]], int]:
n_vals, type_code, type_size, offset = _read_typed_descriptor(data, offset)
if type_code == 0:
return [], offset
if type_code not in (1, 2, 3):
raise ValueError(f"Expected an integer typed value, found atomic type {type_code}.")
missing = 1 << ((type_size * 8) - 1)
vector_end = missing | 0x1
values: List[Optional[int]] = []
for _ in range(n_vals):
raw = int.from_bytes(data[offset:offset + type_size], "little", signed=False)
offset += type_size
if raw == vector_end:
break
if raw == missing:
values.append(None)
continue
values.append(int.from_bytes(raw.to_bytes(type_size, "little"), "little", signed=True))
return values, offset
def _read_float_list(data: bytes, offset: int) -> Tuple[List[float], int]:
n_vals, type_code, type_size, offset = _read_typed_descriptor(data, offset)
if type_code == 0:
return [], offset
if type_code != 5 or type_size != 4:
raise ValueError(f"Expected a float typed value, found atomic type {type_code}.")
values: List[float] = []
for _ in range(n_vals):
raw = _U32.unpack_from(data, offset)[0]
offset += 4
if raw == _FLOAT_VECTOR_END:
break
if raw == _FLOAT_MISSING:
values.append(np.nan)
continue
values.append(_F32.unpack_from(data, offset - 4)[0])
return values, offset
def _render_info_value(value: Any) -> str:
if isinstance(value, list):
rendered = []
for item in value:
if item is None:
rendered.append(".")
elif isinstance(item, float) and np.isnan(item):
rendered.append(".")
else:
rendered.append(str(item))
return ",".join(rendered)
if isinstance(value, float) and np.isnan(value):
return "."
return str(value)
def _variant_qual(data: bytes, base_offset: int) -> float:
raw = _U32.unpack_from(data, base_offset + 12)[0]
if raw == _FLOAT_MISSING:
return np.nan
return _F32.unpack_from(data, base_offset + 12)[0]
def _decode_record_identifiers(
data: bytes,
record_offset: int,
header: _BCFHeader,
) -> Tuple[str, int, str, str, Tuple[str, ...]]:
base = record_offset + 8
chrom = header.contigs[_I32.unpack_from(data, base)[0]]
pos = _I32.unpack_from(data, base + 4)[0] + 1
n_alleles = _U32.unpack_from(data, base + 16)[0] >> 16
offset = base + 24
variant_id, offset = _read_typed_string(data, offset)
ref, offset = _read_typed_string(data, offset)
alts = []
for _ in range(max(0, n_alleles - 1)):
alt, offset = _read_typed_string(data, offset)
alts.append(alt)
return chrom, pos, variant_id, ref, tuple(alts)
def _record_identifiers(chrom: str, pos: int, variant_id: str, ref: str, alts: Sequence[str]) -> set[str]:
identifiers = {f"{chrom}:{pos}", f"{chrom}:{pos}:{ref}:{','.join(alts)}"}
if variant_id not in ("", "."):
identifiers.add(variant_id)
return identifiers
def _count_records(data: bytes, body_offset: int) -> int:
offset = body_offset
end = len(data)
count = 0
while offset < end:
l_shared = _U32.unpack_from(data, offset)[0]
l_indiv = _U32.unpack_from(data, offset + 4)[0]
offset += 8 + l_shared + l_indiv
count += 1
if offset != end:
raise ValueError("Malformed BCF: record boundaries do not consume the full file.")
return count
_U32_PAIR = struct.Struct("<II")
def _build_record_offsets(data: bytes, body_offset: int) -> np.ndarray:
"""Build an array of byte offsets for every record in one pass."""
offsets = []
offset = body_offset
end = len(data)
unpack = _U32_PAIR.unpack_from
while offset < end:
offsets.append(offset)
l_shared, l_indiv = unpack(data, offset)
offset += 8 + l_shared + l_indiv
if offset != end:
raise ValueError("Malformed BCF: record boundaries do not consume the full file.")
return np.asarray(offsets, dtype=np.int64)
def _gather_u32(raw: np.ndarray, offsets: np.ndarray) -> np.ndarray:
"""Read little-endian uint32 values at given byte offsets using numpy gather.
Reads 4 consecutive bytes at each offset and assembles them into uint32 values
using vectorized shift-and-add instead of per-element struct.unpack_from.
"""
b0 = raw[offsets].astype(np.uint32)
b1 = raw[offsets + 1].astype(np.uint32)
b2 = raw[offsets + 2].astype(np.uint32)
b3 = raw[offsets + 3].astype(np.uint32)
return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
def _extract_fixed_fields(
data: bytes,
record_offsets: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Vectorized extraction of fixed-layout fields from all records.
Uses numpy byte-level gather to read all fixed fields across all records
in bulk, avoiding per-record Python loops and struct.unpack_from calls.
Returns:
l_shared, l_indiv, contig_ids, positions, qual_raw, n_alleles, n_info, n_fmt
All as 1-D numpy arrays with one element per record.
"""
raw = np.frombuffer(data, dtype=np.uint8)
# Record header: l_shared (u32 at +0), l_indiv (u32 at +4)
l_shared = _gather_u32(raw, record_offsets)
l_indiv = _gather_u32(raw, record_offsets + 4)
# Fixed section at base = offset + 8:
# chrom_id (i32 at +0), pos (i32 at +4), rlen (i32 at +8), qual (u32 at +12),
# n_alleles_info (u32 at +16), n_fmt_n_samples (u32 at +20)
base_offsets = record_offsets + 8
contig_ids = _gather_u32(raw, base_offsets).view(np.int32)
positions = _gather_u32(raw, base_offsets + 4).view(np.int32).astype(np.int64) + 1
qual_raw = _gather_u32(raw, base_offsets + 12)
n_alleles_info = _gather_u32(raw, base_offsets + 16)
n_fmt_n_samples = _gather_u32(raw, base_offsets + 20)
n_alleles = (n_alleles_info >> 16).astype(np.uint16)
n_info = (n_alleles_info & 0xFFFF).astype(np.uint16)
n_fmt = (n_fmt_n_samples >> 24).astype(np.uint8)
return l_shared, l_indiv, contig_ids, positions, qual_raw, n_alleles, n_info, n_fmt
def _resolve_variant_request(
data: bytes,
body_offset: int,
header: _BCFHeader,
region_filter: Optional[Tuple[str, Optional[int], Optional[int]]],
variant_ids: Optional[Sequence[str]],
variant_idxs: Optional[Sequence[int]],
) -> Tuple[int, Optional[List[int]], Optional[List[int]]]:
n_records = _count_records(data, body_offset)
requested_variant_idxs = None
if variant_idxs is not None:
requested_variant_idxs = np.asarray(variant_idxs, dtype=int).ravel()
if np.any((requested_variant_idxs < -n_records) | (requested_variant_idxs >= n_records)):
raise ValueError("One or more variant indexes are out of bounds.")
requested_variant_idxs = np.mod(requested_variant_idxs, n_records).tolist()
if variant_ids is None and requested_variant_idxs is None and region_filter is None:
return n_records, None, None
requested_idx_set = None if requested_variant_idxs is None else set(requested_variant_idxs)
requested_ids = None if variant_ids is None else {
str(value) for value in np.asarray(variant_ids, dtype=object).ravel()
}
found_ids = set()
selected_offsets: List[int] = []
selected_by_row: Dict[int, int] = {}
offset = body_offset
row_idx = 0
end = len(data)
while offset < end:
l_shared = _U32.unpack_from(data, offset)[0]
l_indiv = _U32.unpack_from(data, offset + 4)[0]
want_row = requested_idx_set is None or row_idx in requested_idx_set
passes = want_row
if passes and (region_filter is not None or requested_ids is not None):
chrom, pos, variant_id, ref, alts = _decode_record_identifiers(data, offset, header)
if region_filter is not None and not _vcf_region_matches(chrom, pos, region_filter):
passes = False
if passes and requested_ids is not None:
record_ids = _record_identifiers(chrom, pos, variant_id, ref, alts)
matched = requested_ids.intersection(record_ids)
if matched:
found_ids.update(matched)
else:
passes = False
if passes:
if requested_variant_idxs is None:
selected_offsets.append(offset)
else:
selected_by_row[row_idx] = offset
offset += 8 + l_shared + l_indiv
row_idx += 1
if requested_ids is not None:
missing = sorted(requested_ids - found_ids)
if missing:
raise ValueError(f"The following specified variants were not found: {missing}")
if requested_variant_idxs is None:
return n_records, selected_offsets, None
return n_records, None, [selected_by_row[row] for row in requested_variant_idxs if row in selected_by_row]
def _decode_gt_array(data: bytes, offset: int, n_samples: int, n_vals: int, type_size: int) -> np.ndarray:
if type_size not in _INT_UNSIGNED_DTYPES:
raise ValueError(f"Unsupported GT integer width in BCF FORMAT/GT: {type_size}")
if n_vals < 1:
return np.empty((n_samples, 0), dtype=np.int8)
raw = np.frombuffer(
data,
dtype=_INT_UNSIGNED_DTYPES[type_size],
count=n_samples * n_vals,
offset=offset,
).reshape(n_samples, n_vals)
decoded = (raw.astype(np.int32, copy=False) >> 1) - 1
if n_vals == 1:
padded = np.full((n_samples, 2), -1, dtype=np.int8)
padded[:, 0] = decoded[:, 0].astype(np.int8, copy=False)
return padded
if n_vals != 2:
raise ValueError("BCFReader currently supports haploid or diploid GT fields only.")
return decoded.astype(np.int8, copy=False)
def _decode_gp_array(data: bytes, offset: int, n_samples: int, n_vals: int) -> np.ndarray:
if n_vals < 1:
return np.empty((n_samples, 0), dtype=np.float32)
return np.frombuffer(
data,
dtype=np.dtype("<f4"),
count=n_samples * n_vals,
offset=offset,
).reshape(n_samples, n_vals).copy()
def _parse_filter_pass(
data: bytes,
offset: int,
header: _BCFHeader,
) -> Tuple[bool, int]:
filter_ids, offset = _read_int_list(data, offset)
filter_names = [header.filters.get(int(idx), str(idx)) for idx in filter_ids if idx is not None]
if not filter_names:
return True, offset
return len(filter_names) == 1 and filter_names[0] == "PASS", offset
def _parse_info_string(
data: bytes,
offset: int,
n_info: int,
header: _BCFHeader,
) -> str:
items = []
for _ in range(n_info):
info_idx, offset = _read_scalar_typed_int(data, offset)
meta = header.info.get(info_idx, {"ID": f"INFO_{info_idx}", "Type": "", "Number": ""})
key = meta["ID"]
n_vals, type_code, type_size, value_offset = _read_typed_descriptor(data, offset)
offset = value_offset
if type_code == 0:
items.append(key)
continue
if type_code == 7:
value = data[offset:offset + n_vals * type_size].split(b"\0", 1)[0].decode("utf-8")
offset += n_vals * type_size
items.append(f"{key}={value}")
continue
if type_code == 5:
values = []
for _ in range(n_vals):
raw = _U32.unpack_from(data, offset)[0]
offset += 4
if raw == _FLOAT_VECTOR_END:
break
if raw == _FLOAT_MISSING:
values.append(np.nan)
continue
values.append(_F32.unpack_from(data, offset - 4)[0])
items.append(f"{key}={_render_info_value(values[0] if len(values) == 1 else values)}")
continue
if type_code in (1, 2, 3):
values = []
missing = 1 << ((type_size * 8) - 1)
vector_end = missing | 0x1
for _ in range(n_vals):
raw = int.from_bytes(data[offset:offset + type_size], "little", signed=False)
offset += type_size
if raw == vector_end:
break
if raw == missing:
values.append(None)
continue
values.append(int.from_bytes(raw.to_bytes(type_size, "little"), "little", signed=True))
items.append(f"{key}={_render_info_value(values[0] if len(values) == 1 else values)}")
continue
raise ValueError(f"Unsupported BCF INFO atomic type code: {type_code}")
return ";".join(items) if items else "."
def _skip_info_block(data: bytes, offset: int, n_info: int) -> int:
for _ in range(n_info):
offset = _skip_typed_value(data, offset)
offset = _skip_typed_value(data, offset)
return offset
def _skip_shared_to_filter(data: bytes, base: int, n_alleles: int) -> int:
"""Skip from base+24 past ID, REF, and ALT strings to reach the FILTER field."""
offset = base + 24
# Skip ID string
offset = _skip_typed_value(data, offset)
# Skip REF string
offset = _skip_typed_value(data, offset)
# Skip ALT strings
for _ in range(max(0, n_alleles - 1)):
offset = _skip_typed_value(data, offset)
return offset
def _probe_gt_layout(
data: bytes,
indiv_offset: int,
n_fmt: int,
n_samples: int,
header: _BCFHeader,
) -> Optional[Tuple[int, int, int, int]]:
"""Probe the FORMAT section of one record to find GT layout.
Returns (gt_data_offset_from_indiv, n_vals, type_size, total_indiv_bytes)
or None if GT is not found.
gt_data_offset_from_indiv is the byte offset from indiv_offset to the start
of the GT sample data for this record.
"""
format_offset = indiv_offset
for _ in range(n_fmt):
fmt_idx, format_offset = _read_scalar_typed_int(data, format_offset)
n_vals, type_code, type_size, values_offset = _read_typed_descriptor(data, format_offset)
key = header.formats.get(fmt_idx, {"ID": f"FORMAT_{fmt_idx}"})["ID"]
values_nbytes = n_samples * n_vals * type_size
if key == "GT":
gt_data_offset = values_offset - indiv_offset
return gt_data_offset, n_vals, type_size, values_offset + values_nbytes - indiv_offset
format_offset = values_offset + values_nbytes
return None
def _probe_gp_layout(
data: bytes,
indiv_offset: int,
n_fmt: int,
n_samples: int,
header: _BCFHeader,
) -> Optional[Tuple[int, int, int]]:
"""Probe the FORMAT section of one record to find GP layout.
Returns (gp_data_offset_from_indiv, n_vals, type_size) or None if GP is not
found.
"""
format_offset = indiv_offset
for _ in range(n_fmt):
fmt_idx, format_offset = _read_scalar_typed_int(data, format_offset)
n_vals, type_code, type_size, values_offset = _read_typed_descriptor(data, format_offset)
key = header.formats.get(fmt_idx, {"ID": f"FORMAT_{fmt_idx}"})["ID"]
values_nbytes = n_samples * n_vals * type_size
if key == "GP":
if type_code != 5 or type_size != 4:
raise ValueError("BCF FORMAT/GP is expected to be stored as float32 values.")
gp_data_offset = values_offset - indiv_offset
return gp_data_offset, n_vals, type_size
format_offset = values_offset + values_nbytes
return None
def _batch_decode_gt(
data: bytes,
indiv_offsets: np.ndarray,
gt_data_rel_offset: int,
n_vals: int,
type_size: int,
n_samples: int,
n_records: int,
sample_index_array: np.ndarray,
sum_strands: bool,
) -> np.ndarray:
"""Batch-decode GT data for all records using vectorized numpy operations.
Instead of calling np.frombuffer per record, we gather all GT bytes into a
single contiguous buffer and decode them all at once.
"""
if type_size not in _INT_UNSIGNED_DTYPES:
raise ValueError(f"Unsupported GT integer width in BCF FORMAT/GT: {type_size}")
if n_vals < 1:
n_sel = len(sample_index_array)
if sum_strands:
return np.empty((n_records, n_sel), dtype=np.int8)
return np.empty((n_records, n_sel, 0), dtype=np.int8)
dtype = _INT_UNSIGNED_DTYPES[type_size]
gt_bytes_per_record = n_samples * n_vals * type_size
# Build index array to gather all GT bytes from the raw buffer
gt_starts = indiv_offsets + gt_data_rel_offset
# Create a (n_records, gt_bytes_per_record) array of byte offsets
byte_offsets_per_sample = np.arange(gt_bytes_per_record, dtype=np.int64)
all_byte_offsets = gt_starts[:, None] + byte_offsets_per_sample[None, :]
# Gather all bytes into a contiguous buffer
raw_bytes = np.frombuffer(data, dtype=np.uint8)
gathered = raw_bytes[all_byte_offsets.ravel()]
# Reinterpret as the correct integer type
raw = np.frombuffer(gathered.tobytes(), dtype=dtype).reshape(n_records, n_samples, n_vals)
# Decode: BCF GT encoding is (allele_index + 1) << 1 | phase
decoded = (raw.astype(np.int32, copy=False) >> 1) - 1
if n_vals == 1:
# Haploid: pad to diploid with -1
padded = np.full((n_records, n_samples, 2), -1, dtype=np.int8)
padded[:, :, 0] = decoded[:, :, 0].astype(np.int8, copy=False)
selected = padded[:, sample_index_array, :]
elif n_vals == 2:
selected = decoded[:, sample_index_array, :].astype(np.int8, copy=False)
else:
raise ValueError("BCFReader currently supports haploid or diploid GT fields only.")
if sum_strands:
return selected.sum(axis=2, dtype=np.int8)
return selected
def _batch_decode_gp(
data: bytes,
indiv_offsets: np.ndarray,
gp_data_rel_offset: int,
n_vals: int,
n_samples: int,
n_records: int,
sample_index_array: np.ndarray,
) -> np.ndarray:
"""Batch-decode GP data for all records using vectorized numpy operations."""
if n_vals < 1:
return np.empty((n_records, len(sample_index_array), 0), dtype=np.float32)
gp_bytes_per_record = n_samples * n_vals * 4 # float32
gp_starts = indiv_offsets + gp_data_rel_offset
byte_offsets_per_sample = np.arange(gp_bytes_per_record, dtype=np.int64)
all_byte_offsets = gp_starts[:, None] + byte_offsets_per_sample[None, :]
raw_bytes = np.frombuffer(data, dtype=np.uint8)
gathered = raw_bytes[all_byte_offsets.ravel()]
raw = np.frombuffer(gathered.tobytes(), dtype=np.dtype("<f4")).reshape(n_records, n_samples, n_vals)
return raw[:, sample_index_array, :].copy()
def _vectorized_qual(qual_raw: np.ndarray) -> np.ndarray:
"""Convert raw uint32 qual values to float32, handling BCF missing sentinel."""
result = np.empty(len(qual_raw), dtype=np.float32)
missing_mask = qual_raw == _FLOAT_MISSING
# Reinterpret the uint32 bits as float32
result[:] = np.frombuffer(qual_raw.tobytes(), dtype=np.float32)
result[missing_mask] = np.nan
return result
[docs]
@SNPBaseReader.register
class BCFReader(SNPBaseReader):
[docs]
def read(
self,
fields: Optional[Union[str, Sequence[str]]] = None,
exclude_fields: Optional[Union[str, Sequence[str]]] = None,
sample_ids: Optional[Sequence[str]] = None,
sample_idxs: Optional[Sequence[int]] = None,
variant_ids: Optional[Sequence[str]] = None,
variant_idxs: Optional[Sequence[int]] = None,
region: Optional[str] = None,
sum_strands: bool = False,
) -> SNPObject:
"""
Read a BCF file into a SNPObject.
Args:
fields: Fields to include. Supported fields are ``GT``, ``GP``, ``IID``,
``REF``, ``ALT``, ``#CHROM``, ``ID``, ``POS``, ``QUAL``,
``FILTER``, and ``INFO``. Use ``"*"`` to request the full set.
If None, the default core fields are loaded.
exclude_fields: Fields to exclude from the returned SNPObject.
sample_ids: Sample IDs to read. If None and sample_idxs is None, all
samples are read.
sample_idxs: Sample indices to read. Negative indexes follow NumPy
conventions.
variant_ids: Variant identifiers to read. Matches BCF ``ID``,
``chrom:pos``, or ``chrom:pos:ref:alt``.
variant_idxs: Variant indices to read. Negative indexes follow NumPy
conventions.
region: Optional genomic region, such as ``"22"`` or
``"22:100000-200000"``.
sum_strands: If True, sum the two diploid alleles per sample and
return dosages in ``genotypes``. If False, keep the two allele
columns separate.
Returns:
SNPObject: Object containing selected genotype, sample, and variant
fields. ``GP`` is stored on ``SNPObject.calldata_gp`` when present.
"""
if sample_idxs is not None and sample_ids is not None:
raise ValueError("Only one of sample_idxs and sample_ids can be specified.")
if variant_idxs is not None and variant_ids is not None:
raise ValueError("Only one of variant_idxs and variant_ids can be specified.")
selected_fields = _normalize_fields(fields, exclude_fields)
region_filter = _parse_vcf_region(region)
data, body_offset, header = _load_bcf_data(str(self.filename))
file_samples = np.asarray(header.samples, dtype=object)
sample_index_array = _resolve_sample_indices(file_samples, sample_ids, sample_idxs)
has_filtering = (variant_ids is not None or variant_idxs is not None or region_filter is not None)
if has_filtering:
return self._read_filtered(
data, body_offset, header, file_samples, sample_index_array,
selected_fields, region_filter, variant_ids, variant_idxs, sum_strands,
)
return self._read_all(
data, body_offset, header, file_samples, sample_index_array,
selected_fields, sum_strands,
)
def _read_all(
self,
data: bytes,
body_offset: int,
header: _BCFHeader,
file_samples: np.ndarray,
sample_index_array: np.ndarray,
selected_fields: list[str],
sum_strands: bool,
) -> SNPObject:
"""Optimized bulk read of all records with no variant filtering."""
# Pass 1: build record offset table and extract fixed fields
record_offsets = _build_record_offsets(data, body_offset)
n_records = len(record_offsets)
if n_records == 0:
return self._empty_snpobject(selected_fields, file_samples, sample_index_array, sum_strands)
l_shared, l_indiv, contig_ids, positions, qual_raw, n_alleles, n_info_arr, n_fmt_arr = \
_extract_fixed_fields(data, record_offsets)
n_file_samples = len(file_samples)
n_selected_samples = len(sample_index_array)
# Vectorized chrom
need_chrom = "#CHROM" in selected_fields
variants_chrom = None
if need_chrom:
variants_chrom = np.empty(n_records, dtype=object)
unique_contig_ids = np.unique(contig_ids)
for cid in unique_contig_ids:
mask = contig_ids == cid
variants_chrom[mask] = header.contigs[int(cid)]
# Vectorized pos
variants_pos = positions if "POS" in selected_fields else None
# Vectorized qual
variants_qual = _vectorized_qual(qual_raw) if "QUAL" in selected_fields else None
# Samples
samples = file_samples[sample_index_array] if "IID" in selected_fields else None
# Compute indiv offsets for GT/GP decode
indiv_offsets = record_offsets + 8 + l_shared.astype(np.int64)
need_gt = "GT" in selected_fields
need_gp = "GP" in selected_fields
genotypes = None
calldata_gp = None
# Batch GT decode
if need_gt and n_records > 0:
# Probe the first record to determine GT layout
first_n_fmt = int(n_fmt_arr[0])
first_indiv_offset = int(indiv_offsets[0])
gt_layout = _probe_gt_layout(data, first_indiv_offset, first_n_fmt, n_file_samples, header)
if gt_layout is None:
raise ValueError("BCF FORMAT field does not contain GT for all selected records.")
gt_data_rel_offset, gt_n_vals, gt_type_size, _ = gt_layout
# Check if all records have uniform l_indiv (same FORMAT layout)
uniform_indiv = np.all(l_indiv == l_indiv[0])
if uniform_indiv:
genotypes = _batch_decode_gt(
data, indiv_offsets, gt_data_rel_offset, gt_n_vals, gt_type_size,
n_file_samples, n_records, sample_index_array, sum_strands,
)
else:
# Fallback: per-record GT decode
if sum_strands:
genotypes = np.empty((n_records, n_selected_samples), dtype=np.int8)
else:
genotypes = np.empty((n_records, n_selected_samples, 2), dtype=np.int8)
for i in range(n_records):
cur_indiv = int(indiv_offsets[i])
cur_n_fmt = int(n_fmt_arr[i])
cur_gt_layout = _probe_gt_layout(data, cur_indiv, cur_n_fmt, n_file_samples, header)
if cur_gt_layout is None:
raise ValueError("BCF FORMAT field does not contain GT for all selected records.")
rel_off, nv, ts, _ = cur_gt_layout
gt = _decode_gt_array(data, cur_indiv + rel_off, n_file_samples, nv, ts)
gt = gt[sample_index_array]
if sum_strands:
genotypes[i] = gt.sum(axis=1, dtype=np.int8)
else:
genotypes[i] = gt
# Batch GP decode
if need_gp and n_records > 0:
first_indiv_offset = int(indiv_offsets[0])
first_n_fmt = int(n_fmt_arr[0])
gp_layout = _probe_gp_layout(data, first_indiv_offset, first_n_fmt, n_file_samples, header)
if gp_layout is not None:
gp_data_rel_offset, gp_n_vals, _ = gp_layout
uniform_indiv = np.all(l_indiv == l_indiv[0])
if uniform_indiv:
calldata_gp = _batch_decode_gp(
data, indiv_offsets, gp_data_rel_offset, gp_n_vals,
n_file_samples, n_records, sample_index_array,
)
else:
gp_rows: list[Optional[np.ndarray]] = [None] * n_records
gp_width = 0
for i in range(n_records):
cur_indiv = int(indiv_offsets[i])
cur_n_fmt = int(n_fmt_arr[i])
cur_gp = _probe_gp_layout(data, cur_indiv, cur_n_fmt, n_file_samples, header)
if cur_gp is not None:
rel_off, nv, _ = cur_gp
gp = _decode_gp_array(data, cur_indiv + rel_off, n_file_samples, nv)[sample_index_array]
gp_rows[i] = gp
gp_width = max(gp_width, gp.shape[1])
if gp_width > 0:
calldata_gp = self._pad_gp_rows(gp_rows, n_selected_samples, gp_width)
# String fields: ID, REF, ALT, FILTER, INFO - must iterate per-record
need_id = "ID" in selected_fields
need_ref = "REF" in selected_fields
need_alt = "ALT" in selected_fields
need_filter = "FILTER" in selected_fields
need_info = "INFO" in selected_fields
need_strings = need_id or need_ref or need_alt or need_filter or need_info
variants_id = np.empty(n_records, dtype=object) if need_id else None
variants_ref = np.empty(n_records, dtype=object) if need_ref else None
variants_alt = np.empty(n_records, dtype=object) if need_alt else None
variants_filter_pass = np.empty(n_records, dtype=bool) if need_filter else None
variants_info = np.empty(n_records, dtype=object) if need_info else None
if need_strings:
ref_cache = {}
alt_cache = {}
filters_dict = header.filters
for i in range(n_records):
base = int(record_offsets[i]) + 8
cur_n_alleles = int(n_alleles[i])
cur_n_info = int(n_info_arr[i])
offset = base + 24
# 1. ID field
if need_id:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
if n_vals == 0:
variant_id = "."
elif n_vals == 1 and data[offset] == 46:
variant_id = "."
else:
val = data[offset:end]
idx = val.find(b"\0")
variant_id = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
offset = end
else:
variant_id, offset = _read_typed_string(data, offset - 1)
variants_id[i] = variant_id if variant_id else "."
else:
offset = _skip_typed_value_fast(data, offset)
# 2. REF field
if need_ref:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
ref = ref_cache.get(val)
if ref is None:
idx = val.find(b"\0")
ref = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
ref_cache[val] = ref
offset = end
else:
ref, offset = _read_typed_string(data, offset - 1)
variants_ref[i] = ref
else:
offset = _skip_typed_value_fast(data, offset)
# 3. ALT field
if need_alt:
if cur_n_alleles <= 1:
alt_str = ""
elif cur_n_alleles == 2:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
alt_str = alt_cache.get(val)
if alt_str is None:
idx = val.find(b"\0")
alt_str = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
alt_cache[val] = alt_str
offset = end
else:
alt_str, offset = _read_typed_string(data, offset - 1)
else:
alts = []
for _ in range(cur_n_alleles - 1):
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
alt = alt_cache.get(val)
if alt is None:
idx = val.find(b"\0")
alt = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
alt_cache[val] = alt
offset = end
else:
alt, offset = _read_typed_string(data, offset - 1)
alts.append(alt)
alt_str = ",".join(alts)
variants_alt[i] = alt_str
else:
for _ in range(max(0, cur_n_alleles - 1)):
offset = _skip_typed_value_fast(data, offset)
# 4. FILTER field
if need_filter:
b = data[offset]
offset += 1
type_code = b & 0x0F
n_vals = b >> 4
if n_vals == 0:
filter_pass = True
elif n_vals == 1 and type_code == 1:
val = data[offset]
offset += 1
if val == 128 or val == 129:
filter_pass = True
else:
filter_name = filters_dict.get(val, str(val))
filter_pass = (filter_name == "PASS")
else:
filter_pass, offset = _parse_filter_pass(data, offset - 1, header)
variants_filter_pass[i] = filter_pass
elif need_info:
offset = _skip_typed_value_fast(data, offset)
# 5. INFO field
if need_info:
variants_info[i] = _parse_info_string(data, offset, cur_n_info, header)
return SNPObject(
genotypes=genotypes,
calldata_gp=calldata_gp,
samples=samples,
variants_ref=variants_ref,
variants_alt=variants_alt,
variants_chrom=variants_chrom,
variants_id=variants_id,
variants_pos=variants_pos,
variants_qual=variants_qual,
variants_filter_pass=variants_filter_pass,
variants_info=variants_info,
)
def _read_filtered(
self,
data: bytes,
body_offset: int,
header: _BCFHeader,
file_samples: np.ndarray,
sample_index_array: np.ndarray,
selected_fields: list[str],
region_filter: Optional[Tuple[str, Optional[int], Optional[int]]],
variant_ids: Optional[Sequence[str]],
variant_idxs: Optional[Sequence[int]],
sum_strands: bool,
) -> SNPObject:
"""Read with variant filtering - uses the original per-record approach."""
n_records, selected_offsets, requested_offsets = _resolve_variant_request(
data, body_offset, header, region_filter, variant_ids, variant_idxs,
)
if requested_offsets is not None:
record_offsets_list = requested_offsets
elif selected_offsets is not None:
record_offsets_list = selected_offsets
else:
record_offsets_list = None
if record_offsets_list is None:
# No filtering was actually applied - redirect to fast path
return self._read_all(data, body_offset, header, file_samples,
sample_index_array, selected_fields, sum_strands)
n_selected_records = len(record_offsets_list)
n_selected_samples = len(sample_index_array)
n_file_samples = len(file_samples)
samples = file_samples[sample_index_array] if "IID" in selected_fields else None
if "GT" in selected_fields:
if sum_strands:
genotypes = np.empty((n_selected_records, n_selected_samples), dtype=np.int8)
else:
genotypes = np.empty((n_selected_records, n_selected_samples, 2), dtype=np.int8)
else:
genotypes = None
gp_rows: Optional[List[Optional[np.ndarray]]] = [None] * n_selected_records if "GP" in selected_fields else None
gp_width = 0
variants_ref = np.empty(n_selected_records, dtype=object) if "REF" in selected_fields else None
variants_alt = np.empty(n_selected_records, dtype=object) if "ALT" in selected_fields else None
variants_chrom = np.empty(n_selected_records, dtype=object) if "#CHROM" in selected_fields else None
variants_id = np.empty(n_selected_records, dtype=object) if "ID" in selected_fields else None
variants_pos = np.empty(n_selected_records, dtype=np.int64) if "POS" in selected_fields else None
variants_qual = np.empty(n_selected_records, dtype=np.float32) if "QUAL" in selected_fields else None
variants_filter_pass = np.empty(n_selected_records, dtype=bool) if "FILTER" in selected_fields else None
variants_info = np.empty(n_selected_records, dtype=object) if "INFO" in selected_fields else None
need_strings = any(field in selected_fields for field in ("ID", "REF", "ALT", "FILTER", "INFO"))
need_info = "INFO" in selected_fields
need_filter = "FILTER" in selected_fields
need_gt = "GT" in selected_fields
need_gp = "GP" in selected_fields
need_id = variants_id is not None
need_ref = variants_ref is not None
need_alt = variants_alt is not None
ref_cache = {}
alt_cache = {}
filters_dict = header.filters
for out_idx, record_offset in enumerate(record_offsets_list):
l_shared = _U32.unpack_from(data, record_offset)[0]
l_indiv = _U32.unpack_from(data, record_offset + 4)[0]
base = record_offset + 8
indiv_offset = base + l_shared
contig_id = _I32.unpack_from(data, base)[0]
pos = _I32.unpack_from(data, base + 4)[0] + 1
n_alleles = _U32.unpack_from(data, base + 16)[0] >> 16
n_info = _U32.unpack_from(data, base + 16)[0] & 0xFFFF
n_fmt = _U32.unpack_from(data, base + 20)[0] >> 24
n_samples = _U32.unpack_from(data, base + 20)[0] & 0xFFFFFF
if n_samples != n_file_samples:
raise ValueError(
f"BCF record sample count ({n_samples}) does not match header sample count "
f"({n_file_samples})."
)
if variants_chrom is not None:
variants_chrom[out_idx] = header.contigs[contig_id]
if variants_pos is not None:
variants_pos[out_idx] = pos
if variants_qual is not None:
variants_qual[out_idx] = _variant_qual(data, base)
if need_strings:
offset = base + 24
# 1. ID field
if need_id:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
if n_vals == 0:
variant_id = "."
elif n_vals == 1 and data[offset] == 46:
variant_id = "."
else:
val = data[offset:end]
idx = val.find(b"\0")
variant_id = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
offset = end
else:
variant_id, offset = _read_typed_string(data, offset - 1)
variants_id[out_idx] = variant_id if variant_id else "."
else:
offset = _skip_typed_value_fast(data, offset)
# 2. REF field
if need_ref:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
ref = ref_cache.get(val)
if ref is None:
idx = val.find(b"\0")
ref = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
ref_cache[val] = ref
offset = end
else:
ref, offset = _read_typed_string(data, offset - 1)
variants_ref[out_idx] = ref
else:
offset = _skip_typed_value_fast(data, offset)
# 3. ALT field
if need_alt:
if n_alleles <= 1:
alt_str = ""
elif n_alleles == 2:
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
alt_str = alt_cache.get(val)
if alt_str is None:
idx = val.find(b"\0")
alt_str = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
alt_cache[val] = alt_str
offset = end
else:
alt_str, offset = _read_typed_string(data, offset - 1)
else:
alts = []
for _ in range(n_alleles - 1):
b = data[offset]
offset += 1
n_vals = b >> 4
if n_vals < 15:
end = offset + n_vals
val = data[offset:end]
alt = alt_cache.get(val)
if alt is None:
idx = val.find(b"\0")
alt = val[:idx].decode("utf-8") if idx != -1 else val.decode("utf-8")
alt_cache[val] = alt
offset = end
else:
alt, offset = _read_typed_string(data, offset - 1)
alts.append(alt)
alt_str = ",".join(alts)
variants_alt[out_idx] = alt_str
else:
for _ in range(max(0, n_alleles - 1)):
offset = _skip_typed_value_fast(data, offset)
# 4. FILTER field
if need_filter:
b = data[offset]
offset += 1
type_code = b & 0x0F
n_vals = b >> 4
if n_vals == 0:
filter_pass = True
elif n_vals == 1 and type_code == 1:
val = data[offset]
offset += 1
if val == 128 or val == 129:
filter_pass = True
else:
filter_name = filters_dict.get(val, str(val))
filter_pass = (filter_name == "PASS")
else:
filter_pass, offset = _parse_filter_pass(data, offset - 1, header)
variants_filter_pass[out_idx] = filter_pass
elif need_info:
offset = _skip_typed_value_fast(data, offset)
# 5. INFO field
if need_info:
variants_info[out_idx] = _parse_info_string(data, offset, n_info, header)
if not (need_gt or need_gp) or l_indiv == 0:
continue
format_offset = indiv_offset
gt_seen = False
for _ in range(n_fmt):
fmt_idx, format_offset = _read_scalar_typed_int(data, format_offset)
n_vals, type_code, type_size, values_offset = _read_typed_descriptor(data, format_offset)
key = header.formats.get(fmt_idx, {"ID": f"FORMAT_{fmt_idx}"})["ID"]
values_nbytes = n_samples * n_vals * type_size
if key == "GT" and need_gt:
gt = _decode_gt_array(data, values_offset, n_samples, n_vals, type_size)
gt = gt[sample_index_array]
if sum_strands:
genotypes[out_idx] = gt.sum(axis=1, dtype=np.int8)
else:
genotypes[out_idx] = gt
gt_seen = True
elif key == "GP" and need_gp:
if type_code != 5 or type_size != 4:
raise ValueError("BCF FORMAT/GP is expected to be stored as float32 values.")
gp = _decode_gp_array(data, values_offset, n_samples, n_vals)[sample_index_array]
gp_rows[out_idx] = gp
gp_width = max(gp_width, gp.shape[1])
format_offset = values_offset + values_nbytes
if need_gt and not gt_seen:
raise ValueError("BCF FORMAT field does not contain GT for all selected records.")
calldata_gp = None
if gp_rows is not None and gp_width > 0:
calldata_gp = self._pad_gp_rows(gp_rows, n_selected_samples, gp_width)
return SNPObject(
genotypes=genotypes,
calldata_gp=calldata_gp,
samples=samples,
variants_ref=variants_ref,
variants_alt=variants_alt,
variants_chrom=variants_chrom,
variants_id=variants_id,
variants_pos=variants_pos,
variants_qual=variants_qual,
variants_filter_pass=variants_filter_pass,
variants_info=variants_info,
)
@staticmethod
def _pad_gp_rows(
gp_rows: List[Optional[np.ndarray]],
n_selected_samples: int,
gp_width: int,
) -> np.ndarray:
padded_rows = []
for row in gp_rows:
if row is None:
padded_rows.append(np.full((n_selected_samples, gp_width), np.nan, dtype=np.float32))
continue
if row.shape[1] == gp_width:
padded_rows.append(row)
continue
padded = np.full((n_selected_samples, gp_width), np.nan, dtype=np.float32)
padded[:, : row.shape[1]] = row
padded_rows.append(padded)
return np.stack(padded_rows, axis=0) if padded_rows else np.empty((0, n_selected_samples, gp_width), dtype=np.float32)
@staticmethod
def _empty_snpobject(
selected_fields: list[str],
file_samples: np.ndarray,
sample_index_array: np.ndarray,
sum_strands: bool,
) -> SNPObject:
n_sel = len(sample_index_array)
return SNPObject(
genotypes=np.empty((0, n_sel) if sum_strands else (0, n_sel, 2), dtype=np.int8) if "GT" in selected_fields else None,
calldata_gp=None,
samples=file_samples[sample_index_array] if "IID" in selected_fields else None,
variants_ref=np.empty(0, dtype=object) if "REF" in selected_fields else None,
variants_alt=np.empty(0, dtype=object) if "ALT" in selected_fields else None,
variants_chrom=np.empty(0, dtype=object) if "#CHROM" in selected_fields else None,
variants_id=np.empty(0, dtype=object) if "ID" in selected_fields else None,
variants_pos=np.empty(0, dtype=np.int64) if "POS" in selected_fields else None,
variants_qual=np.empty(0, dtype=np.float32) if "QUAL" in selected_fields else None,
variants_filter_pass=np.empty(0, dtype=bool) if "FILTER" in selected_fields else None,
variants_info=np.empty(0, dtype=object) if "INFO" in selected_fields else None,
)