import argparse
import sys
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
from snputils import __version__
from snputils.visualization.constants import get_palette_color
DOCS_URL = "https://docs.snputils.org"
SOURCE_URL = "https://github.com/AI-sandbox/snputils"
@dataclass(frozen=True)
class _Command:
help: str
add_arguments: Callable[[argparse.ArgumentParser], None]
run: Callable[[argparse.Namespace], int]
def _positive_int(value: str) -> int:
parsed = int(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("Value must be a positive integer.")
return parsed
def _run_pca(args: argparse.Namespace) -> int:
from . import pca as pca_module
return int(pca_module.run_pca_command(args))
def _run_admixture_map(args: argparse.Namespace) -> int:
from . import admixture_mapping as admix_module
return int(admix_module.run_admixmap_command(args))
def _run_gwas(args: argparse.Namespace) -> int:
from . import gwas as gwas_module
return int(gwas_module.run_gwas_command(args))
def _run_simulate(args: argparse.Namespace) -> int:
try:
from snputils.simulation.simulator_cli import run_simulator_command
except ModuleNotFoundError as exc:
if exc.name == "torch":
print(
"snputils simulate requires PyTorch. Install it with `pip install 'snputils[torch]'`.",
file=sys.stderr,
)
return 2
raise
return int(run_simulator_command(args))
def _run_mdpca(args: argparse.Namespace) -> int:
from snputils.ancestry.io.local.read import read_lai
from snputils.processing.mdpca import mdPCA
from snputils.snp.io.read import read_snp
snpobj = read_snp(args.snp_path, sum_strands=False)
laiobj = read_lai(args.lai_path)
mdPCA(
snpobj=snpobj,
laiobj=laiobj,
labels_file=args.labels_file,
ancestry=args.ancestry,
method=args.method,
is_masked=not args.unmasked,
average_strands=args.average_strands,
force_nan_incomplete_strands=args.force_nan_incomplete_strands,
is_weighted=args.weighted,
groups_to_remove=args.groups_to_remove,
min_percent_snps=args.min_percent_snps,
group_snp_frequencies_only=not args.include_individual_frequencies,
save_masks=args.save_masks,
load_masks=args.load_masks,
masks_file=args.masks_file,
embedding_table_path=args.coords,
covariance_matrix_file=args.covariance_matrix_file,
n_components=args.n_components,
rsid_or_chrompos=args.rsid_or_chrompos,
percent_vals_masked=args.percent_vals_masked,
)
if args.plot is not None:
import matplotlib
matplotlib.use("Agg", force=True)
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from snputils.visualization._figure_export import default_savefig_kwargs, scatter_rasterized_for_path
df = pd.read_csv(args.coords, sep=None, engine="python")
coord_cols = [c for c in df.columns if c.startswith("PC") or c.startswith("MDS")]
if not coord_cols:
coord_cols = df.select_dtypes(include=[float, int]).columns.tolist()
if len(coord_cols) >= 2:
x = df[coord_cols[0]].to_numpy(dtype=float)
y = df[coord_cols[1]].to_numpy(dtype=float)
x_label, y_label = coord_cols[0], coord_cols[1]
elif len(coord_cols) == 1:
x = df[coord_cols[0]].to_numpy(dtype=float)
y = np.zeros_like(x)
x_label, y_label = coord_cols[0], "Constant (0)"
else:
raise ValueError("Could not find coordinate columns in the mdPCA output table for plotting.")
plt.figure(figsize=(10, 8))
_scatter_kw: dict = {"linewidth": 0, "alpha": 0.5, "color": get_palette_color(0)}
if scatter_rasterized_for_path(str(args.plot)):
_scatter_kw["rasterized"] = True
plt.scatter(x, y, **_scatter_kw)
plt.xlabel(x_label, fontsize=20)
plt.ylabel(y_label, fontsize=20)
plt.tight_layout()
_save_kw = default_savefig_kwargs(str(args.plot))
plt.savefig(args.plot, **_save_kw)
return 0
def _run_maasmds(args: argparse.Namespace) -> int:
from snputils.ancestry.io.local.read import read_lai
from snputils.processing.maasmds import maasMDS
from snputils.snp.io.read import read_snp
snp_paths = _split_csv(args.snp_path)
lai_paths = _split_csv(args.lai_path)
if len(snp_paths) != len(lai_paths):
raise ValueError("--snp-path and --lai-path must contain the same number of comma-separated paths.")
snpobj = [read_snp(path, sum_strands=False) for path in snp_paths]
laiobj = [read_lai(path) for path in lai_paths]
if len(snpobj) == 1:
snp_arg = snpobj[0]
lai_arg = laiobj[0]
else:
snp_arg = snpobj
lai_arg = laiobj
maasMDS(
snpobj=snp_arg,
laiobj=lai_arg,
labels_file=args.labels_file,
ancestry=args.ancestry,
is_masked=not args.unmasked,
average_strands=args.average_strands,
force_nan_incomplete_strands=args.force_nan_incomplete_strands,
is_weighted=args.weighted,
groups_to_remove=args.groups_to_remove,
min_percent_snps=args.min_percent_snps,
group_snp_frequencies_only=not args.include_individual_frequencies,
save_masks=args.save_masks,
load_masks=args.load_masks,
masks_file=args.masks_file,
distance_type=args.distance_type,
n_components=args.n_components,
rsid_or_chrompos=args.rsid_or_chrompos,
embedding_table_path=args.coords,
)
return 0
def _run_plot_manhattan(args: argparse.Namespace) -> int:
import matplotlib
matplotlib.use("Agg", force=True)
from snputils.visualization.manhattan_plot import manhattan_plot
manhattan_plot(
args.results_path,
significance_threshold=args.significance_threshold,
point_size=args.point_size,
line_width=args.line_width,
line_color=args.line_color,
title=args.title,
save=True,
output_filename=args.output_path,
)
return 0
def _run_plot_qq(args: argparse.Namespace) -> int:
import matplotlib
matplotlib.use("Agg", force=True)
from snputils.visualization.qq_plot import qq_plot
qq_plot(
args.results_path,
color=args.color,
significance_threshold=args.significance_threshold,
point_size=args.point_size,
line_width=args.line_width,
expected_line_color=args.expected_line_color,
threshold_line_color=args.threshold_line_color,
title=args.title,
save=True,
output_filename=args.output_path,
)
return 0
def _split_csv(value: str) -> List[str]:
return [item.strip() for item in value.split(",") if item.strip()]
def _add_pca_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--snp-path",
dest="snp_path",
required=True,
type=str,
help="Path to genotype input (VCF, BED, or PGEN fileset).",
)
parser.add_argument(
"--plot",
dest="plot",
required=True,
type=str,
help="Path to save the PCA scatter plot (.pdf / .svg / .png, ...; vector formats use rasterized points at 300 dpi by default).",
)
parser.add_argument(
"--coords",
dest="coords",
default=None,
type=str,
help="Optional path to write PC coordinates as TSV/CSV (see snputils.processing.dimred_tabular).",
)
parser.add_argument(
"--components",
dest="components",
default=None,
type=str,
help="Optional path to save PCA components as a .npy file.",
)
parser.add_argument(
"--backend",
choices=("sklearn", "pytorch"),
default="sklearn",
help="PCA backend to use.",
)
parser.add_argument(
"--n-components",
dest="n_components",
default=2,
type=_positive_int,
help="Number of principal components to compute.",
)
parser.add_argument(
"--fitting",
dest="fitting",
choices=("exact", "lowrank"),
default="exact",
help=(
"SVD mode: exact (standard PCA; sklearn uses svd_solver='full') or "
"lowrank approximate (sklearn randomized / torch svd_lowrank)."
),
)
parser.add_argument(
"--sum-strands",
dest="sum_strands",
action="store_true",
help="Read diploid genotypes as per-individual summed strand counts.",
)
parser.add_argument(
"--vcf-backend",
dest="vcf_backend",
choices=("default", "polars"),
default="default",
help="VCF reader backend (used only when input is VCF).",
)
def _add_admixture_map_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--phe-id",
dest="phe_id",
required=True,
type=str,
help="Phenotype ID / column name to analyze.",
)
parser.add_argument(
"--phe-path",
dest="phe_path",
required=True,
type=str,
help="Path to phenotype file.",
)
parser.add_argument(
"--lai-path",
dest="lai_path",
required=True,
type=str,
help="Path to local ancestry file (.msp/.msp.tsv or FLARE .anc.vcf.gz).",
)
parser.add_argument(
"--results-path",
dest="results_path",
required=True,
type=str,
help="Output directory or output .tsv/.tsv.gz path.",
)
parser.add_argument(
"--batch-size",
dest="batch_size",
default=32768,
type=int,
help="Max windows processed per chunk.",
)
parser.add_argument(
"--memory",
dest="memory",
default=None,
type=int,
help="Peak RSS-delta memory cap in MiB.",
)
parser.add_argument(
"--keep-hla",
dest="keep_hla",
action="store_true",
help="Keep chr6 HLA windows (default removes them).",
)
parser.add_argument(
"--quantitative",
dest="quantitative",
action="store_true",
default=None,
help="Force quantitative (linear) mode.",
)
parser.add_argument(
"--verbose",
dest="verbose",
action="store_true",
help="Print progress updates.",
)
parser.add_argument(
"--covar-path",
dest="covar_path",
default=None,
type=str,
help="Path to covariate file.",
)
parser.add_argument(
"--covar-col-nums",
dest="covar_col_nums",
default=None,
type=str,
help='Covariate columns relative to first covariate column (e.g. "1-5,7").',
)
parser.add_argument(
"--covar-variance-standardize",
dest="covar_variance_standardize",
action="store_true",
help="Center and variance-standardize selected covariates.",
)
parser.add_argument(
"--ci",
dest="ci",
default=None,
type=float,
help="Confidence level in (0, 1), e.g. 0.95.",
)
parser.add_argument(
"--adjust",
dest="adjust",
action="store_true",
help="Add Bonferroni and Benjamini-Hochberg FDR p-values.",
)
parser.add_argument(
"--keep-path",
dest="keep_path",
default=None,
type=str,
help="Path to keep file (FID IID or IID per line).",
)
parser.add_argument(
"--sample-remove",
dest="remove_path",
default=None,
type=str,
help="Path to remove file (FID IID or IID per line).",
)
parser.add_argument(
"--manhattan-plot",
dest="manhattan_plot",
default=None,
type=str,
help="Optional path to save a Manhattan plot after admixture mapping completes (.pdf / .svg / .png, ...).",
)
parser.add_argument(
"--qq-plot",
dest="qq_plot",
default=None,
type=str,
help="Optional path to save a Q-Q plot after admixture mapping completes (.pdf / .svg / .png, ...).",
)
def _add_gwas_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--phe-id",
dest="phe_id",
required=True,
type=str,
help="Phenotype ID / column name to analyze.",
)
parser.add_argument(
"--phe-path",
dest="phe_path",
required=True,
type=str,
help="Path to phenotype file with IID and one or more phenotype columns.",
)
parser.add_argument(
"--snp-path",
dest="snp_path",
required=True,
type=str,
help="Path to genotype input (VCF/BED/PGEN).",
)
parser.add_argument(
"--results-path",
dest="results_path",
required=True,
type=str,
help="Output directory or output .tsv/.tsv.gz path.",
)
parser.add_argument(
"--batch-size",
dest="batch_size",
default=32768,
type=int,
help="Max variants processed per chunk.",
)
parser.add_argument(
"--memory",
dest="memory",
default=None,
type=int,
help="Peak RSS-delta memory cap in MiB.",
)
parser.add_argument(
"--quantitative",
dest="quantitative",
action="store_true",
default=None,
help="Force quantitative (linear) mode.",
)
parser.add_argument(
"--verbose",
dest="verbose",
action="store_true",
help="Print progress updates.",
)
parser.add_argument(
"--covar-path",
dest="covar_path",
default=None,
type=str,
help="Path to covariate file.",
)
parser.add_argument(
"--covar-col-nums",
dest="covar_col_nums",
default=None,
type=str,
help='Covariate columns relative to first covariate column (e.g. "1-5,7").',
)
parser.add_argument(
"--covar-variance-standardize",
dest="covar_variance_standardize",
action="store_true",
help="Center and variance-standardize selected covariates.",
)
parser.add_argument(
"--variant-exclude",
dest="exclude_path",
default=None,
type=str,
help="Path to variant exclusion file (one or more variant selectors per line).",
)
parser.add_argument(
"--ci",
dest="ci",
default=None,
type=float,
help="Confidence level in (0, 1), e.g. 0.95.",
)
parser.add_argument(
"--adjust",
dest="adjust",
action="store_true",
help="Add Bonferroni and Benjamini-Hochberg FDR p-values.",
)
parser.add_argument(
"--keep-path",
dest="keep_path",
default=None,
type=str,
help="Path to keep file (FID IID or IID per line).",
)
parser.add_argument(
"--sample-remove",
dest="remove_path",
default=None,
type=str,
help="Path to remove file (FID IID or IID per line).",
)
parser.add_argument(
"--vcf-backend",
dest="vcf_backend",
choices=("polars", "scikit-allel"),
default="polars",
help="VCF reader backend (used only when input is VCF).",
)
parser.add_argument(
"--manhattan-plot",
dest="manhattan_plot",
default=None,
type=str,
help="Optional path to save a Manhattan plot after GWAS completes (.pdf / .svg / .png, ...).",
)
parser.add_argument(
"--qq-plot",
dest="qq_plot",
default=None,
type=str,
help="Optional path to save a Q-Q plot after GWAS completes (.pdf / .svg / .png, ...).",
)
def _add_dimred_common_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--snp-path", required=True, help="Path to SNP input. For maasmds, pass comma-separated paths for multiple arrays.")
parser.add_argument("--lai-path", required=True, help="Path to local ancestry input. For maasmds, pass comma-separated paths matching --snp-path.")
parser.add_argument("--labels-file", required=True, help="TSV labels file with indID and label columns.")
parser.add_argument("--ancestry", required=True, help="Ancestry index or ancestry-map label to analyze.")
parser.add_argument("--coords", required=True, help="Output TSV/CSV path for coordinates and row metadata.")
parser.add_argument("--n-components", type=_positive_int, default=2, help="Number of dimensions/components to compute.")
parser.add_argument("--unmasked", action="store_true", help="Use unmasked genotypes instead of ancestry-specific masking.")
parser.add_argument("--average-strands", action="store_true", help="Average each individual's two haplotypes.")
parser.add_argument("--force-nan-incomplete-strands", action="store_true", help="Set averaged strand pairs to NaN if either haplotype is missing.")
parser.add_argument("--weighted", action="store_true", help="Read individual weights from the labels file.")
parser.add_argument("--groups-to-remove", nargs="+", default=None, help="Population labels to remove before analysis.")
parser.add_argument("--min-percent-snps", type=float, default=4, help="Minimum percent of non-missing SNPs required per row.")
parser.add_argument("--include-individual-frequencies", action="store_true", help="Keep individual-level data when weighted group combinations are present.")
parser.add_argument("--save-masks", action="store_true", help="Save masked genotype data to --masks-file.")
parser.add_argument("--load-masks", action="store_true", help="Load masked genotype data from --masks-file.")
parser.add_argument("--masks-file", default="masks.npz", help="Path for saving/loading masked genotype data.")
parser.add_argument("--rsid-or-chrompos", type=int, choices=(1, 2), default=2, help="Variant ID mode: 1=rsID, 2=chromosome_position.")
def _add_mdpca_arguments(parser: argparse.ArgumentParser) -> None:
_add_dimred_common_arguments(parser)
parser.add_argument(
"--method",
default="weighted_cov_pca",
choices=(
"weighted_cov_pca",
"regularized_optimization_ils",
"cov_matrix_imputation",
"cov_matrix_imputation_ils",
"nonmissing_pca_ils",
),
help="mdPCA method.",
)
parser.add_argument("--covariance-matrix-file", default=None, help="Optional .npy path for the covariance matrix.")
parser.add_argument("--percent-vals-masked", type=float, default=0, help="Percent of covariance values to mask for imputation methods.")
parser.add_argument(
"--plot",
dest="plot",
default=None,
type=str,
help="Optional path to save a scatter plot of the first two mdPCA components (.pdf / .svg / .png, ...).",
)
def _add_maasmds_arguments(parser: argparse.ArgumentParser) -> None:
_add_dimred_common_arguments(parser)
parser.add_argument(
"--distance-type",
choices=("Manhattan", "RMS", "AP"),
default="AP",
help="Pairwise distance used before MDS.",
)
def _add_simulate_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--snp", required=True, help="Path to SNP input (VCF, BED, or PGEN fileset).")
parser.add_argument("--metadata", required=True, help="TSV/CSV file with at least Sample / Population / Latitude / Longitude.")
parser.add_argument("--output-dir", required=True, help="Directory in which to save the simulated batches.")
parser.add_argument("--genetic-map", default=None, help="Genetic map table with columns: chrom, pos, cM.")
parser.add_argument("--chromosome", type=int, default=None, help="If provided, restrict genetic map rows to this chromosome id.")
parser.add_argument("--window-size", type=int, default=1000, help="#SNPs per window.")
parser.add_argument("--store-latlon-as-nvec", action="store_true", help="Convert lat/lon to unit n-vectors (x,y,z).")
parser.add_argument("--make-haploid", action="store_true", help="Flatten diploid genotypes into haplotypes.")
parser.add_argument("--device", default="cpu", help="torch device string, e.g. 'cuda:0'.")
parser.add_argument("--batch-size", type=int, default=256, help="#simulated haplotypes per batch.")
parser.add_argument("--num-generations", type=int, default=10, help="Upper bound on random generations since admixture.")
parser.add_argument("--n-batches", type=int, default=1, help="#separate batches to generate & save.")
parser.add_argument("-v", "--verbose", action="store_true", help="Print additional debugging info.")
def _add_plot_manhattan_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--results-path", required=True, help="Association result TSV with #CHROM, POS, and P columns.")
parser.add_argument("--output-path", required=True, help="Output figure path (.pdf, .svg, .png, ...).")
parser.add_argument("--significance-threshold", type=float, default=0.05, help="Nominal alpha used for the Bonferroni line.")
parser.add_argument("--point-size", type=float, default=7.0, help="Scatter point size.")
parser.add_argument("--line-width", type=float, default=1.0, help="Reference line width.")
parser.add_argument("--line-color", default="r", help="Bonferroni reference line color.")
parser.add_argument("--title", default=None, help="Optional plot title.")
def _add_plot_qq_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--results-path", required=True, help="Association result TSV with a P column.")
parser.add_argument("--output-path", required=True, help="Output figure path (.pdf, .svg, .png, ...).")
parser.add_argument("--significance-threshold", type=float, default=0.05, help="Nominal alpha used for the Bonferroni line.")
parser.add_argument("--point-size", type=float, default=7.0, help="Scatter point size.")
parser.add_argument("--line-width", type=float, default=1.0, help="Reference line width.")
parser.add_argument("--color", default="black", help="Scatter point color.")
parser.add_argument("--expected-line-color", default="red", help="Expected-null reference line color.")
parser.add_argument("--threshold-line-color", default="orange", help="Bonferroni threshold line color.")
parser.add_argument("--title", default=None, help="Optional plot title.")
_COMMANDS: Dict[str, _Command] = {
"pca": _Command(
help="Run PCA and save plot/components.",
add_arguments=_add_pca_arguments,
run=_run_pca,
),
"mdpca": _Command(
help="Run missing-data PCA and save an embedding table.",
add_arguments=_add_mdpca_arguments,
run=_run_mdpca,
),
"maasmds": _Command(
help="Run multi-array ancestry-specific MDS and save an embedding table.",
add_arguments=_add_maasmds_arguments,
run=_run_maasmds,
),
"admixture-map": _Command(
help="Run admixture mapping.",
add_arguments=_add_admixture_map_arguments,
run=_run_admixture_map,
),
"gwas": _Command(
help="Run GWAS.",
add_arguments=_add_gwas_arguments,
run=_run_gwas,
),
"simulate": _Command(
help="Simulate admixed haplotype batches.",
add_arguments=_add_simulate_arguments,
run=_run_simulate,
),
"plot-manhattan": _Command(
help="Create a Manhattan plot from association results.",
add_arguments=_add_plot_manhattan_arguments,
run=_run_plot_manhattan,
),
"plot-qq": _Command(
help="Create a QQ plot from association results.",
add_arguments=_add_plot_qq_arguments,
run=_run_plot_qq,
),
}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="snputils",
description=(
"snputils command-line interface for common file-backed workflows. "
f"Version: {__version__}. Docs: {DOCS_URL}. Source: {SOURCE_URL}."
),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--version",
action="version",
version=f"snputils {__version__}",
help="Show the installed snputils version and exit.",
)
subparsers = parser.add_subparsers(dest="command")
for name, command in _COMMANDS.items():
subparser = subparsers.add_parser(name, help=command.help, description=command.help)
command.add_arguments(subparser)
subparser.set_defaults(_handler=command.run)
version_parser = subparsers.add_parser("version", help="Show the installed snputils version.")
version_parser.set_defaults(_handler=lambda args: _print_version())
return parser
def _print_version() -> int:
print(f"snputils {__version__}")
return 0
[docs]
def main(argv: Optional[List[str]] = None) -> int:
parser = build_parser()
args_list = sys.argv[1:] if argv is None else argv
if not args_list:
parser.print_help(sys.stderr)
return 1
args = parser.parse_args(args_list)
handler = getattr(args, "_handler", None)
if handler is None:
parser.print_help(sys.stderr)
return 1
return int(handler(args))
if __name__ == "__main__":
sys.exit(main())