snputils.visualization.lai

  1import numpy as np
  2from typing import Optional, Tuple, Dict, cast
  3import matplotlib.pyplot as plt
  4import matplotlib.colors as mcolors
  5import matplotlib.patches as patches
  6
  7from snputils.ancestry.genobj.local import LocalAncestryObject
  8
  9
 10def _custom_cmap(colors: Dict, padding: float = 1.05):
 11    """
 12    Create a custom colormap from a dictionary.
 13
 14    Args:
 15        colors: Dictionary with levels as keys and color names as values.
 16        padding: Offset applied to levels. Defaults to 1.05.
 17
 18    Returns:
 19        cmap: The custom colormap.
 20    """
 21    labels = sorted(float(key) for key in colors if key is not None)
 22    clrs = [colors[key] for key in labels]
 23    
 24    # Adjust levels to match the number of colors minus one
 25    levels = sorted([labels[0] - 1] + [x * padding for x in labels])
 26
 27    cmap, _ = mcolors.from_levels_and_colors(levels, clrs)
 28    return cmap
 29
 30
 31def plot_lai(
 32    laiobj: LocalAncestryObject, 
 33    colors: Dict,
 34    sort: Optional[bool]=True,
 35    figsize: Optional[Tuple[float, float]]=None,
 36    legend: Optional[bool]=False,
 37    title: Optional[str]=None,
 38    fontsize: Optional[Dict[str, float]]=None,
 39    scale: int=2,
 40):
 41    """
 42    Plot LAI (Local Ancestry Inference) data with customizable options. Each row 
 43    represents the ancestry of a sample at the window level, distinguishing 
 44    between maternal and paternal strands. Whitespace is used to separate 
 45    individual samples.
 46
 47    Args:
 48        laiobj: A LocalAncestryObject containing LAI data.
 49        colors: A dictionary with ancestry-color mapping.
 50        sort: If True, sort samples based on the most frequent ancestry. 
 51            Samples are displayed with the most predominant ancestry first, followed by the 
 52            second most predominant, and so on. Defaults to True.
 53        figsize: Figure size. If is None, the figure is displayed with a default size of 
 54            (25, 25). Defaults to None.
 55        legend: If True, display a legend. If ``sort==True``, ancestries in the 
 56            legend are sorted based on their total frequency in descending order. Defaults to False.
 57        title: Title for the plot. If None, no title is displayed. Defaults to None.
 58        fontsize: Font sizes for various plot elements. If None, default font sizes are used. 
 59            Defaults to None.
 60        scale: Number of times to duplicate rows for enhanced vertical visibility. Defaults to 2.
 61    """
 62    # The `lai` field is a 2D array containing the window-wise ancestry 
 63    # information for each individual. Consecutive rows of the transformed form 
 64    # correspond to the maternal and paternal ancestries of the same individual
 65    lai_T = laiobj.lai.T
 66    
 67    # Obtain number of samples and windows
 68    n_samples = int(lai_T.shape[0]/2)
 69    n_windows = lai_T.shape[1]
 70    
 71    if fontsize is None:
 72        fontsize = {
 73            'xticks' : 20, 
 74            'yticks' : 9, 
 75            'xlabel' : 20, 
 76            'ylabel' : 20,
 77            'legend' : 20
 78        }
 79    
 80    if sort:
 81        # Reshape `lai_T` to a 3D array where the third dimension represents 
 82        # pairs of consecutive rows, corresponding to maternal and paternal ancestries
 83        # Dimension: n_samples x 2 x n_windows
 84        maternal_paternal_pairs = lai_T.reshape(n_samples, 2, n_windows)
 85        
 86        # Reshape `maternal_paternal_pairs` to a 2D array where each row contains 
 87        # concatenated maternal and paternal ancestries
 88        # Dimension: n_samples x (2·n_windows)
 89        if maternal_paternal_pairs.ndim != 3:
 90            raise ValueError("maternal_paternal_pairs must be a 3D array, got array with ndim=" + str(maternal_paternal_pairs.ndim))
 91        num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape)
 92        
 93        flat_ancestry_pairs = maternal_paternal_pairs.reshape(
 94            num_samples, num_maternal_paternal * num_windows
 95        )
 96
 97        # Determine the most frequent ancestry for each sample
 98        most_freq_ancestry_sample = np.apply_along_axis(
 99            lambda row: np.bincount(row).argmax(), axis=1, arr=flat_ancestry_pairs
100        )
101        
102        # Obtain unique ancestry values and total counts
103        ancestry_values, ancestry_counts = np.unique(lai_T, return_counts=True)
104        # Sort ancestry values by total counts in decreasing order
105        sorted_ancestry_values = ancestry_values[np.argsort(ancestry_counts)[::-1]]        
106        
107        # For each ancestry, obtain the samples where that ancestry is predominant
108        # and store the indexes sorted in decreasing order based on the ancestry count
109        all_sorted_row_idxs = []
110        for ancestry in sorted_ancestry_values:
111            ancestry_filter = np.where(most_freq_ancestry_sample == ancestry)[0]
112            ancestry_counts_sample = np.sum(
113                flat_ancestry_pairs[ancestry_filter, :]==ancestry, axis=1
114            )
115            sorted_row_idxs_1 = np.argsort(ancestry_counts_sample)[::-1]
116            sorted_row_idxs = ancestry_filter[sorted_row_idxs_1]
117            all_sorted_row_idxs += list(sorted_row_idxs)
118        
119        # Sort sample IDs based on most frequent ancestry
120        if laiobj.samples is not None:
121            sample_ids = [laiobj.samples[idx] for idx in all_sorted_row_idxs]
122        else:
123            sample_ids = None
124        
125        # Sort `lai_T` based on most frequent ancestry
126        num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape)
127        lai_T = maternal_paternal_pairs[all_sorted_row_idxs, :].reshape(
128            -1, num_windows
129        )
130    
131    # Check if ancestry_map keys are strings of integers (reversed form)
132    if all(isinstance(key, str) and key.isdigit() for key in laiobj.ancestry_map.keys()):
133        # Reverse the dictionary to match integer-to-ancestry format
134        ancestry_map_reverse = {int(key): value for key, value in laiobj.ancestry_map.items()}
135    else:
136        # The ancestry_map is already in the correct integer-to-ancestry format
137        ancestry_map_reverse = laiobj.ancestry_map
138    
139    # Dictionary with integer-to-color mapping
140    colors_map = {key : colors[value] for key, value in ancestry_map_reverse.items()}
141    
142    # Total number of ancestries
143    n_ancestries = len(ancestry_map_reverse.keys())
144    
145    # Insert whitespace between samples
146    lai_T_with_whitespace = np.insert(lai_T, np.arange(2, n_samples*2, 2), n_ancestries, axis=0)
147    
148    # Repeat rows to increase height of samples in plot
149    lai_T_repeat = np.repeat(lai_T_with_whitespace, scale, axis=0)
150    
151    colors_map[n_ancestries] = 'white'
152    
153    # Configure custom map from matrix values to colors
154    cmap = _custom_cmap(colors_map)
155    
156    # Plot LAI output
157    if figsize is None:
158        plt.figure(figsize=(25, 25))
159    else:
160        plt.figure(figsize=figsize)
161    plt.imshow(lai_T_repeat, cmap=cmap)
162    
163    # Display sample IDs in y-axis
164    yticks_positions = np.arange(scale, lai_T_repeat.shape[0]+1, scale*(2+1))
165    plt.xticks(fontsize=fontsize['xticks'])
166    plt.yticks(yticks_positions, sample_ids, fontsize=fontsize['yticks'])
167    
168    ax = plt.gca()
169    ax.set_xlabel('Window', fontsize=fontsize['xlabel'], labelpad=8)
170    ax.set_ylabel('Sample', fontsize=fontsize['ylabel'])
171    ax.spines['top'].set_linewidth(2)
172    ax.spines['right'].set_linewidth(2)
173    ax.spines['bottom'].set_linewidth(2)
174    ax.spines['left'].set_linewidth(2)
175    ax.tick_params(axis='both', which='major', length=8, width=2)
176    ax.tick_params(axis='x', which='major', pad=4)
177    ax.tick_params(axis='y', which='major', pad=4)
178    
179    if legend:
180        if sort:
181            # Sort legend based on global ancestry frequency in decresaing order 
182            sorted_colors = [colors_map[x] for x in sorted_ancestry_values]
183            ancestries = [ancestry_map_reverse[x] for x in sorted_ancestry_values]
184            legend_patches = [patches.Patch(color=color, label=label) 
185                              for color, label in zip(sorted_colors, ancestries)]
186        else:
187            # Add patches for each color to represent the legend squares
188            legend_patches = [patches.Patch(color=color, label=label) 
189                              for color, label in zip(colors.values(), colors.keys())]
190    
191        ax.legend(
192            handles=legend_patches, 
193            loc='lower center', 
194            borderaxespad=-5,
195            ncol=n_ancestries,
196            fontsize=fontsize['legend']
197        )
198    
199    if title:
200        plt.title(title)
def plot_lai( laiobj: snputils.ancestry.genobj.LocalAncestryObject, colors: Dict, sort: Optional[bool] = True, figsize: Optional[Tuple[float, float]] = None, legend: Optional[bool] = False, title: Optional[str] = None, fontsize: Optional[Dict[str, float]] = None, scale: int = 2):
 32def plot_lai(
 33    laiobj: LocalAncestryObject, 
 34    colors: Dict,
 35    sort: Optional[bool]=True,
 36    figsize: Optional[Tuple[float, float]]=None,
 37    legend: Optional[bool]=False,
 38    title: Optional[str]=None,
 39    fontsize: Optional[Dict[str, float]]=None,
 40    scale: int=2,
 41):
 42    """
 43    Plot LAI (Local Ancestry Inference) data with customizable options. Each row 
 44    represents the ancestry of a sample at the window level, distinguishing 
 45    between maternal and paternal strands. Whitespace is used to separate 
 46    individual samples.
 47
 48    Args:
 49        laiobj: A LocalAncestryObject containing LAI data.
 50        colors: A dictionary with ancestry-color mapping.
 51        sort: If True, sort samples based on the most frequent ancestry. 
 52            Samples are displayed with the most predominant ancestry first, followed by the 
 53            second most predominant, and so on. Defaults to True.
 54        figsize: Figure size. If is None, the figure is displayed with a default size of 
 55            (25, 25). Defaults to None.
 56        legend: If True, display a legend. If ``sort==True``, ancestries in the 
 57            legend are sorted based on their total frequency in descending order. Defaults to False.
 58        title: Title for the plot. If None, no title is displayed. Defaults to None.
 59        fontsize: Font sizes for various plot elements. If None, default font sizes are used. 
 60            Defaults to None.
 61        scale: Number of times to duplicate rows for enhanced vertical visibility. Defaults to 2.
 62    """
 63    # The `lai` field is a 2D array containing the window-wise ancestry 
 64    # information for each individual. Consecutive rows of the transformed form 
 65    # correspond to the maternal and paternal ancestries of the same individual
 66    lai_T = laiobj.lai.T
 67    
 68    # Obtain number of samples and windows
 69    n_samples = int(lai_T.shape[0]/2)
 70    n_windows = lai_T.shape[1]
 71    
 72    if fontsize is None:
 73        fontsize = {
 74            'xticks' : 20, 
 75            'yticks' : 9, 
 76            'xlabel' : 20, 
 77            'ylabel' : 20,
 78            'legend' : 20
 79        }
 80    
 81    if sort:
 82        # Reshape `lai_T` to a 3D array where the third dimension represents 
 83        # pairs of consecutive rows, corresponding to maternal and paternal ancestries
 84        # Dimension: n_samples x 2 x n_windows
 85        maternal_paternal_pairs = lai_T.reshape(n_samples, 2, n_windows)
 86        
 87        # Reshape `maternal_paternal_pairs` to a 2D array where each row contains 
 88        # concatenated maternal and paternal ancestries
 89        # Dimension: n_samples x (2·n_windows)
 90        if maternal_paternal_pairs.ndim != 3:
 91            raise ValueError("maternal_paternal_pairs must be a 3D array, got array with ndim=" + str(maternal_paternal_pairs.ndim))
 92        num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape)
 93        
 94        flat_ancestry_pairs = maternal_paternal_pairs.reshape(
 95            num_samples, num_maternal_paternal * num_windows
 96        )
 97
 98        # Determine the most frequent ancestry for each sample
 99        most_freq_ancestry_sample = np.apply_along_axis(
100            lambda row: np.bincount(row).argmax(), axis=1, arr=flat_ancestry_pairs
101        )
102        
103        # Obtain unique ancestry values and total counts
104        ancestry_values, ancestry_counts = np.unique(lai_T, return_counts=True)
105        # Sort ancestry values by total counts in decreasing order
106        sorted_ancestry_values = ancestry_values[np.argsort(ancestry_counts)[::-1]]        
107        
108        # For each ancestry, obtain the samples where that ancestry is predominant
109        # and store the indexes sorted in decreasing order based on the ancestry count
110        all_sorted_row_idxs = []
111        for ancestry in sorted_ancestry_values:
112            ancestry_filter = np.where(most_freq_ancestry_sample == ancestry)[0]
113            ancestry_counts_sample = np.sum(
114                flat_ancestry_pairs[ancestry_filter, :]==ancestry, axis=1
115            )
116            sorted_row_idxs_1 = np.argsort(ancestry_counts_sample)[::-1]
117            sorted_row_idxs = ancestry_filter[sorted_row_idxs_1]
118            all_sorted_row_idxs += list(sorted_row_idxs)
119        
120        # Sort sample IDs based on most frequent ancestry
121        if laiobj.samples is not None:
122            sample_ids = [laiobj.samples[idx] for idx in all_sorted_row_idxs]
123        else:
124            sample_ids = None
125        
126        # Sort `lai_T` based on most frequent ancestry
127        num_samples, num_maternal_paternal, num_windows = cast(Tuple[int, int, int], maternal_paternal_pairs.shape)
128        lai_T = maternal_paternal_pairs[all_sorted_row_idxs, :].reshape(
129            -1, num_windows
130        )
131    
132    # Check if ancestry_map keys are strings of integers (reversed form)
133    if all(isinstance(key, str) and key.isdigit() for key in laiobj.ancestry_map.keys()):
134        # Reverse the dictionary to match integer-to-ancestry format
135        ancestry_map_reverse = {int(key): value for key, value in laiobj.ancestry_map.items()}
136    else:
137        # The ancestry_map is already in the correct integer-to-ancestry format
138        ancestry_map_reverse = laiobj.ancestry_map
139    
140    # Dictionary with integer-to-color mapping
141    colors_map = {key : colors[value] for key, value in ancestry_map_reverse.items()}
142    
143    # Total number of ancestries
144    n_ancestries = len(ancestry_map_reverse.keys())
145    
146    # Insert whitespace between samples
147    lai_T_with_whitespace = np.insert(lai_T, np.arange(2, n_samples*2, 2), n_ancestries, axis=0)
148    
149    # Repeat rows to increase height of samples in plot
150    lai_T_repeat = np.repeat(lai_T_with_whitespace, scale, axis=0)
151    
152    colors_map[n_ancestries] = 'white'
153    
154    # Configure custom map from matrix values to colors
155    cmap = _custom_cmap(colors_map)
156    
157    # Plot LAI output
158    if figsize is None:
159        plt.figure(figsize=(25, 25))
160    else:
161        plt.figure(figsize=figsize)
162    plt.imshow(lai_T_repeat, cmap=cmap)
163    
164    # Display sample IDs in y-axis
165    yticks_positions = np.arange(scale, lai_T_repeat.shape[0]+1, scale*(2+1))
166    plt.xticks(fontsize=fontsize['xticks'])
167    plt.yticks(yticks_positions, sample_ids, fontsize=fontsize['yticks'])
168    
169    ax = plt.gca()
170    ax.set_xlabel('Window', fontsize=fontsize['xlabel'], labelpad=8)
171    ax.set_ylabel('Sample', fontsize=fontsize['ylabel'])
172    ax.spines['top'].set_linewidth(2)
173    ax.spines['right'].set_linewidth(2)
174    ax.spines['bottom'].set_linewidth(2)
175    ax.spines['left'].set_linewidth(2)
176    ax.tick_params(axis='both', which='major', length=8, width=2)
177    ax.tick_params(axis='x', which='major', pad=4)
178    ax.tick_params(axis='y', which='major', pad=4)
179    
180    if legend:
181        if sort:
182            # Sort legend based on global ancestry frequency in decresaing order 
183            sorted_colors = [colors_map[x] for x in sorted_ancestry_values]
184            ancestries = [ancestry_map_reverse[x] for x in sorted_ancestry_values]
185            legend_patches = [patches.Patch(color=color, label=label) 
186                              for color, label in zip(sorted_colors, ancestries)]
187        else:
188            # Add patches for each color to represent the legend squares
189            legend_patches = [patches.Patch(color=color, label=label) 
190                              for color, label in zip(colors.values(), colors.keys())]
191    
192        ax.legend(
193            handles=legend_patches, 
194            loc='lower center', 
195            borderaxespad=-5,
196            ncol=n_ancestries,
197            fontsize=fontsize['legend']
198        )
199    
200    if title:
201        plt.title(title)

Plot LAI (Local Ancestry Inference) data with customizable options. Each row represents the ancestry of a sample at the window level, distinguishing between maternal and paternal strands. Whitespace is used to separate individual samples.

Arguments:
  • laiobj: A LocalAncestryObject containing LAI data.
  • colors: A dictionary with ancestry-color mapping.
  • sort: If True, sort samples based on the most frequent ancestry. Samples are displayed with the most predominant ancestry first, followed by the second most predominant, and so on. Defaults to True.
  • figsize: Figure size. If is None, the figure is displayed with a default size of (25, 25). Defaults to None.
  • legend: If True, display a legend. If sort==True, ancestries in the legend are sorted based on their total frequency in descending order. Defaults to False.
  • title: Title for the plot. If None, no title is displayed. Defaults to None.
  • fontsize: Font sizes for various plot elements. If None, default font sizes are used. Defaults to None.
  • scale: Number of times to duplicate rows for enhanced vertical visibility. Defaults to 2.