snputils.simulation.simulator

1from .simulator import OnlineSimulator
2
3__all__ = ['OnlineSimulator']
class OnlineSimulator:
263class OnlineSimulator:
264    """
265    A refactored 'OnlineSimulator' for haplotype simulation with window-based SNP data.
266    -------------------------------------------------------------------------------
267    Core Functionality:
268      - Simulates admixed haplotypes.
269      - Supports:
270         (a) discrete labels (e.g., population codes), or
271         (b) lat/lon (converted to n-vectors) stored per window of SNPs.
272
273    Example usage:
274    -------------
275        sim = OnlineSimulator(
276            vcf_data=my_species_chrX_vcf_data,
277            meta=metadata_df,
278            genetic_map=genetic_map_df,  # optional
279            ...
280        )
281        # Then to simulate:
282        snps, labels_discrete, labels_continuous, changepoints = sim.simulate(batch_size=32)
283    """
284    
285    def __init__(
286        self,
287        vcf_data,
288        meta,
289        genetic_map = None,
290        make_haploid = True,
291        window_size = None,
292        store_latlon_as_nvec = False,
293        cp_tolerance = 0,
294    ):
295        self.vcf_data = vcf_data
296        self.meta = meta
297        self.genetic_map = genetic_map
298        self.make_haploid = make_haploid
299        self.window_size = window_size
300        self.store_latlon_as_nvec = store_latlon_as_nvec
301        self.cp_tolerance = cp_tolerance
302        
303        # We will keep discrete and continuous labels separately
304        self.labels_discrete = None
305        self.labels_continuous = None
306
307        # Load everything
308        self._check_sample_metadata()
309        self._intersect_vcf_metadata()
310        self._build_descriptors()
311        self._broadcast_labels_across_snps() 
312
313    def _check_sample_metadata(self):
314        """
315        Ensures the DataFrame `self.meta` has the necessary columns.
316        - If 'discrete', we expect 'Population' column
317        - If 'continuous', we expect 'Latitude' and 'Longitude'
318        """
319        if 'Sample' not in self.meta.columns:
320            raise ValueError("Expected 'Sample' column in sample metadata.")
321            
322        # We'll just check presence:
323        # If 'Population' in columns => we'll do discrete
324        # If 'Latitude'/'Longitude' in columns => we'll do continuous
325        # It's fine if only one is present
326        needed_for_continuous = {'Latitude', 'Longitude'}
327        self.has_discrete = ('Population' in self.meta.columns)
328        self.has_continuous = needed_for_continuous.issubset(self.meta.columns)
329        
330        if not (self.has_discrete or self.has_continuous):
331            raise ValueError(
332                "No recognized columns for descriptors. Need 'Population' for discrete "
333                "and/or 'Latitude','Longitude' for continuous."
334            )
335
336        # Drop rows that lack the necessary fields
337        # For discrete, require 'Population'
338        if self.has_discrete:
339            self.meta = self.meta.dropna(subset=['Sample', 'Population'])
340            log.info("Discrete labeling: found 'Population' column in metadata.")
341        # For continuous, require lat/lon
342        if self.has_continuous:
343            self.meta = self.meta.dropna(subset=['Sample', 'Latitude', 'Longitude'])
344            log.info("Continuous labeling: found 'Latitude'/'Longitude' columns in metadata.")
345
346        log.info('Metadata OK.')
347            
348    def _intersect_vcf_metadata(self):
349        """
350        Intersects VCF samples with metadata samples.
351        Produces:
352          self.snps: shape (N, D) or (N,2,D) if not yet flattened
353          self.samples: array of sample names
354        If self.make_haploid is True, flattens to haplotype level => shape (N*2, D).
355        """
356        # Intersect VCF samples with metadata samples
357        vcf_samples = self.vcf_data["samples"]
358        log.info(f"VCF has {len(vcf_samples)} samples total.")
359        meta_samples = self.meta["Sample"].values
360        # Return intersection array plus index arrays
361        #   isamples = intersected sample IDs
362        #   iidx = indices in vcf_samples that match
363        inter = np.intersect1d(vcf_samples, meta_samples, assume_unique=False, return_indices=True)
364        isamples, iidx = inter[0], inter[1]
365        log.info(f"{len(isamples)} samples found in both VCF and metadata.")
366        if len(isamples) == 0:
367            raise ValueError("No overlap between VCF samples and metadata samples. Check your paths or sample naming.")
368        
369        # Reindex the metadata so it lines up with 'intersect_samples'
370        # idx_meta is the array of indices in self.metadata that correspond
371        # to the intersected sample set
372        #self.meta = self.meta.iloc[iidx].copy().reset_index(drop=True)
373        samp2idx = {s: idx for idx, s in enumerate(meta_samples)}
374        meta_idxs = [samp2idx[s] for s in isamples]
375        self.meta = self.meta.iloc[meta_idxs].copy().reset_index(drop=True)
376        
377        # Load genotype data: shape (variants, samples, ploidy)
378        snps = self.vcf_data["calldata_gt"].transpose(1,2,0)[iidx, ...] 
379        n_samples, ploidy, n_snps = snps.shape
380        
381        # Note that if we flatten into haploid, we need to repeat rows
382        if self.make_haploid:
383            # Flatten into haploid if requested
384            snps = snps.reshape(n_samples * ploidy, n_snps)
385            # If we flattened from (samples, 2, snps) => (samples*2, snps)
386            # we must also repeat the metadata rows for the 2 haplotypes
387            isamples = np.repeat(isamples, ploidy)
388            self.meta = self.meta.loc[self.meta.index.repeat(2)].reset_index(drop=True)
389            
390        # Convert to torch
391        self.snps = torch.tensor(snps, dtype=torch.int8)
392        self.samples = np.array(isamples)
393        log.info(f"snps shape = {self.snps.shape}, sample length = {len(self.samples)}")
394                                 
395        # Read genetic map
396        if self.genetic_map is not None:
397            cm_interp = np.interp(self.vcf_data["variants_pos"], self.genetic_map['pos'], self.genetic_map['cM'])
398            self.rate_per_snp = np.gradient(cm_interp/100.0)
399            log.info(f"rate/snp shape = {self.rate_per_snp.shape}")
400        else:
401            self.rate_per_snp = None
402        
403
404    def _build_descriptors(self):
405        """
406        Build the self.labels array. If discrete, we'll store integer-coded labels.
407        If continuous, we store lat/lon or x,y,z for each sample.
408        """
409        if len(self.samples) != self.snps.shape[0]:
410            raise ValueError("Metadata subset mismatch in length after flattening haplotypes.")
411
412        # 1) Discrete
413        if self.has_discrete:
414            pop_values = self.meta['Population'].values
415            unique_pops = sorted(np.unique(pop_values))
416            pop2code = {p: i for i, p in enumerate(unique_pops)}
417            discrete_arr = np.array([pop2code[p] for p in pop_values], dtype=np.int16)
418            # shape => (N,1)
419            discrete_arr = discrete_arr[:, None]
420            self.labels_discrete = torch.tensor(discrete_arr, dtype=torch.int16)
421            log.info(f"Built discrete labels => shape {self.labels_discrete.shape}")
422
423        # 2) Continuous
424        if self.has_continuous:
425            lat_vals = self.meta["Latitude"].values
426            lon_vals = self.meta["Longitude"].values
427            if self.store_latlon_as_nvec:
428                coords = latlon_to_nvector(lat_vals, lon_vals)  # (N, 3)
429            else:
430                coords = np.stack([lat_vals, lon_vals], axis=-1)  # (N, 2)
431            self.labels_continuous = torch.tensor(coords, dtype=torch.float32)
432            log.info(f"Built continuous labels => shape {self.labels_continuous.shape}")
433
434
435    def _broadcast_labels_across_snps(self):
436        """
437        Make self.labels have shape (N, D) for discrete or (N, D, coord_dim) for continuous
438        so we can do per-SNP crossovers that also scramble the labels.
439        """
440        N, D = self.snps.shape
441        
442        # Discrete
443        if self.labels_discrete is not None:
444            # shape => (N,1) => broadcast => (N, D)
445            arr = self.labels_discrete.cpu().numpy()  # shape (N,1)
446            arr_bcast = np.repeat(arr, D, axis=1)     # (N, D)
447            self.labels_discrete = torch.tensor(arr_bcast, dtype=torch.int16)
448            log.info(f"Broadcast discrete => {self.labels_discrete.shape}")
449
450        # Continuous
451        if self.labels_continuous is not None:
452            arr = self.labels_continuous.cpu().numpy()   # shape (N, 2 or 3)
453            coord_dim = arr.shape[1]
454            arr_bcast = np.zeros((N, D, coord_dim), dtype=arr.dtype)
455            for i in range(coord_dim):
456                arr_bcast[:,:,i] = np.repeat(arr[:, i][:, None], D, axis=1)
457            self.labels_continuous = torch.tensor(arr_bcast, dtype=torch.float32)
458            log.info(f"Broadcast continuous => {self.labels_continuous.shape}")
459       
460    def _simulate_from_pool(
461        self,
462        batch_snps,
463        batch_labels_discrete,
464        batch_labels_continuous,
465        num_generation_max,
466        device='cpu',
467    ):
468        """
469        Shuffle segments for admixture on snps, discrete labels, continuous labels (if they exist).
470        Each has shape: 
471          snps => (B, D)
472          batch_labels_discrete => (B, D) or None
473          batch_labels_continuous => (B, D, cdim) or None
474        """
475        if device != 'cpu':
476            batch_snps = batch_snps.to(device)
477            if batch_labels_discrete is not None:
478                batch_labels_discrete = batch_labels_discrete.to(device)
479            if batch_labels_continuous is not None:
480                batch_labels_continuous = batch_labels_continuous.to(device)
481
482        # 1) Pick the number of generations
483        G = np.random.randint(0, num_generation_max+1)
484        # 2) If we have a rate_per_snp array, we can do random binomial switch at each SNP
485        #    or if cM is provided, uniform. 
486        #    We'll keep it simple: we do g random switch points
487        B, D = batch_snps.shape 
488
489        if self.rate_per_snp is not None:
490            switch = np.random.binomial(G, self.rate_per_snp) % 2
491            split_points = np.flatnonzero(switch)
492        else:
493            split_points = torch.randint(D, (G,))
494
495        for sp in split_points:
496            perm = torch.randperm(B, device=batch_snps.device)
497            # Swap SNPs
498            batch_snps[:, sp:] = batch_snps[perm, sp:]
499            # Swap discrete
500            if batch_labels_discrete is not None:
501                batch_labels_discrete[:, sp:] = batch_labels_discrete[perm, sp:]
502            # Swap continuous
503            if batch_labels_continuous is not None:
504                batch_labels_continuous[:, sp:, :] = batch_labels_continuous[perm, sp:, :]
505
506
507        return batch_snps, batch_labels_discrete, batch_labels_continuous
508
509    def simulate(
510        self,
511        batch_size=256,
512        num_generation_max=10,
513        balanced=False,
514        single_ancestry=False,
515        device='cpu',
516        pool_method='mode'
517    ):
518        """
519        Returns a tuple of:
520          ( batch_snps, final_discrete_labels_window, final_continuous_labels_window )
521
522        where:
523          - batch_snps.shape == (B, D)
524          - final_discrete_labels_window == (B, n_windows) if discrete was present, else None
525          - final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None
526        """
527        # pick random subset of samples
528        N = self.snps.shape[0]
529        idx = torch.randint(N, (batch_size,))
530        batch_snps = self.snps[idx, :].clone()  # shape (B, D)
531        
532        # Subset discrete
533        if self.labels_discrete is not None:
534            batch_discrete = self.labels_discrete[idx, :].clone()  # shape (B, D)
535        else:
536            batch_discrete = None
537
538        # Subset continuous
539        if self.labels_continuous is not None:
540            batch_continuous = self.labels_continuous[idx, :, :].clone() # (B, D, dim)
541        else:
542            batch_continuous = None
543            
544        # 2) possibly do single_ancestry or balanced logic if you want
545        # We'll skip it for brevity; your original code had that logic.
546
547        # Crossovers
548        batch_snps, batch_discrete, batch_continuous = self._simulate_from_pool(
549            batch_snps, batch_discrete, batch_continuous,
550            num_generation_max=num_generation_max,
551            device=device
552        )
553        # Window-chunk each label array if window_size is specified
554        discrete_out = None
555        continuous_out = None
556        final_cp_window = None
557        
558        if self.window_size is not None and self.window_size > 0:
559            if batch_discrete is not None:
560                discrete_out = _chunk_label_array(
561                    labels=batch_discrete,
562                    window_size=self.window_size,
563                    descriptor="discrete",
564                    pool_method=None,
565                )
566                # shape => (B, D)
567                # find SNP-level breakpoints
568                # Compare label[i] vs label[i-1]
569                # We'll do that in NumPy or Torch
570                lab_np = batch_discrete.cpu().numpy()  # (B, D)
571                # define a shift comparison
572                # cp_mask[:,i] = True if lab[:,i] != lab[:,i-1]
573                # We'll do it for i from 1..D-1
574                cp_mask = np.zeros_like(lab_np, dtype=bool)  # (B, D)
575                cp_mask[:, 1:] = (lab_np[:, 1:] != lab_np[:, :-1])
576                
577                # Now chunk that cp_mask into windows
578                cp_mask_t = torch.from_numpy(cp_mask)
579                final_cp_window = _chunk_changepoints(cp_mask_t, self.window_size)
580                
581            if batch_continuous is not None:
582                continuous_out = _chunk_label_array(
583                    labels=batch_continuous,
584                    window_size=self.window_size,
585                    descriptor="continuous",
586                    pool_method=pool_method,
587                )
588
589        return batch_snps.float(), discrete_out, continuous_out, final_cp_window

