snputils.simulation.simulator
class
OnlineSimulator:
263class OnlineSimulator: 264 """ 265 A refactored 'OnlineSimulator' for haplotype simulation with window-based SNP data. 266 ------------------------------------------------------------------------------- 267 Core Functionality: 268 - Simulates admixed haplotypes. 269 - Supports: 270 (a) discrete labels (e.g., population codes), or 271 (b) lat/lon (converted to n-vectors) stored per window of SNPs. 272 273 Example usage: 274 ------------- 275 sim = OnlineSimulator( 276 vcf_data=my_species_chrX_vcf_data, 277 meta=metadata_df, 278 genetic_map=genetic_map_df, # optional 279 ... 280 ) 281 # Then to simulate: 282 snps, labels_discrete, labels_continuous, changepoints = sim.simulate(batch_size=32) 283 """ 284 285 def __init__( 286 self, 287 vcf_data, 288 meta, 289 genetic_map = None, 290 make_haploid = True, 291 window_size = None, 292 store_latlon_as_nvec = False, 293 cp_tolerance = 0, 294 ): 295 self.vcf_data = vcf_data 296 self.meta = meta 297 self.genetic_map = genetic_map 298 self.make_haploid = make_haploid 299 self.window_size = window_size 300 self.store_latlon_as_nvec = store_latlon_as_nvec 301 self.cp_tolerance = cp_tolerance 302 303 # We will keep discrete and continuous labels separately 304 self.labels_discrete = None 305 self.labels_continuous = None 306 307 # Load everything 308 self._check_sample_metadata() 309 self._intersect_vcf_metadata() 310 self._build_descriptors() 311 self._broadcast_labels_across_snps() 312 313 def _check_sample_metadata(self): 314 """ 315 Ensures the DataFrame `self.meta` has the necessary columns. 316 - If 'discrete', we expect 'Population' column 317 - If 'continuous', we expect 'Latitude' and 'Longitude' 318 """ 319 if 'Sample' not in self.meta.columns: 320 raise ValueError("Expected 'Sample' column in sample metadata.") 321 322 # We'll just check presence: 323 # If 'Population' in columns => we'll do discrete 324 # If 'Latitude'/'Longitude' in columns => we'll do continuous 325 # It's fine if only one is present 326 needed_for_continuous = {'Latitude', 'Longitude'} 327 self.has_discrete = ('Population' in self.meta.columns) 328 self.has_continuous = needed_for_continuous.issubset(self.meta.columns) 329 330 if not (self.has_discrete or self.has_continuous): 331 raise ValueError( 332 "No recognized columns for descriptors. Need 'Population' for discrete " 333 "and/or 'Latitude','Longitude' for continuous." 334 ) 335 336 # Drop rows that lack the necessary fields 337 # For discrete, require 'Population' 338 if self.has_discrete: 339 self.meta = self.meta.dropna(subset=['Sample', 'Population']) 340 log.info("Discrete labeling: found 'Population' column in metadata.") 341 # For continuous, require lat/lon 342 if self.has_continuous: 343 self.meta = self.meta.dropna(subset=['Sample', 'Latitude', 'Longitude']) 344 log.info("Continuous labeling: found 'Latitude'/'Longitude' columns in metadata.") 345 346 log.info('Metadata OK.') 347 348 def _intersect_vcf_metadata(self): 349 """ 350 Intersects VCF samples with metadata samples. 351 Produces: 352 self.snps: shape (N, D) or (N,2,D) if not yet flattened 353 self.samples: array of sample names 354 If self.make_haploid is True, flattens to haplotype level => shape (N*2, D). 355 """ 356 # Intersect VCF samples with metadata samples 357 vcf_samples = self.vcf_data["samples"] 358 log.info(f"VCF has {len(vcf_samples)} samples total.") 359 meta_samples = self.meta["Sample"].values 360 # Return intersection array plus index arrays 361 # isamples = intersected sample IDs 362 # iidx = indices in vcf_samples that match 363 inter = np.intersect1d(vcf_samples, meta_samples, assume_unique=False, return_indices=True) 364 isamples, iidx = inter[0], inter[1] 365 log.info(f"{len(isamples)} samples found in both VCF and metadata.") 366 if len(isamples) == 0: 367 raise ValueError("No overlap between VCF samples and metadata samples. Check your paths or sample naming.") 368 369 # Reindex the metadata so it lines up with 'intersect_samples' 370 # idx_meta is the array of indices in self.metadata that correspond 371 # to the intersected sample set 372 #self.meta = self.meta.iloc[iidx].copy().reset_index(drop=True) 373 samp2idx = {s: idx for idx, s in enumerate(meta_samples)} 374 meta_idxs = [samp2idx[s] for s in isamples] 375 self.meta = self.meta.iloc[meta_idxs].copy().reset_index(drop=True) 376 377 # Load genotype data: shape (variants, samples, ploidy) 378 snps = self.vcf_data["calldata_gt"].transpose(1,2,0)[iidx, ...] 379 n_samples, ploidy, n_snps = snps.shape 380 381 # Note that if we flatten into haploid, we need to repeat rows 382 if self.make_haploid: 383 # Flatten into haploid if requested 384 snps = snps.reshape(n_samples * ploidy, n_snps) 385 # If we flattened from (samples, 2, snps) => (samples*2, snps) 386 # we must also repeat the metadata rows for the 2 haplotypes 387 isamples = np.repeat(isamples, ploidy) 388 self.meta = self.meta.loc[self.meta.index.repeat(2)].reset_index(drop=True) 389 390 # Convert to torch 391 self.snps = torch.tensor(snps, dtype=torch.int8) 392 self.samples = np.array(isamples) 393 log.info(f"snps shape = {self.snps.shape}, sample length = {len(self.samples)}") 394 395 # Read genetic map 396 if self.genetic_map is not None: 397 cm_interp = np.interp(self.vcf_data["variants_pos"], self.genetic_map['pos'], self.genetic_map['cM']) 398 self.rate_per_snp = np.gradient(cm_interp/100.0) 399 log.info(f"rate/snp shape = {self.rate_per_snp.shape}") 400 else: 401 self.rate_per_snp = None 402 403 404 def _build_descriptors(self): 405 """ 406 Build the self.labels array. If discrete, we'll store integer-coded labels. 407 If continuous, we store lat/lon or x,y,z for each sample. 408 """ 409 if len(self.samples) != self.snps.shape[0]: 410 raise ValueError("Metadata subset mismatch in length after flattening haplotypes.") 411 412 # 1) Discrete 413 if self.has_discrete: 414 pop_values = self.meta['Population'].values 415 unique_pops = sorted(np.unique(pop_values)) 416 pop2code = {p: i for i, p in enumerate(unique_pops)} 417 discrete_arr = np.array([pop2code[p] for p in pop_values], dtype=np.int16) 418 # shape => (N,1) 419 discrete_arr = discrete_arr[:, None] 420 self.labels_discrete = torch.tensor(discrete_arr, dtype=torch.int16) 421 log.info(f"Built discrete labels => shape {self.labels_discrete.shape}") 422 423 # 2) Continuous 424 if self.has_continuous: 425 lat_vals = self.meta["Latitude"].values 426 lon_vals = self.meta["Longitude"].values 427 if self.store_latlon_as_nvec: 428 coords = latlon_to_nvector(lat_vals, lon_vals) # (N, 3) 429 else: 430 coords = np.stack([lat_vals, lon_vals], axis=-1) # (N, 2) 431 self.labels_continuous = torch.tensor(coords, dtype=torch.float32) 432 log.info(f"Built continuous labels => shape {self.labels_continuous.shape}") 433 434 435 def _broadcast_labels_across_snps(self): 436 """ 437 Make self.labels have shape (N, D) for discrete or (N, D, coord_dim) for continuous 438 so we can do per-SNP crossovers that also scramble the labels. 439 """ 440 N, D = self.snps.shape 441 442 # Discrete 443 if self.labels_discrete is not None: 444 # shape => (N,1) => broadcast => (N, D) 445 arr = self.labels_discrete.cpu().numpy() # shape (N,1) 446 arr_bcast = np.repeat(arr, D, axis=1) # (N, D) 447 self.labels_discrete = torch.tensor(arr_bcast, dtype=torch.int16) 448 log.info(f"Broadcast discrete => {self.labels_discrete.shape}") 449 450 # Continuous 451 if self.labels_continuous is not None: 452 arr = self.labels_continuous.cpu().numpy() # shape (N, 2 or 3) 453 coord_dim = arr.shape[1] 454 arr_bcast = np.zeros((N, D, coord_dim), dtype=arr.dtype) 455 for i in range(coord_dim): 456 arr_bcast[:,:,i] = np.repeat(arr[:, i][:, None], D, axis=1) 457 self.labels_continuous = torch.tensor(arr_bcast, dtype=torch.float32) 458 log.info(f"Broadcast continuous => {self.labels_continuous.shape}") 459 460 def _simulate_from_pool( 461 self, 462 batch_snps, 463 batch_labels_discrete, 464 batch_labels_continuous, 465 num_generation_max, 466 device='cpu', 467 ): 468 """ 469 Shuffle segments for admixture on snps, discrete labels, continuous labels (if they exist). 470 Each has shape: 471 snps => (B, D) 472 batch_labels_discrete => (B, D) or None 473 batch_labels_continuous => (B, D, cdim) or None 474 """ 475 if device != 'cpu': 476 batch_snps = batch_snps.to(device) 477 if batch_labels_discrete is not None: 478 batch_labels_discrete = batch_labels_discrete.to(device) 479 if batch_labels_continuous is not None: 480 batch_labels_continuous = batch_labels_continuous.to(device) 481 482 # 1) Pick the number of generations 483 G = np.random.randint(0, num_generation_max+1) 484 # 2) If we have a rate_per_snp array, we can do random binomial switch at each SNP 485 # or if cM is provided, uniform. 486 # We'll keep it simple: we do g random switch points 487 B, D = batch_snps.shape 488 489 if self.rate_per_snp is not None: 490 switch = np.random.binomial(G, self.rate_per_snp) % 2 491 split_points = np.flatnonzero(switch) 492 else: 493 split_points = torch.randint(D, (G,)) 494 495 for sp in split_points: 496 perm = torch.randperm(B, device=batch_snps.device) 497 # Swap SNPs 498 batch_snps[:, sp:] = batch_snps[perm, sp:] 499 # Swap discrete 500 if batch_labels_discrete is not None: 501 batch_labels_discrete[:, sp:] = batch_labels_discrete[perm, sp:] 502 # Swap continuous 503 if batch_labels_continuous is not None: 504 batch_labels_continuous[:, sp:, :] = batch_labels_continuous[perm, sp:, :] 505 506 507 return batch_snps, batch_labels_discrete, batch_labels_continuous 508 509 def simulate( 510 self, 511 batch_size=256, 512 num_generation_max=10, 513 balanced=False, 514 single_ancestry=False, 515 device='cpu', 516 pool_method='mode' 517 ): 518 """ 519 Returns a tuple of: 520 ( batch_snps, final_discrete_labels_window, final_continuous_labels_window ) 521 522 where: 523 - batch_snps.shape == (B, D) 524 - final_discrete_labels_window == (B, n_windows) if discrete was present, else None 525 - final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None 526 """ 527 # pick random subset of samples 528 N = self.snps.shape[0] 529 idx = torch.randint(N, (batch_size,)) 530 batch_snps = self.snps[idx, :].clone() # shape (B, D) 531 532 # Subset discrete 533 if self.labels_discrete is not None: 534 batch_discrete = self.labels_discrete[idx, :].clone() # shape (B, D) 535 else: 536 batch_discrete = None 537 538 # Subset continuous 539 if self.labels_continuous is not None: 540 batch_continuous = self.labels_continuous[idx, :, :].clone() # (B, D, dim) 541 else: 542 batch_continuous = None 543 544 # 2) possibly do single_ancestry or balanced logic if you want 545 # We'll skip it for brevity; your original code had that logic. 546 547 # Crossovers 548 batch_snps, batch_discrete, batch_continuous = self._simulate_from_pool( 549 batch_snps, batch_discrete, batch_continuous, 550 num_generation_max=num_generation_max, 551 device=device 552 ) 553 # Window-chunk each label array if window_size is specified 554 discrete_out = None 555 continuous_out = None 556 final_cp_window = None 557 558 if self.window_size is not None and self.window_size > 0: 559 if batch_discrete is not None: 560 discrete_out = _chunk_label_array( 561 labels=batch_discrete, 562 window_size=self.window_size, 563 descriptor="discrete", 564 pool_method=None, 565 ) 566 # shape => (B, D) 567 # find SNP-level breakpoints 568 # Compare label[i] vs label[i-1] 569 # We'll do that in NumPy or Torch 570 lab_np = batch_discrete.cpu().numpy() # (B, D) 571 # define a shift comparison 572 # cp_mask[:,i] = True if lab[:,i] != lab[:,i-1] 573 # We'll do it for i from 1..D-1 574 cp_mask = np.zeros_like(lab_np, dtype=bool) # (B, D) 575 cp_mask[:, 1:] = (lab_np[:, 1:] != lab_np[:, :-1]) 576 577 # Now chunk that cp_mask into windows 578 cp_mask_t = torch.from_numpy(cp_mask) 579 final_cp_window = _chunk_changepoints(cp_mask_t, self.window_size) 580 581 if batch_continuous is not None: 582 continuous_out = _chunk_label_array( 583 labels=batch_continuous, 584 window_size=self.window_size, 585 descriptor="continuous", 586 pool_method=pool_method, 587 ) 588 589 return batch_snps.float(), discrete_out, continuous_out, final_cp_window
A refactored 'OnlineSimulator' for haplotype simulation with window-based SNP data.
Core Functionality:
- Simulates admixed haplotypes.
- Supports: (a) discrete labels (e.g., population codes), or (b) lat/lon (converted to n-vectors) stored per window of SNPs.
Example usage:
sim = OnlineSimulator(
vcf_data=my_species_chrX_vcf_data,
meta=metadata_df,
genetic_map=genetic_map_df, # optional
...
)
# Then to simulate:
snps, labels_discrete, labels_continuous, changepoints = sim.simulate(batch_size=32)
OnlineSimulator( vcf_data, meta, genetic_map=None, make_haploid=True, window_size=None, store_latlon_as_nvec=False, cp_tolerance=0)
285 def __init__( 286 self, 287 vcf_data, 288 meta, 289 genetic_map = None, 290 make_haploid = True, 291 window_size = None, 292 store_latlon_as_nvec = False, 293 cp_tolerance = 0, 294 ): 295 self.vcf_data = vcf_data 296 self.meta = meta 297 self.genetic_map = genetic_map 298 self.make_haploid = make_haploid 299 self.window_size = window_size 300 self.store_latlon_as_nvec = store_latlon_as_nvec 301 self.cp_tolerance = cp_tolerance 302 303 # We will keep discrete and continuous labels separately 304 self.labels_discrete = None 305 self.labels_continuous = None 306 307 # Load everything 308 self._check_sample_metadata() 309 self._intersect_vcf_metadata() 310 self._build_descriptors() 311 self._broadcast_labels_across_snps()
def
simulate( self, batch_size=256, num_generation_max=10, balanced=False, single_ancestry=False, device='cpu', pool_method='mode'):
509 def simulate( 510 self, 511 batch_size=256, 512 num_generation_max=10, 513 balanced=False, 514 single_ancestry=False, 515 device='cpu', 516 pool_method='mode' 517 ): 518 """ 519 Returns a tuple of: 520 ( batch_snps, final_discrete_labels_window, final_continuous_labels_window ) 521 522 where: 523 - batch_snps.shape == (B, D) 524 - final_discrete_labels_window == (B, n_windows) if discrete was present, else None 525 - final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None 526 """ 527 # pick random subset of samples 528 N = self.snps.shape[0] 529 idx = torch.randint(N, (batch_size,)) 530 batch_snps = self.snps[idx, :].clone() # shape (B, D) 531 532 # Subset discrete 533 if self.labels_discrete is not None: 534 batch_discrete = self.labels_discrete[idx, :].clone() # shape (B, D) 535 else: 536 batch_discrete = None 537 538 # Subset continuous 539 if self.labels_continuous is not None: 540 batch_continuous = self.labels_continuous[idx, :, :].clone() # (B, D, dim) 541 else: 542 batch_continuous = None 543 544 # 2) possibly do single_ancestry or balanced logic if you want 545 # We'll skip it for brevity; your original code had that logic. 546 547 # Crossovers 548 batch_snps, batch_discrete, batch_continuous = self._simulate_from_pool( 549 batch_snps, batch_discrete, batch_continuous, 550 num_generation_max=num_generation_max, 551 device=device 552 ) 553 # Window-chunk each label array if window_size is specified 554 discrete_out = None 555 continuous_out = None 556 final_cp_window = None 557 558 if self.window_size is not None and self.window_size > 0: 559 if batch_discrete is not None: 560 discrete_out = _chunk_label_array( 561 labels=batch_discrete, 562 window_size=self.window_size, 563 descriptor="discrete", 564 pool_method=None, 565 ) 566 # shape => (B, D) 567 # find SNP-level breakpoints 568 # Compare label[i] vs label[i-1] 569 # We'll do that in NumPy or Torch 570 lab_np = batch_discrete.cpu().numpy() # (B, D) 571 # define a shift comparison 572 # cp_mask[:,i] = True if lab[:,i] != lab[:,i-1] 573 # We'll do it for i from 1..D-1 574 cp_mask = np.zeros_like(lab_np, dtype=bool) # (B, D) 575 cp_mask[:, 1:] = (lab_np[:, 1:] != lab_np[:, :-1]) 576 577 # Now chunk that cp_mask into windows 578 cp_mask_t = torch.from_numpy(cp_mask) 579 final_cp_window = _chunk_changepoints(cp_mask_t, self.window_size) 580 581 if batch_continuous is not None: 582 continuous_out = _chunk_label_array( 583 labels=batch_continuous, 584 window_size=self.window_size, 585 descriptor="continuous", 586 pool_method=pool_method, 587 ) 588 589 return batch_snps.float(), discrete_out, continuous_out, final_cp_window
Returns a tuple of:
( batch_snps, final_discrete_labels_window, final_continuous_labels_window )
where:
- batch_snps.shape == (B, D)
- final_discrete_labels_window == (B, n_windows) if discrete was present, else None
- final_continuous_labels_window == (B, n_windows, cdim) if continuous was present, else None