Source code for snputils.simulation.simulator.simulator

import logging
import torch
import numpy as np

logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s%(levelname)-8s%(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S")
log = logging.getLogger("simulator_cli")
    

def latlon_to_nvector(lat, lon):
    """
    Convert lat/long (in DEGREES) to x,y,z n-vector.
    If lat/lon are in radians, remove the radian conversion below.
    """
    lat_rad = np.radians(lat)
    lon_rad = np.radians(lon)
    x = np.cos(lat_rad) * np.cos(lon_rad)
    y = np.cos(lat_rad) * np.sin(lon_rad)
    z = np.sin(lat_rad)
    return np.stack([x, y, z], axis=-1)


def nvector_to_latlon(nvec):
    """
    Convert an n-vector (x,y,z) back to latitude/longitude in DEGREES.

    Parameters
    ----------
    nvec : np.ndarray of shape (3,) or (N,3)
        x, y, z coordinates of the n-vector.

    Returns
    -------
    (lat_deg, lon_deg) : tuple of floats or np.ndarrays
        Latitude(s) and longitude(s) in degrees.
    """
    if nvec.ndim == 1:
        x, y, z = nvec
        lat_rad = np.arcsin(z)
        lon_rad = np.arctan2(y, x)
        return (np.degrees(lat_rad), np.degrees(lon_rad))
    else:
        x = nvec[..., 0]
        y = nvec[..., 1]
        z = nvec[..., 2]
        lat_rad = np.arcsin(z)
        lon_rad = np.arctan2(y, x)
        return (np.degrees(lat_rad), np.degrees(lon_rad))
    
def approximate_mode_per_row(
    row_2d: torch.Tensor,   # shape (B, W) 
    nbins=32
) -> torch.Tensor:
    """
    row_2d: shape (B, W), continuous data on GPU
    nbins:  number of histogram bins

    Returns: shape (B,) approximate mode for each row 
             (i.e. each row in row_2d).
    """
    device = row_2d.device
    B, W = row_2d.shape

    # row-wise min/max
    row_min = row_2d.min(dim=1).values  # (B,)
    row_max = row_2d.max(dim=1).values  # (B,)

    out_modes = torch.zeros(B, device=device, dtype=torch.float32)

    # We'll do a simple loop over B rows,
    # because torch.histc only handles 1D at a time
    for i in range(B):
        data_i = row_2d[i]  # shape (W,)
        vmin = row_min[i].item()
        vmax = row_max[i].item()

        # if all the same => mode is that value
        if vmax == vmin:
            out_modes[i] = data_i[0]
            continue

        # hist => shape (nbins,)
        hist = torch.histc(data_i, bins=nbins, min=vmin, max=vmax)
        # bin_idx in [0..nbins-1]
        bin_idx = hist.argmax().item()
        bin_width = (vmax - vmin)/nbins
        # approximate midpoint
        bin_mid = vmin + (bin_idx+0.5)*bin_width
        out_modes[i] = bin_mid

    return out_modes


def _chunk_label_array(
    labels: torch.Tensor,
    window_size: int,
    descriptor="continuous",
    pool_method="mode",
    nbins=32
):
    """
    labels: shape (B, D) or (B, D, dim) after final crossovers
      - If descriptor="continuous" => shape (B, D[, dim])
      - If descriptor="discrete"   => shape (B, D)
    window_size: # of SNPs per window
    descriptor:  "continuous" or "discrete"
    pool_method: "mean" or "mode"
      - If continuous + "mode" => uses an approximate histogram-based mode (GPU-friendly)
    nbins: # of bins if using approximate mode for continuous

    Returns:
      chunked => shape (B, n_win, dim), (B, n_win, 1), or (B, n_win)
        depending on descriptor + dimension
    """
    if labels.ndim == 2:
        # shape => (B, D) => discrete or continuous w/ dim=1
        B, D = labels.shape
        label_dim = None
    else:
        # shape => (B, D, dim)
        B, D, label_dim = labels.shape

    n_full = D // window_size
    leftover = D % window_size

    chunks = []
    start = 0
    for _ in range(n_full):
        end = start + window_size
        # slice => (B, window_size[, dim])
        window_segment = (
            labels[:, start:end, ...]
            if label_dim else labels[:, start:end]
        )

        if descriptor == "continuous":
            if pool_method == "mean":
                # normal PyTorch .mean(...)
                if label_dim:
                    # shape => (B, window_size, dim)
                    mean_vals = window_segment.mean(dim=1)  # => (B, dim)
                    chunks.append(mean_vals)
                else:
                    # shape => (B, window_size)
                    mean_vals = window_segment.float().mean(dim=1)  # => (B,)
                    chunks.append(mean_vals.unsqueeze(-1))

            elif pool_method == "mode":
                # approximate GPU mode
                if label_dim:
                    # shape => (B, window_size, dim)
                    # do dimension by dimension
                    # => we'll gather a list of (B,) for each dim, then stack
                    mode_vals_list = []
                    for d_i in range(label_dim):
                        # slice => (B, window_size)
                        slice_2d = window_segment[:, :, d_i]
                        # approximate mode => (B,)
                        approx_m = approximate_mode_per_row(slice_2d, nbins=nbins)
                        mode_vals_list.append(approx_m)
                    # stack => (B, label_dim)
                    mode_vals_cat = torch.stack(mode_vals_list, dim=1)
                    chunks.append(mode_vals_cat)
                else:
                    # shape => (B, window_size)
                    approx_m = approximate_mode_per_row(window_segment, nbins=nbins) # => (B,)
                    chunks.append(approx_m.unsqueeze(-1))
            else:
                raise ValueError(f"pool_method '{pool_method}' not implemented for continuous.")

        else:
            # descriptor == "discrete" => use built-in .mode(dim=1)
            # shape => (B, window_size)
            # mode along dimension=1 => shape (B,)
            mode_vals = window_segment.mode(dim=1).values
            chunks.append(mode_vals)

        start = end

    # leftover
    if leftover > 0:
        window_segment = (
            labels[:, start:, ...]
            if label_dim else labels[:, start:]
        )
        if descriptor == "continuous":
            if label_dim:
                if pool_method == "mean":
                    mean_vals = window_segment.mean(dim=1)  # => (B, dim)
                    chunks.append(mean_vals)
                else:
                    # pool_method == "mode" => approximate
                    mode_vals_list = []
                    for d_i in range(label_dim):
                        slice_2d = window_segment[:, :, d_i]
                        approx_m = approximate_mode_per_row(slice_2d, nbins=nbins)
                        mode_vals_list.append(approx_m)
                    mode_vals_cat = torch.stack(mode_vals_list, dim=1)
                    chunks.append(mode_vals_cat)
            else:
                if pool_method == "mean":
                    mean_vals = window_segment.float().mean(dim=1)  # (B,)
                    chunks.append(mean_vals.unsqueeze(-1))
                else:
                    # approximate mode
                    approx_m = approximate_mode_per_row(window_segment, nbins=nbins)
                    chunks.append(approx_m.unsqueeze(-1))
        else:
            # discrete => leftover => .mode(dim=1)
            mode_vals = window_segment.mode(dim=1).values
            chunks.append(mode_vals)

    if len(chunks) == 0:
        # if window_size >= D => no chunk
        return None

    # Now stack => shape (B, n_windows[, dim]) or (B, n_windows)
    if descriptor == "continuous":
        cat_res = torch.stack(chunks, dim=1)  
        return cat_res
    else:
        # discrete => shape => (B,) in each chunk => stack => (B, n_windows)
        cat_res = torch.stack(chunks, dim=1)
        return cat_res


def _chunk_changepoints(cp_mask, window_size):
    """
    cp_mask: shape (B, D), a boolean (or 0/1) array indicating
             breakpoint positions at the SNP level.

    Returns: shape (B, n_windows), with 1 if any SNP in that window
             was a breakpoint, else 0.
    """
    B, D = cp_mask.shape
    n_full = D // window_size
    leftover = D % window_size

    chunks = []
    start = 0
    for _ in range(n_full):
        end = start + window_size
        # if any True in that window => 1
        any_cp = cp_mask[:, start:end].any(axis=1)
        chunks.append(any_cp)
        start = end

    if leftover > 0:
        any_cp = cp_mask[:, start:].any(axis=1)
        chunks.append(any_cp)

    if len(chunks) == 0:
        return None

    # stack along new dim => (B, n_windows)
    out = np.stack(chunks, axis=1).astype(np.int8)
    return torch.tensor(out)
    
[docs] class OnlineSimulator: """ A refactored 'OnlineSimulator' for haplotype simulation with window-based SNP data. ------------------------------------------------------------------------------- Core Functionality: - Simulates admixed haplotypes. - Supports: (a) discrete labels (e.g., population codes), or (b) lat/lon (converted to n-vectors) stored per window of SNPs. Example usage: ------------- sim = OnlineSimulator( snp_data=my_snpobj, meta=metadata_df, genetic_map=genetic_map_df, # optional ... ) # Then to simulate: snps, labels_discrete, labels_continuous, changepoints = sim.simulate(batch_size=32) """ def __init__( self, snp_data, meta, genetic_map = None, make_haploid = True, window_size = None, store_latlon_as_nvec = False, cp_tolerance = 0, ): self.snp_data = snp_data self.meta = meta self.genetic_map = genetic_map self.make_haploid = make_haploid self.window_size = window_size self.store_latlon_as_nvec = store_latlon_as_nvec self.cp_tolerance = cp_tolerance self.labels_discrete = None self.labels_continuous = None self._check_sample_metadata() self._intersect_snp_metadata() self._build_descriptors() self._broadcast_labels_across_snps() def _check_sample_metadata(self): """ Ensures the DataFrame `self.meta` has the necessary columns. - If 'discrete', we expect 'Population' column - If 'continuous', we expect 'Latitude' and 'Longitude' """ if 'Sample' not in self.meta.columns: raise ValueError("Expected 'Sample' column in sample metadata.") # We'll just check presence: # If 'Population' in columns => we'll do discrete # If 'Latitude'/'Longitude' in columns => we'll do continuous # It's fine if only one is present needed_for_continuous = {'Latitude', 'Longitude'} self.has_discrete = ('Population' in self.meta.columns) self.has_continuous = needed_for_continuous.issubset(self.meta.columns) if not (self.has_discrete or self.has_continuous): raise ValueError( "No recognized columns for descriptors. Need 'Population' for discrete " "and/or 'Latitude','Longitude' for continuous." ) # Drop rows that lack the necessary fields # For discrete, require 'Population' if self.has_discrete: self.meta = self.meta.dropna(subset=['Sample', 'Population']) log.info("Discrete labeling: found 'Population' column in metadata.") # For continuous, require lat/lon if self.has_continuous: self.meta = self.meta.dropna(subset=['Sample', 'Latitude', 'Longitude']) log.info("Continuous labeling: found 'Latitude'/'Longitude' columns in metadata.") log.info('Metadata OK.') def _intersect_snp_metadata(self): """ Intersects SNP samples with metadata samples. Produces: self.snps: shape (N, D) or (N,2,D) if not yet flattened self.samples: array of sample names If self.make_haploid is True, flattens to haplotype level => shape (N*2, D). """ snp_samples = np.asarray(self.snp_data.samples) log.info(f"SNP input has {len(snp_samples)} samples total.") meta_samples = self.meta["Sample"].values inter = np.intersect1d(snp_samples, meta_samples, assume_unique=False, return_indices=True) isamples, iidx = inter[0], inter[1] log.info(f"{len(isamples)} samples found in both SNP input and metadata.") if len(isamples) == 0: raise ValueError("No overlap between SNP samples and metadata samples. Check your paths or sample naming.") samp2idx = {s: idx for idx, s in enumerate(meta_samples)} meta_idxs = [samp2idx[s] for s in isamples] self.meta = self.meta.iloc[meta_idxs].copy().reset_index(drop=True) snps = np.asarray(self.snp_data.genotypes).transpose(1, 2, 0)[iidx, ...] n_samples, ploidy, n_snps = snps.shape if self.make_haploid: snps = snps.reshape(n_samples * ploidy, n_snps) isamples = np.repeat(isamples, ploidy) self.meta = self.meta.loc[self.meta.index.repeat(2)].reset_index(drop=True) self.snps = torch.tensor(snps, dtype=torch.int8) self.samples = np.array(isamples) log.info(f"snps shape = {self.snps.shape}, sample length = {len(self.samples)}") if self.genetic_map is not None: cm_interp = np.interp(self.snp_data.variants_pos, self.genetic_map['pos'], self.genetic_map['cM']) self.rate_per_snp = np.gradient(cm_interp/100.0) log.info(f"rate/snp shape = {self.rate_per_snp.shape}") else: self.rate_per_snp = None def _build_descriptors(self): """ Build the self.labels array. If discrete, we'll store integer-coded labels. If continuous, we store lat/lon or x,y,z for each sample. """ if len(self.samples) != self.snps.shape[0]: raise ValueError("Metadata subset mismatch in length after flattening haplotypes.") # 1) Discrete if self.has_discrete: pop_values = self.meta['Population'].values unique_pops = sorted(np.unique(pop_values)) pop2code = {p: i for i, p in enumerate(unique_pops)} discrete_arr = np.array([pop2code[p] for p in pop_values], dtype=np.int16) # shape => (N,1) discrete_arr = discrete_arr[:, None] self.labels_discrete = torch.tensor(discrete_arr, dtype=torch.int16) log.info(f"Built discrete labels => shape {self.labels_discrete.shape}") # 2) Continuous if self.has_continuous: lat_vals = self.meta["Latitude"].values lon_vals = self.meta["Longitude"].values if self.store_latlon_as_nvec: coords = latlon_to_nvector(lat_vals, lon_vals) # (N, 3) else: coords = np.stack([lat_vals, lon_vals], axis=-1) # (N, 2) self.labels_continuous = torch.tensor(coords, dtype=torch.float32) log.info(f"Built continuous labels => shape {self.labels_continuous.shape}") def _broadcast_labels_across_snps(self): """ Make self.labels have shape (N, D) for discrete or (N, D, coord_dim) for continuous so we can do per-SNP crossovers that also scramble the labels. """ if self.snps.ndim == 3: N, _, D = self.snps.shape # (samples, ploidy, snps) else: N, D = self.snps.shape # (haplotypes, snps) after --make-haploid # Discrete if self.labels_discrete is not None: # shape => (N,1) => broadcast => (N, D) arr = self.labels_discrete.cpu().numpy() # shape (N,1) arr_bcast = np.repeat(arr, D, axis=1) # (N, D) self.labels_discrete = torch.tensor(arr_bcast, dtype=torch.int16) log.info(f"Broadcast discrete => {self.labels_discrete.shape}") # Continuous if self.labels_continuous is not None: arr = self.labels_continuous.cpu().numpy() # shape (N, 2 or 3) coord_dim = arr.shape[1] arr_bcast = np.zeros((N, D, coord_dim), dtype=arr.dtype) for i in range(coord_dim): arr_bcast[:,:,i] = np.repeat(arr[:, i][:, None], D, axis=1) self.labels_continuous = torch.tensor(arr_bcast, dtype=torch.float32) log.info(f"Broadcast continuous => {self.labels_continuous.shape}") def _simulate_from_pool( self, batch_snps, batch_labels_discrete, batch_labels_continuous, num_generation_max, device='cpu', ): """ Shuffle segments for admixture on snps, discrete labels, continuous labels (if they exist). Each has shape: snps => (B, D) batch_labels_discrete => (B, D) or None batch_labels_continuous => (B, D, cdim) or None """ if device != 'cpu': batch_snps = batch_snps.to(device) if batch_labels_discrete is not None: batch_labels_discrete = batch_labels_discrete.to(device) if batch_labels_continuous is not None: batch_labels_continuous = batch_labels_continuous.to(device) # 1) Pick the number of generations G = np.random.randint(0, num_generation_max+1) # 2) If we have a rate_per_snp array, we can do random binomial switch at each SNP # or if cM is provided, uniform. # We'll keep it simple: we do g random switch points B, D = batch_snps.shape if self.rate_per_snp is not None: switch = np.random.binomial(G, self.rate_per_snp) % 2 split_points = np.flatnonzero(switch) else: split_points = torch.randint(D, (G,)) for sp in split_points: perm = torch.randperm(B, device=batch_snps.device) # Swap SNPs batch_snps[:, sp:] = batch_snps[perm, sp:] # Swap discrete if batch_labels_discrete is not None: batch_labels_discrete[:, sp:] = batch_labels_discrete[perm, sp:] # Swap continuous if batch_labels_continuous is not None: batch_labels_continuous[:, sp:, :] = batch_labels_continuous[perm, sp:, :] return batch_snps, batch_labels_discrete, batch_labels_continuous
[docs] def simulate( self, batch_size=256, num_generation_max=10, balanced=False, single_ancestry=False, device='cpu', pool_method='mode' ): """ Returns a tuple of: ( batch_snps, final_discrete_labels_window, final_continuous_labels_window ) where: - batch_snps.shape == (B, D) - final_discrete_labels_window == (B, n_windows) if discrete was present, else None - final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None """ del balanced, single_ancestry # pick random subset of samples N = self.snps.shape[0] idx = torch.randint(N, (batch_size,)) batch_snps = self.snps[idx].clone() # Subset discrete if self.labels_discrete is not None: batch_discrete = self.labels_discrete[idx].clone() else: batch_discrete = None # Subset continuous if self.labels_continuous is not None: batch_continuous = self.labels_continuous[idx].clone() else: batch_continuous = None # Diploid input: (B, 2, D) → flatten strands into haplotype rows (B*2, D) # so that _simulate_from_pool and all downstream logic see a 2-D tensor. if batch_snps.ndim == 3: B_dip, ploidy, D = batch_snps.shape batch_snps = batch_snps.reshape(B_dip * ploidy, D) if batch_discrete is not None: batch_discrete = batch_discrete.repeat_interleave(ploidy, dim=0) if batch_continuous is not None: batch_continuous = batch_continuous.repeat_interleave(ploidy, dim=0) # 2) possibly do single_ancestry or balanced logic if you want # We'll skip it for brevity; your original code had that logic. # Crossovers batch_snps, batch_discrete, batch_continuous = self._simulate_from_pool( batch_snps, batch_discrete, batch_continuous, num_generation_max=num_generation_max, device=device ) # Window-chunk each label array if window_size is specified discrete_out = None continuous_out = None final_cp_window = None if self.window_size is not None and self.window_size > 0: if batch_discrete is not None: discrete_out = _chunk_label_array( labels=batch_discrete, window_size=self.window_size, descriptor="discrete", pool_method=None, ) # shape => (B, D) # find SNP-level breakpoints # Compare label[i] vs label[i-1] # We'll do that in NumPy or Torch lab_np = batch_discrete.cpu().numpy() # (B, D) # define a shift comparison # cp_mask[:,i] = True if lab[:,i] != lab[:,i-1] # We'll do it for i from 1..D-1 cp_mask = np.zeros_like(lab_np, dtype=bool) # (B, D) cp_mask[:, 1:] = (lab_np[:, 1:] != lab_np[:, :-1]) # Now chunk that cp_mask into windows cp_mask_t = torch.from_numpy(cp_mask) final_cp_window = _chunk_changepoints(cp_mask_t, self.window_size) if batch_continuous is not None: continuous_out = _chunk_label_array( labels=batch_continuous, window_size=self.window_size, descriptor="continuous", pool_method=pool_method, ) return batch_snps.float(), discrete_out, continuous_out, final_cp_window