A refactored 'OnlineSimulator' for haplotype simulation with window-based SNP data.

Core Functionality:
  • Simulates admixed haplotypes.
  • Supports: (a) discrete labels (e.g., population codes), or (b) lat/lon (converted to n-vectors) stored per window of SNPs.

Example usage:

sim = OnlineSimulator(
    vcf_data=my_species_chrX_vcf_data,
    meta=metadata_df,
    genetic_map=genetic_map_df,  # optional
    ...
)
# Then to simulate:
snps, labels_discrete, labels_continuous, changepoints = sim.simulate(batch_size=32)
OnlineSimulator( vcf_data, meta, genetic_map=None, make_haploid=True, window_size=None, store_latlon_as_nvec=False, cp_tolerance=0)
285    def __init__(
286        self,
287        vcf_data,
288        meta,
289        genetic_map = None,
290        make_haploid = True,
291        window_size = None,
292        store_latlon_as_nvec = False,
293        cp_tolerance = 0,
294    ):
295        self.vcf_data = vcf_data
296        self.meta = meta
297        self.genetic_map = genetic_map
298        self.make_haploid = make_haploid
299        self.window_size = window_size
300        self.store_latlon_as_nvec = store_latlon_as_nvec
301        self.cp_tolerance = cp_tolerance
302        
303        # We will keep discrete and continuous labels separately
304        self.labels_discrete = None
305        self.labels_continuous = None
306
307        # Load everything
308        self._check_sample_metadata()
309        self._intersect_vcf_metadata()
310        self._build_descriptors()
311        self._broadcast_labels_across_snps() 
vcf_data
meta
genetic_map
make_haploid
window_size
store_latlon_as_nvec
cp_tolerance
labels_discrete
labels_continuous
def simulate( self, batch_size=256, num_generation_max=10, balanced=False, single_ancestry=False, device='cpu', pool_method='mode'):
509    def simulate(
510        self,
511        batch_size=256,
512        num_generation_max=10,
513        balanced=False,
514        single_ancestry=False,
515        device='cpu',
516        pool_method='mode'
517    ):
518        """
519        Returns a tuple of:
520          ( batch_snps, final_discrete_labels_window, final_continuous_labels_window )
521
522        where:
523          - batch_snps.shape == (B, D)
524          - final_discrete_labels_window == (B, n_windows) if discrete was present, else None
525          - final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None
526        """
527        # pick random subset of samples
528        N = self.snps.shape[0]
529        idx = torch.randint(N, (batch_size,))
530        batch_snps = self.snps[idx, :].clone()  # shape (B, D)
531        
532        # Subset discrete
533        if self.labels_discrete is not None:
534            batch_discrete = self.labels_discrete[idx, :].clone()  # shape (B, D)
535        else:
536            batch_discrete = None
537
538        # Subset continuous
539        if self.labels_continuous is not None:
540            batch_continuous = self.labels_continuous[idx, :, :].clone() # (B, D, dim)
541        else:
542            batch_continuous = None
543            
544        # 2) possibly do single_ancestry or balanced logic if you want
545        # We'll skip it for brevity; your original code had that logic.
546
547        # Crossovers
548        batch_snps, batch_discrete, batch_continuous = self._simulate_from_pool(
549            batch_snps, batch_discrete, batch_continuous,
550            num_generation_max=num_generation_max,
551            device=device
552        )
553        # Window-chunk each label array if window_size is specified
554        discrete_out = None
555        continuous_out = None
556        final_cp_window = None
557        
558        if self.window_size is not None and self.window_size > 0:
559            if batch_discrete is not None:
560                discrete_out = _chunk_label_array(
561                    labels=batch_discrete,
562                    window_size=self.window_size,
563                    descriptor="discrete",
564                    pool_method=None,
565                )
566                # shape => (B, D)
567                # find SNP-level breakpoints
568                # Compare label[i] vs label[i-1]
569                # We'll do that in NumPy or Torch
570                lab_np = batch_discrete.cpu().numpy()  # (B, D)
571                # define a shift comparison
572                # cp_mask[:,i] = True if lab[:,i] != lab[:,i-1]
573                # We'll do it for i from 1..D-1
574                cp_mask = np.zeros_like(lab_np, dtype=bool)  # (B, D)
575                cp_mask[:, 1:] = (lab_np[:, 1:] != lab_np[:, :-1])
576                
577                # Now chunk that cp_mask into windows
578                cp_mask_t = torch.from_numpy(cp_mask)
579                final_cp_window = _chunk_changepoints(cp_mask_t, self.window_size)
580                
581            if batch_continuous is not None:
582                continuous_out = _chunk_label_array(
583                    labels=batch_continuous,
584                    window_size=self.window_size,
585                    descriptor="continuous",
586                    pool_method=pool_method,
587                )
588
589        return batch_snps.float(), discrete_out, continuous_out, final_cp_window
Returns a tuple of:

( batch_snps, final_discrete_labels_window, final_continuous_labels_window )

where:

  • batch_snps.shape == (B, D)
  • final_discrete_labels_window == (B, n_windows) if discrete was present, else None
  • final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None