Source code for snputils.visualization.lai

import numpy as np
from typing import Optional, Tuple, Dict, cast
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches

from snputils.ancestry.genobj.local import LocalAncestryObject


def _custom_cmap(colors: Dict, padding: float = 1.05):
    """
    Create a custom colormap from a dictionary.

    Args:
        colors: Dictionary with levels as keys and color names as values.
        padding: Offset applied to levels. Defaults to 1.05.

    Returns:
        cmap: The custom colormap.
    """
    labels = sorted(float(key) for key in colors if key is not None)
    clrs = [colors[key] for key in labels]
    
    # Adjust levels to match the number of colors minus one
    levels = sorted([labels[0] - 1] + [x * padding for x in labels])

    cmap, _ = mcolors.from_levels_and_colors(levels, clrs)
    return cmap


[docs] def plot_lai( laiobj: LocalAncestryObject, colors: Dict, sort: Optional[bool]=True, figsize: Optional[Tuple[float, float]]=None, legend: Optional[bool]=False, title: Optional[str]=None, fontsize: Optional[Dict[str, float]]=None, scale: int=2, ): """ Plot LAI (Local Ancestry Inference) data with customizable options. Each row represents the ancestry of a sample at the window level, distinguishing between maternal and paternal strands. Whitespace is used to separate individual samples. Args: laiobj: A LocalAncestryObject containing LAI data. colors: A dictionary with ancestry-color mapping. sort: If True, sort samples based on the most frequent ancestry. Samples are displayed with the most predominant ancestry first, followed by the second most predominant, and so on. Defaults to True. figsize: Figure size. If is None, the figure is displayed with a default size of (25, 25). Defaults to None. legend: If True, display a legend. If ``sort==True``, ancestries in the legend are sorted based on their total frequency in descending order. Defaults to False. title: Title for the plot. If None, no title is displayed. Defaults to None. fontsize: Font sizes for various plot elements. If None, default font sizes are used. Defaults to None. scale: Number of times to duplicate rows for enhanced vertical visibility. Defaults to 2. """ # The `lai` field is a 2D array containing the window-wise ancestry # information for each individual. Consecutive rows of the transformed form # correspond to the maternal and paternal ancestries of the same individual lai_T = laiobj.lai.T # Obtain number of samples and windows n_samples = int(lai_T.shape[0]/2) n_windows = lai_T.shape[1] if fontsize is None: fontsize = { 'xticks' : 20, 'yticks' : 9, 'xlabel' : 20, 'ylabel' : 20, 'legend' : 20 } if sort: # Reshape `lai_T` to a 3D array where the third dimension represents # pairs of consecutive rows, corresponding to maternal and paternal ancestries # Dimension: n_samples x 2 x n_windows maternal_paternal_pairs = lai_T.reshape(n_samples, 2, n_windows) # Reshape `maternal_paternal_pairs` to a 2D array where each row contains # concatenated maternal and paternal ancestries # Dimension: n_samples x (2·n_windows) if maternal_paternal_pairs.ndim != 3: raise ValueError("maternal_paternal_pairs must be a 3D array, got array with ndim=" + str(maternal_paternal_pairs.ndim)) num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape) flat_ancestry_pairs = maternal_paternal_pairs.reshape( num_samples, num_maternal_paternal * num_windows ) # Determine the most frequent ancestry for each sample most_freq_ancestry_sample = np.apply_along_axis( lambda row: np.bincount(row).argmax(), axis=1, arr=flat_ancestry_pairs ) # Obtain unique ancestry values and total counts ancestry_values, ancestry_counts = np.unique(lai_T, return_counts=True) # Sort ancestry values by total counts in decreasing order sorted_ancestry_values = ancestry_values[np.argsort(ancestry_counts)[::-1]] # For each ancestry, obtain the samples where that ancestry is predominant # and store the indexes sorted in decreasing order based on the ancestry count all_sorted_row_idxs = [] for ancestry in sorted_ancestry_values: ancestry_filter = np.where(most_freq_ancestry_sample == ancestry)[0] ancestry_counts_sample = np.sum( flat_ancestry_pairs[ancestry_filter, :]==ancestry, axis=1 ) sorted_row_idxs_1 = np.argsort(ancestry_counts_sample)[::-1] sorted_row_idxs = ancestry_filter[sorted_row_idxs_1] all_sorted_row_idxs += list(sorted_row_idxs) # Sort sample IDs based on most frequent ancestry if laiobj.samples is not None: sample_ids = [laiobj.samples[idx] for idx in all_sorted_row_idxs] else: sample_ids = None # Sort `lai_T` based on most frequent ancestry num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape) lai_T = maternal_paternal_pairs[all_sorted_row_idxs, :].reshape( -1, num_windows ) # Check if ancestry_map keys are strings of integers (reversed form) if all(isinstance(key, str) and key.isdigit() for key in laiobj.ancestry_map.keys()): # Reverse the dictionary to match integer-to-ancestry format ancestry_map_reverse = {int(key): value for key, value in laiobj.ancestry_map.items()} else: # The ancestry_map is already in the correct integer-to-ancestry format ancestry_map_reverse = laiobj.ancestry_map # Dictionary with integer-to-color mapping colors_map = {key : colors[value] for key, value in ancestry_map_reverse.items()} # Total number of ancestries n_ancestries = len(ancestry_map_reverse.keys()) # Insert whitespace between samples lai_T_with_whitespace = np.insert(lai_T, np.arange(2, n_samples*2, 2), n_ancestries, axis=0) # Repeat rows to increase height of samples in plot lai_T_repeat = np.repeat(lai_T_with_whitespace, scale, axis=0) colors_map[n_ancestries] = 'white' # Configure custom map from matrix values to colors cmap = _custom_cmap(colors_map) # Plot LAI output if figsize is None: plt.figure(figsize=(25, 25)) else: plt.figure(figsize=figsize) plt.imshow(lai_T_repeat, cmap=cmap) # Display sample IDs in y-axis yticks_positions = np.arange(scale, lai_T_repeat.shape[0]+1, scale*(2+1)) plt.xticks(fontsize=fontsize['xticks']) plt.yticks(yticks_positions, sample_ids, fontsize=fontsize['yticks']) ax = plt.gca() ax.set_xlabel('Window', fontsize=fontsize['xlabel'], labelpad=8) ax.set_ylabel('Sample', fontsize=fontsize['ylabel']) ax.spines['top'].set_linewidth(2) ax.spines['right'].set_linewidth(2) ax.spines['bottom'].set_linewidth(2) ax.spines['left'].set_linewidth(2) ax.tick_params(axis='both', which='major', length=8, width=2) ax.tick_params(axis='x', which='major', pad=4) ax.tick_params(axis='y', which='major', pad=4) if legend: if sort: # Sort legend based on global ancestry frequency in decresaing order sorted_colors = [colors_map[x] for x in sorted_ancestry_values] ancestries = [ancestry_map_reverse[x] for x in sorted_ancestry_values] legend_patches = [patches.Patch(color=color, label=label) for color, label in zip(sorted_colors, ancestries)] else: # Add patches for each color to represent the legend squares legend_patches = [patches.Patch(color=color, label=label) for color, label in zip(colors.values(), colors.keys())] ax.legend( handles=legend_patches, loc='lower center', borderaxespad=-5, ncol=n_ancestries, fontsize=fontsize['legend'] ) if title: plt.title(title)