import argparse
import csv
import gzip
import logging
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
import pandas as pd
from snputils.phenotype.io.read import PhenotypeReader
from snputils.phenotype.genobj import CovariateObject, PhenotypeObject
from snputils.snp.genobj import SNPObject
from snputils.snp.io.read import BEDReader, PGENReader, SNPReader, VCFReader
from snputils.snp.io.read.vcf import VCFReaderPolars
from ._association import (
_apply_multiple_testing_adjustment,
_compute_effective_chunk_size,
_compute_group_counts_batch,
_compute_linear_ci_beta,
_compute_logistic_ci_or,
_compute_multiple_testing_adjustments,
_coerce_covar_source,
_confidence_interval_label,
_enforce_memory_budget,
_fit_linear_batch,
_fit_linear_batch_with_covariates,
_fit_logistic_batch,
_fit_logistic_batch_with_covariates,
_get_process_rss_mb,
_normalize_chromosome,
_odds_ratio_batch,
_open_tsv_for_write,
_prepare_fwl,
_read_sample_list,
_resolve_output_path,
)
log = logging.getLogger(__name__)
def add_gwas_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--batch-size",
dest="batch_size",
required=False,
default=32768,
type=int,
help="Maximum number of variants processed per chunk.",
)
parser.add_argument(
"--memory",
dest="memory",
required=False,
default=None,
type=int,
help="Peak RSS-delta memory cap in MiB for internal chunked processing.",
)
parser.add_argument(
"--quantitative",
dest="quantitative",
required=False,
action="store_true",
default=None,
help="Optional override to force quantitative (linear) mode. Default is automatic trait detection.",
)
parser.add_argument(
"--verbose",
dest="verbose",
required=False,
action="store_true",
help="Print progress (variants processed, elapsed time, rate) during GWAS.",
)
parser.add_argument(
"--covar-path",
dest="covar_path",
required=False,
type=str,
default=None,
help="Path to covariate file (whitespace-delimited, header with #FID IID or #IID plus covariate columns).",
)
parser.add_argument(
"--covar-col-nums",
dest="covar_col_nums",
required=False,
type=str,
default=None,
help='Covariate column numbers relative to first covariate column (e.g. "1-5,7").',
)
parser.add_argument(
"--covar-variance-standardize",
dest="covar_variance_standardize",
required=False,
action="store_true",
help="Center and variance-standardize selected covariates.",
)
parser.add_argument(
"--variant-exclude",
dest="exclude_path",
required=False,
type=str,
default=None,
help="Path to variant exclusion file (one or more variant selectors per line).",
)
parser.add_argument(
"--ci",
dest="ci",
required=False,
type=float,
default=None,
help="Confidence level in (0, 1), e.g. 0.95 for L95/U95 columns.",
)
parser.add_argument(
"--adjust",
dest="adjust",
required=False,
action="store_true",
help="Add Bonferroni and Benjamini-Hochberg FDR adjusted p-values.",
)
parser.add_argument(
"--keep-path",
dest="keep_path",
required=False,
type=str,
default=None,
help="Path to keep file (FID IID or IID per line) for sample inclusion.",
)
parser.add_argument(
"--sample-remove",
dest="remove_path",
required=False,
type=str,
default=None,
help="Path to remove file (FID IID or IID per line) for sample exclusion.",
)
parser.add_argument(
"--vcf-backend",
dest="vcf_backend",
required=False,
choices=("polars", "scikit-allel"),
default="polars",
help="VCF reader backend (used only when --snp-path is VCF).",
)
parser.add_argument(
"--results-path",
dest="results_path",
required=False,
type=str,
default="gwas.tsv.gz",
help="Path used to save resulting data in compressed .tsv file (default: gwas.tsv.gz).",
)
parser.add_argument(
"--manhattan-plot",
dest="manhattan_plot",
required=False,
default=None,
type=str,
help="Optional path to save a Manhattan plot after GWAS completes (.pdf / .svg / .png, ...).",
)
parser.add_argument(
"--qq-plot",
dest="qq_plot",
required=False,
default=None,
type=str,
help="Optional path to save a Q-Q plot after GWAS completes (.pdf / .svg / .png, ...).",
)
required_argv = parser.add_argument_group("required arguments")
required_argv.add_argument(
"--phe-id",
dest="phe_id",
required=True,
type=str,
help="Phenotype ID / column name to analyze.",
)
required_argv.add_argument(
"--phe-path",
dest="phe_path",
required=True,
type=str,
help="Path to phenotype file (headered text with IID column and one or more phenotype columns; e.g. .txt, .phe, .pheno).",
)
required_argv.add_argument(
"--snp-path",
dest="snp_path",
required=True,
type=str,
help="Path to genotype input (VCF/BED/PGEN).",
)
def parse_gwas_args(argv: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog="gwas", description="Genome-wide association study (GWAS).")
add_gwas_arguments(parser)
return parser.parse_args(argv)
def _read_vcf_sample_ids(path: Union[str, Path]) -> List[str]:
vcf_path = Path(path)
if vcf_path.suffixes[-2:] == [".vcf", ".gz"]:
open_func = gzip.open
elif vcf_path.suffix == ".vcf":
open_func = open
else:
raise ValueError(f"Unsupported VCF extension for sample parsing: {vcf_path.suffixes}")
with open_func(vcf_path, "rt", encoding="utf-8") as handle:
for raw_line in handle:
if not raw_line.startswith("#CHROM") and not raw_line.startswith("CHROM"):
continue
line = raw_line.rstrip("\n")
parts = line.split("\t")
if len(parts) == 1:
parts = line.split()
if len(parts) <= 9:
raise ValueError("VCF header does not contain sample columns.")
return [str(sample) for sample in parts[9:]]
raise ValueError("VCF header line (#CHROM) not found.")
def _read_snp_samples(snp_reader: object) -> List[str]:
if isinstance(snp_reader, SNPObject):
samples = snp_reader.samples
if samples is None:
raise ValueError("In-memory SNPObject must include sample IDs for GWAS.")
return [str(sample) for sample in np.asarray(samples, dtype=object).tolist()]
if isinstance(snp_reader, (VCFReaderPolars, VCFReader)):
return _read_vcf_sample_ids(getattr(snp_reader, "filename"))
if isinstance(snp_reader, (BEDReader, PGENReader)):
sample_obj = snp_reader.read(fields=["IID"])
samples = sample_obj.samples
if samples is None:
raise ValueError("Failed to read sample IDs from SNP input.")
return [str(sample) for sample in np.asarray(samples, dtype=object).tolist()]
raise ValueError(f"Unsupported SNP reader type: {type(snp_reader).__name__}")
def _coerce_phenotype_source(
phe_source: Union[str, Path, PhenotypeObject],
*,
phe_id: Optional[str],
quantitative: Optional[bool],
) -> Tuple[PhenotypeObject, str]:
if isinstance(phe_source, PhenotypeObject):
phenotype_obj = (
phe_source
if quantitative is None
else PhenotypeObject(
samples=phe_source.samples,
values=phe_source.values,
phenotype_name=phe_source.phenotype_name,
quantitative=quantitative,
)
)
return phenotype_obj, str(phe_id or phenotype_obj.phenotype_name)
if phe_id is None:
raise TypeError("run_gwas() missing required argument: 'phe_id' when phenotype input is a path.")
phenotype_obj = PhenotypeReader(phe_source).read(
phenotype_col=phe_id,
quantitative=quantitative,
)
return phenotype_obj, str(phe_id)
def _coerce_snp_source(
snp_source: Union[str, Path, SNPObject, SNPReader, BEDReader, PGENReader, VCFReader, VCFReaderPolars],
*,
vcf_backend: str,
) -> object:
if isinstance(snp_source, (str, Path)):
return SNPReader(snp_source, vcf_backend=vcf_backend)
return snp_source
def _read_variant_list(path: Union[str, Path]) -> Set[str]:
selected: Set[str] = set()
with open(path, "r", encoding="utf-8") as handle:
for raw_line in handle:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
parts = [part for part in line.split() if part]
if not parts:
continue
selected.update(parts)
if len(parts) >= 2:
selected.add(f"{parts[0]}:{parts[1]}")
return selected
def _align_samples_to_snp_order(
snp_samples: Sequence[str],
phe_samples: Sequence[str],
y: np.ndarray,
quantitative: bool = False,
keep_ids: Optional[Set[str]] = None,
remove_ids: Optional[Set[str]] = None,
covar_samples: Optional[Sequence[str]] = None,
covar_matrix: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str], Optional[np.ndarray]]:
phe_to_idx: Dict[str, int] = {sample_id: idx for idx, sample_id in enumerate(phe_samples)}
covar_to_idx: Optional[Dict[str, int]] = None
if covar_samples is not None:
if covar_matrix is None:
raise ValueError("Internal error: covariate samples provided without covariate matrix.")
covar_to_idx = {sample_id: idx for idx, sample_id in enumerate(covar_samples)}
snp_indexes: List[int] = []
y_aligned: List[Union[int, float]] = []
covar_aligned: List[np.ndarray] = []
aligned_samples: List[str] = []
for snp_idx, sid in enumerate(snp_samples):
sid_str = str(sid)
if keep_ids is not None and sid_str not in keep_ids:
continue
if remove_ids is not None and sid_str in remove_ids:
continue
phe_idx = phe_to_idx.get(sid_str)
if phe_idx is None:
continue
if covar_to_idx is not None:
cov_idx = covar_to_idx.get(sid_str)
if cov_idx is None:
continue
snp_indexes.append(snp_idx)
yi = y[phe_idx]
y_aligned.append(float(yi) if quantitative else int(yi))
if covar_to_idx is not None and covar_matrix is not None:
covar_aligned.append(covar_matrix[cov_idx].astype(np.float64, copy=False))
aligned_samples.append(sid_str)
if not snp_indexes:
raise ValueError("No overlapping samples between phenotype and SNP input.")
if quantitative:
y_arr = np.asarray(y_aligned, dtype=np.float64)
if np.var(y_arr) <= 0.0:
raise ValueError("Quantitative phenotype has zero variance after SNP/PHE sample intersection.")
else:
y_arr = np.asarray(y_aligned, dtype=np.int8)
if int(np.sum(y_arr)) == 0:
raise ValueError("No cases after SNP/PHE sample intersection.")
if int(np.sum(y_arr)) == len(y_arr):
raise ValueError("No controls after SNP/PHE sample intersection.")
covar_out = np.asarray(covar_aligned, dtype=np.float64) if covar_to_idx is not None else None
return np.asarray(snp_indexes, dtype=np.int64), y_arr, aligned_samples, covar_out
def _iter_snp_chunks(
snp_reader: object,
chunk_size: int,
sample_indices: np.ndarray,
aligned_samples: Sequence[str],
) -> Iterator[Dict[str, Optional[np.ndarray]]]:
if isinstance(snp_reader, SNPObject):
if snp_reader.genotypes is None:
return
gt = np.asarray(snp_reader.genotypes)
if gt.ndim not in (2, 3):
raise ValueError("GWAS expects SNPObject.genotypes with shape (variants, samples[, strands]).")
sample_indices = np.asarray(sample_indices, dtype=np.int64)
n_variants = int(gt.shape[0])
for start in range(0, n_variants, int(chunk_size)):
stop = min(start + int(chunk_size), n_variants)
if gt.ndim == 2:
genotypes = gt[start:stop, :][:, sample_indices]
else:
genotypes = gt[start:stop, :, :][:, sample_indices, :]
yield {
"genotypes": genotypes,
"variants_chrom": None if snp_reader.variants_chrom is None else snp_reader.variants_chrom[start:stop],
"variants_pos": None if snp_reader.variants_pos is None else snp_reader.variants_pos[start:stop],
"variants_id": None if snp_reader.variants_id is None else snp_reader.variants_id[start:stop],
"variants_ref": None if snp_reader.variants_ref is None else snp_reader.variants_ref[start:stop],
"variants_alt": None if snp_reader.variants_alt is None else snp_reader.variants_alt[start:stop],
}
return
if isinstance(snp_reader, (BEDReader, PGENReader)):
for chunk in snp_reader.iter_read(
fields=["GT", "#CHROM", "POS", "ID", "REF", "ALT"],
sample_idxs=np.asarray(sample_indices, dtype=np.uint32),
sum_strands=True,
chunk_size=chunk_size,
):
yield {
"genotypes": chunk.genotypes,
"variants_chrom": chunk.variants_chrom,
"variants_pos": chunk.variants_pos,
"variants_id": chunk.variants_id,
"variants_ref": chunk.variants_ref,
"variants_alt": chunk.variants_alt,
}
return
if isinstance(snp_reader, VCFReaderPolars):
for chunk in snp_reader.iter_read(
fields=["#CHROM", "CHROM", "POS", "ID", "REF", "ALT"],
samples=list(aligned_samples),
sum_strands=True,
chunk_size=chunk_size,
):
yield {
"genotypes": chunk.genotypes,
"variants_chrom": chunk.variants_chrom,
"variants_pos": chunk.variants_pos,
"variants_id": chunk.variants_id,
"variants_ref": chunk.variants_ref,
"variants_alt": chunk.variants_alt,
}
return
if isinstance(snp_reader, VCFReader):
full = snp_reader.read(
fields=[
"variants/CHROM",
"variants/POS",
"variants/ID",
"variants/REF",
"variants/ALT",
"calldata/GT",
],
samples=list(aligned_samples),
sum_strands=True,
)
if full.genotypes is None:
return
n_variants = int(full.genotypes.shape[0])
for start in range(0, n_variants, int(chunk_size)):
stop = min(start + int(chunk_size), n_variants)
yield {
"genotypes": full.genotypes[start:stop],
"variants_chrom": None if full.variants_chrom is None else full.variants_chrom[start:stop],
"variants_pos": None if full.variants_pos is None else full.variants_pos[start:stop],
"variants_id": None if full.variants_id is None else full.variants_id[start:stop],
"variants_ref": None if full.variants_ref is None else full.variants_ref[start:stop],
"variants_alt": None if full.variants_alt is None else full.variants_alt[start:stop],
}
return
raise ValueError(f"Unsupported SNP reader type for chunking: {type(snp_reader).__name__}")
def _coerce_variant_text_array(
values: Optional[np.ndarray],
length: int,
default: str,
) -> np.ndarray:
out = np.full(length, default, dtype=object)
if values is None:
return out
arr = np.asarray(values, dtype=object).reshape(-1)
if arr.size == 0:
return out
if arr.size != length:
raise ValueError("Variant metadata length mismatch for GWAS chunk.")
for idx, raw in enumerate(arr.tolist()):
text = str(raw).strip()
if text and text != "." and text.upper() != "NAN":
out[idx] = text
return out
def _coerce_variant_chrom_array(
values: Optional[np.ndarray],
length: int,
) -> np.ndarray:
out = np.full(length, ".", dtype=object)
if values is None:
return out
arr = np.asarray(values, dtype=object).reshape(-1)
if arr.size == 0:
return out
if arr.size != length:
raise ValueError("Variant chromosome length mismatch for GWAS chunk.")
for idx, raw in enumerate(arr.tolist()):
text = str(raw).strip()
if text and text.upper() != "NAN":
out[idx] = _normalize_chromosome(text)
return out
def _coerce_variant_pos_array(
values: Optional[np.ndarray],
length: int,
offset: int,
) -> np.ndarray:
out = np.arange(offset + 1, offset + length + 1, dtype=np.int64)
if values is None:
return out
arr = np.asarray(values).reshape(-1)
if arr.size == 0:
return out
if arr.size != length:
raise ValueError("Variant position length mismatch for GWAS chunk.")
numeric = pd.to_numeric(pd.Series(arr), errors="coerce").to_numpy(dtype=np.float64)
valid = np.isfinite(numeric)
out[valid] = numeric[valid].astype(np.int64)
return out
def _build_variant_id_array(
values: Optional[np.ndarray],
chrom: np.ndarray,
pos: np.ndarray,
offset: int,
) -> np.ndarray:
length = int(pos.shape[0])
out = np.empty(length, dtype=object)
raw = None
if values is not None:
arr = np.asarray(values, dtype=object).reshape(-1)
if arr.size not in (0, length):
raise ValueError("Variant ID length mismatch for GWAS chunk.")
raw = arr if arr.size == length else None
for idx in range(length):
if raw is not None:
text = str(raw[idx]).strip()
if text and text != "." and text.upper() != "NAN":
out[idx] = text
continue
chrom_text = str(chrom[idx])
pos_val = int(pos[idx])
out[idx] = f"{chrom_text}:{pos_val}" if chrom_text != "." else f"v{offset + idx + 1}"
return out
def _extract_chunk_arrays(
chunk: Dict[str, Optional[np.ndarray]],
variant_offset: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
gt = chunk.get("genotypes")
if gt is None:
raise ValueError("Missing genotype calls in GWAS chunk.")
dosage = np.asarray(gt)
if dosage.ndim == 3:
dosage = dosage.sum(axis=2, dtype=np.int16)
if dosage.ndim != 2:
raise ValueError("GWAS expects genotype chunks with shape (variants, samples).")
if dosage.shape[1] == 0:
raise ValueError("No samples available in GWAS genotype chunk.")
invalid = (dosage < 0) | (dosage > 2)
if np.any(invalid):
raise ValueError(
"GWAS currently requires diploid dosages encoded as 0/1/2 with no missing values."
)
dosage_uint8 = dosage.astype(np.uint8, copy=False)
n_variants = int(dosage_uint8.shape[0])
chrom = _coerce_variant_chrom_array(chunk.get("variants_chrom"), n_variants)
pos = _coerce_variant_pos_array(chunk.get("variants_pos"), n_variants, offset=variant_offset)
variant_id = _build_variant_id_array(chunk.get("variants_id"), chrom, pos, offset=variant_offset)
ref = _coerce_variant_text_array(chunk.get("variants_ref"), n_variants, default="N")
alt = _coerce_variant_text_array(chunk.get("variants_alt"), n_variants, default=".")
return dosage_uint8, chrom, pos, variant_id, ref, alt
def _filter_chunk_variants(
dosage_batch: np.ndarray,
chrom: np.ndarray,
pos: np.ndarray,
variant_id: np.ndarray,
ref: np.ndarray,
alt: np.ndarray,
*,
exclude_variants: Optional[Set[str]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if not exclude_variants:
return dosage_batch, chrom, pos, variant_id, ref, alt
keep_mask = np.ones(int(pos.shape[0]), dtype=bool)
for idx in range(int(pos.shape[0])):
row_tokens = {
str(variant_id[idx]),
str(int(pos[idx])),
}
chrom_text = str(chrom[idx]).strip()
if chrom_text and chrom_text != ".":
row_tokens.add(f"{chrom_text}:{int(pos[idx])}")
if exclude_variants.intersection(row_tokens):
keep_mask[idx] = False
if np.all(keep_mask):
return dosage_batch, chrom, pos, variant_id, ref, alt
return (
dosage_batch[keep_mask],
chrom[keep_mask],
pos[keep_mask],
variant_id[keep_mask],
ref[keep_mask],
alt[keep_mask],
)
def _compute_linear_stats_from_dosage_batch(
dosage_batch: np.ndarray,
y: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_f64 = y.astype(np.float64, copy=False)
y_sq = y_f64 * y_f64
n_variants = int(dosage_batch.shape[0])
n_batch = np.empty((n_variants, 3), dtype=np.float64)
sum_y_batch = np.empty((n_variants, 3), dtype=np.float64)
sum_y2_batch = np.empty((n_variants, 3), dtype=np.float64)
for dosage_code in (0, 1, 2):
mask = dosage_batch == dosage_code
n_batch[:, dosage_code] = np.sum(mask, axis=1, dtype=np.int64)
sum_y_batch[:, dosage_code] = mask @ y_f64
sum_y2_batch[:, dosage_code] = mask @ y_sq
return n_batch, sum_y_batch, sum_y2_batch
[docs]
def run_gwas(
phe_path: Union[str, Path, PhenotypeObject],
snp_path: Union[str, Path, SNPObject, SNPReader, BEDReader, PGENReader, VCFReader, VCFReaderPolars],
results_path: Union[str, Path] = "gwas.tsv.gz",
phe_id: Optional[str] = None,
batch_size: int = 256,
memory: Optional[int] = None,
return_results: bool = True,
quantitative: Optional[bool] = None,
verbose: bool = False,
covar: Optional[Union[str, Path, CovariateObject]] = None,
covar_path: Optional[Union[str, Path]] = None,
covar_col_nums: Optional[str] = None,
covar_variance_standardize: bool = False,
ci: Optional[float] = None,
adjust: bool = False,
keep_path: Optional[Union[str, Path]] = None,
remove_path: Optional[Union[str, Path]] = None,
exclude_path: Optional[Union[str, Path]] = None,
vcf_backend: str = "polars",
) -> pd.DataFrame:
"""Run variant-level association testing.
``phe_path`` may be a phenotype file path or an in-memory
:class:`PhenotypeObject`. ``snp_path`` may be a genotype file path, reader,
or in-memory :class:`SNPObject`. ``covar`` may be a covariate file path or
an in-memory :class:`CovariateObject`; ``covar_path`` is retained as a
backward-compatible alias. ``phe_id`` is required only when the phenotype
input is a file path. Results are written to ``results_path`` (default:
gwas.tsv.gz).
"""
if covar is not None and covar_path is not None:
raise TypeError("Pass only one of `covar` or `covar_path`.")
covar_source = covar if covar is not None else covar_path
if memory is not None and int(memory) < 2:
raise MemoryError("--memory must be >= 2 MiB for internal GWAS processing.")
if ci is not None and (ci <= 0.0 or ci >= 1.0):
raise ValueError("--ci must be in the open interval (0, 1).")
phenotype_obj, output_phe_id = _coerce_phenotype_source(
phe_path,
phe_id=phe_id,
quantitative=quantitative,
)
phe_samples = phenotype_obj.samples
y = phenotype_obj.values
trait_is_quantitative = bool(phenotype_obj.is_quantitative)
keep_ids = _read_sample_list(keep_path) if keep_path is not None else None
remove_ids = _read_sample_list(remove_path) if remove_path is not None else None
exclude_variants = _read_variant_list(exclude_path) if exclude_path is not None else None
covar_samples, _covar_names, covar_matrix = _coerce_covar_source(
covar_source,
col_nums=covar_col_nums,
variance_standardize=covar_variance_standardize,
)
snp_reader = _coerce_snp_source(snp_path, vcf_backend=vcf_backend)
snp_samples = _read_snp_samples(snp_reader)
sample_indexes, y_aligned, aligned_samples, covar_aligned = _align_samples_to_snp_order(
snp_samples=snp_samples,
phe_samples=phe_samples,
y=y,
quantitative=trait_is_quantitative,
keep_ids=keep_ids,
remove_ids=remove_ids,
covar_samples=covar_samples,
covar_matrix=covar_matrix,
)
covariates_present = covar_aligned is not None
n_covar = int(covar_aligned.shape[1]) if covariates_present else 0
chunk_size = _compute_effective_chunk_size(
batch_size=batch_size,
n_samples=int(sample_indexes.size),
memory_mib=memory,
covariates_present=covariates_present,
quantitative=trait_is_quantitative,
)
obs_ct = int(y_aligned.size)
rss_baseline_mb = _get_process_rss_mb() if memory is not None else None
output_file = _resolve_output_path(results_path, output_phe_id, default_suffix="_gwas.tsv.gz")
ci_cols: List[str] = []
if ci is not None:
ci_suffix = _confidence_interval_label(ci)
ci_cols = [f"L{ci_suffix}", f"U{ci_suffix}"]
if trait_is_quantitative:
core_columns = [
"#CHROM", "POS", "END", "ID", "REF", "ALT", "A1",
"TEST", "OBS_CT", "BETA", "SE", "T_STAT", "P",
]
else:
core_columns = [
"#CHROM", "POS", "END", "ID", "REF", "ALT", "A1",
"TEST", "OBS_CT", "BETA", "OR", "LOG(OR)_SE", "Z_STAT", "P",
]
columns_without_adjust = core_columns + ci_cols + ["ERRCODE"]
final_columns = core_columns + ci_cols + (["BONF", "FDR_BH"] if adjust else []) + ["ERRCODE"]
covar_f64: Optional[np.ndarray] = None
y_resid: Optional[np.ndarray] = None
q_fwl: Optional[np.ndarray] = None
if trait_is_quantitative:
y_f64 = y_aligned.astype(np.float64, copy=False)
if covariates_present:
covar_f64 = covar_aligned.astype(np.float64, copy=False)
y_resid, q_fwl = _prepare_fwl(y_f64, covar_f64)
else:
y_binary = y_aligned.astype(np.int64, copy=False)
if covariates_present:
covar_f64 = covar_aligned.astype(np.float64, copy=False)
records: List[Dict[str, object]] = [] if return_results else []
collected_p_values: Optional[List[float]] = [] if adjust else None
variants_processed_total = 0
chunk_index = 0
try:
with _open_tsv_for_write(output_file) as handle:
writer = csv.writer(handle, delimiter="\t")
writer.writerow(columns_without_adjust)
if verbose:
print("Reading SNP input...", flush=True)
for chunk in _iter_snp_chunks(
snp_reader=snp_reader,
chunk_size=chunk_size,
sample_indices=sample_indexes,
aligned_samples=aligned_samples,
):
dosage_batch, chrom_arr, pos_arr, id_arr, ref_arr, alt_arr = _extract_chunk_arrays(
chunk,
variant_offset=variants_processed_total,
)
dosage_batch, chrom_arr, pos_arr, id_arr, ref_arr, alt_arr = _filter_chunk_variants(
dosage_batch,
chrom_arr,
pos_arr,
id_arr,
ref_arr,
alt_arr,
exclude_variants=exclude_variants,
)
n_variants = int(dosage_batch.shape[0])
if n_variants == 0:
continue
if verbose and (chunk_index % 100 == 0):
print(
f" Chunk {chunk_index}: {n_variants} variants "
f"(total so far: {variants_processed_total + n_variants:,})",
flush=True,
)
chunk_index += 1
if trait_is_quantitative:
if covariates_present:
if covar_f64 is None or y_resid is None or q_fwl is None:
raise ValueError("Internal error: missing covariate projection state.")
beta_arr, se_arr, t_arr, p_arr, errcode_arr = _fit_linear_batch_with_covariates(
dosage_batch,
y_resid,
q_fwl,
n_covar=n_covar,
)
df_linear = float(obs_ct - (2 + n_covar))
else:
n_batch, sy_batch, sy2_batch = _compute_linear_stats_from_dosage_batch(dosage_batch, y_f64)
beta_arr, se_arr, t_arr, p_arr, errcode_arr = _fit_linear_batch(
n_batch, sy_batch, sy2_batch
)
df_linear = float(obs_ct - 2)
if ci is not None:
ci_low_arr, ci_high_arr = _compute_linear_ci_beta(
beta_arr,
se_arr,
ci=ci,
df=df_linear,
)
else:
ci_low_arr = ci_high_arr = None
chunk_data: Dict[str, object] = {
"#CHROM": chrom_arr,
"POS": pos_arr.astype(np.int64, copy=False),
"END": pos_arr.astype(np.int64, copy=False),
"ID": id_arr,
"REF": ref_arr,
"ALT": alt_arr,
"A1": alt_arr,
"TEST": "LINEAR",
"OBS_CT": obs_ct,
"BETA": beta_arr,
"SE": se_arr,
"T_STAT": t_arr,
"P": p_arr,
}
if ci is not None and ci_low_arr is not None and ci_high_arr is not None:
chunk_data[ci_cols[0]] = ci_low_arr
chunk_data[ci_cols[1]] = ci_high_arr
chunk_data["ERRCODE"] = errcode_arr
chunk_df = pd.DataFrame(chunk_data)[columns_without_adjust]
chunk_df.to_csv(
handle,
sep="\t",
header=False,
index=False,
lineterminator="\n",
)
if return_results:
records.extend(chunk_df.to_dict("records"))
if collected_p_values is not None:
collected_p_values.extend(float(x) for x in p_arr.ravel())
else:
if covariates_present:
if covar_f64 is None:
raise ValueError("Internal error: missing aligned covariate matrix.")
beta_arr, se_arr, z_arr, p_arr, test_arr, errcode_arr = _fit_logistic_batch_with_covariates(
dosage_batch,
y_binary,
covar_f64,
)
else:
n_counts_batch, c_counts_batch = _compute_group_counts_batch(
dosage_batch,
y_binary,
)
beta_arr, se_arr, z_arr, p_arr, test_arr, errcode_arr = _fit_logistic_batch(
n_counts_batch,
c_counts_batch,
)
or_arr = _odds_ratio_batch(beta_arr)
if ci is not None:
ci_low_arr, ci_high_arr = _compute_logistic_ci_or(beta_arr, se_arr, ci=ci)
else:
ci_low_arr = ci_high_arr = None
chunk_data = {
"#CHROM": chrom_arr,
"POS": pos_arr.astype(np.int64, copy=False),
"END": pos_arr.astype(np.int64, copy=False),
"ID": id_arr,
"REF": ref_arr,
"ALT": alt_arr,
"A1": alt_arr,
"TEST": test_arr,
"OBS_CT": obs_ct,
"BETA": beta_arr,
"OR": or_arr,
"LOG(OR)_SE": se_arr,
"Z_STAT": z_arr,
"P": p_arr,
}
if ci is not None and ci_low_arr is not None and ci_high_arr is not None:
chunk_data[ci_cols[0]] = ci_low_arr
chunk_data[ci_cols[1]] = ci_high_arr
chunk_data["ERRCODE"] = errcode_arr
chunk_df = pd.DataFrame(chunk_data)[columns_without_adjust]
chunk_df.to_csv(
handle,
sep="\t",
header=False,
index=False,
lineterminator="\n",
)
if return_results:
records.extend(chunk_df.to_dict("records"))
if collected_p_values is not None:
collected_p_values.extend(float(x) for x in p_arr.ravel())
variants_processed_total += n_variants
_enforce_memory_budget(memory, rss_baseline_mb, context="GWAS chunk processing")
except Exception:
try:
output_file.unlink()
except FileNotFoundError:
pass
raise
if verbose and variants_processed_total > 0:
print(f" Done. Processed {variants_processed_total:,} variants.", flush=True)
if adjust:
if collected_p_values is None:
raise ValueError("Internal error: adjusted p-values requested but collection is unavailable.")
_apply_multiple_testing_adjustment(output_file, collected_p_values)
if return_results and records:
bonf_arr, fdr_arr = _compute_multiple_testing_adjustments(
np.asarray(collected_p_values, dtype=np.float64)
)
for rec, bonf_val, fdr_val in zip(records, bonf_arr.tolist(), fdr_arr.tolist()):
rec["BONF"] = bonf_val
rec["FDR_BH"] = fdr_val
if return_results:
results = pd.DataFrame.from_records(records)
results = results.reindex(columns=final_columns)
else:
results = pd.DataFrame(columns=final_columns)
log.info("GWAS results written to %s", output_file)
return results
def gwas(argv: Sequence[str]):
args = parse_gwas_args(argv)
return run_gwas_command(args)
def run_gwas_command(args: argparse.Namespace) -> int:
run_gwas(
phe_path=args.phe_path,
snp_path=args.snp_path,
results_path=args.results_path,
phe_id=args.phe_id,
batch_size=args.batch_size,
memory=args.memory,
return_results=False,
quantitative=args.quantitative,
verbose=args.verbose,
covar_path=args.covar_path,
covar_col_nums=args.covar_col_nums,
covar_variance_standardize=args.covar_variance_standardize,
ci=args.ci,
adjust=args.adjust,
keep_path=args.keep_path,
remove_path=args.remove_path,
exclude_path=args.exclude_path,
vcf_backend=args.vcf_backend,
)
manhattan_plot_path = getattr(args, "manhattan_plot", None)
qq_plot_path = getattr(args, "qq_plot", None)
if manhattan_plot_path is not None or qq_plot_path is not None:
import matplotlib
matplotlib.use("Agg", force=True)
actual_results = _resolve_output_path(args.results_path, args.phe_id, default_suffix="_gwas.tsv.gz")
if manhattan_plot_path is not None:
from snputils.visualization.manhattan_plot import manhattan_plot
manhattan_plot(str(actual_results), save=True, output_filename=manhattan_plot_path)
if qq_plot_path is not None:
from snputils.visualization.qq_plot import qq_plot
qq_plot(str(actual_results), save=True, output_filename=qq_plot_path)
return 0