snputils.visualization.scatter_plot

  1import numpy as np
  2import pandas as pd
  3import matplotlib.pyplot as plt
  4from matplotlib import cm
  5from typing import Optional
  6from adjustText import adjust_text
  7
  8
  9def scatter(
 10    dimredobj: np.ndarray,
 11    labels_file: str,
 12    abbreviation_inside_dots: bool = True,
 13    arrows_for_titles: bool = False,
 14    dots: bool = True,
 15    legend: bool = True,
 16    color_palette=None,
 17    show: bool = True,
 18    save_path: Optional[str] = None
 19) -> None:
 20    """
 21    Plot a scatter plot with centroids for each group, with options for labeling and display styles.
 22    
 23    Args:
 24        dimredobj (np.ndarray): 
 25            Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape.
 26        labels_file (str): 
 27            Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
 28        abbreviation_inside_dots (bool): 
 29            If True, displays abbreviated labels (first 3 characters) inside the centroid markers.
 30        arrows_for_titles (bool): 
 31            If True, adds arrows pointing to centroids with group labels displayed near the centroids.
 32        legend (bool): 
 33            If True, includes a legend indicating each group label.
 34        color_palette (optional): 
 35            Color map or list of colors to use for unique labels. Defaults to 'tab10' if None.
 36        show (bool, optional):
 37            Whether to display the plot. Defaults to False.
 38        save_path (str, optional):
 39            Path to save the plot image. If None, the plot is not saved.
 40        
 41    Returns:
 42        None
 43    """
 44    # Load labels from TSV
 45    labels_df = pd.read_csv(labels_file, sep='\t')
 46
 47    # Ensure 'indID' is treated as a string
 48    labels_df['indID'] = labels_df['indID'].astype(str)
 49
 50    # Filter labels based on the indIDs in dimredobj
 51    sample_ids = dimredobj.samples_
 52    filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)]
 53
 54    # Define unique colors for each group label, either from color_palette or defaulting to 'tab10'
 55    unique_labels = filtered_labels_df['label'].unique()
 56    colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels))
 57
 58    # Initialize the plot
 59    fig, ax = plt.subplots(figsize=(10, 8))
 60
 61    # Calculate the overall center of the plot (used for positioning arrows)
 62    plot_center = dimredobj.X_new_.mean(axis=0)
 63
 64    # Dictionary to hold centroid positions for each label
 65    centroids = {}
 66
 67    # Plot data points and centroids by label
 68    for i, label in enumerate(unique_labels):
 69        # Get sample IDs corresponding to the current label
 70        sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID']
 71        
 72        # Filter points based on sample IDs
 73        points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)]
 74
 75        if dots:
 76            # Plot individual points for the current group
 77            ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label)
 78        else:
 79            # TODO: solve bug
 80            for point in points:
 81                print(point[0], point[1])
 82                ax.text(point[0], point[1], label[:2].upper(), ha='center', va='center', color=colors(i), fontsize=8, weight='bold')
 83
 84        # Calculate and mark the centroid for the current group
 85        centroid = points.mean(axis=0)
 86        centroids[label] = centroid  # Store centroid for later use
 87        
 88        # Plot the centroid as a larger dot
 89        ax.scatter(*centroid, color=colors(i), s=300)
 90
 91        # Optionally add an abbreviation inside the centroid dot
 92        if abbreviation_inside_dots:
 93            ax.text(centroid[0], centroid[1], label[:2].upper(), ha='center', va='center', color='white', fontsize=8, weight='bold')
 94
 95    # Adding arrows and labels with `adjust_text` for no overlap
 96    texts = []
 97    for label, centroid in centroids.items():
 98        # Determine the direction of the arrow based on centroid position
 99        offset_x = 0.07 if centroid[0] >= plot_center[0] else -0.07
100        offset_y = 0.07 if centroid[1] >= plot_center[1] else -0.07
101        
102        if arrows_for_titles:
103            ax.annotate('', xy=centroid, 
104                        xytext=(centroid[0] + offset_x, centroid[1] + offset_y),
105                        arrowprops=dict(facecolor=colors(unique_labels.tolist().index(label)), 
106                                        shrink=0.05, width=1.5, headwidth=8))
107        
108            # Label text for centroid
109            texts.append(ax.text(centroid[0] + offset_x, centroid[1] + offset_y, label, 
110                                color=colors(unique_labels.tolist().index(label)), 
111                                fontsize=12, weight='bold'))
112
113        # Adjust text to prevent overlap
114        adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.6))
115
116    # Configure additional plot settings
117    ax.set_xlabel("Component 1")
118    ax.set_ylabel("Component 2")
119    if legend:
120        ax.legend(loc='upper right')
121    
122    # Save plot if save_path is provided
123    if save_path:
124        plt.savefig(save_path)
125
126    # Display plot if show is True
127    if show:
128        plt.show()
129    else:
130        plt.close()
def scatter( dimredobj: numpy.ndarray, labels_file: str, abbreviation_inside_dots: bool = True, arrows_for_titles: bool = False, dots: bool = True, legend: bool = True, color_palette=None, show: bool = True, save_path: Optional[str] = None) -> None:
 10def scatter(
 11    dimredobj: np.ndarray,
 12    labels_file: str,
 13    abbreviation_inside_dots: bool = True,
 14    arrows_for_titles: bool = False,
 15    dots: bool = True,
 16    legend: bool = True,
 17    color_palette=None,
 18    show: bool = True,
 19    save_path: Optional[str] = None
 20) -> None:
 21    """
 22    Plot a scatter plot with centroids for each group, with options for labeling and display styles.
 23    
 24    Args:
 25        dimredobj (np.ndarray): 
 26            Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape.
 27        labels_file (str): 
 28            Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
 29        abbreviation_inside_dots (bool): 
 30            If True, displays abbreviated labels (first 3 characters) inside the centroid markers.
 31        arrows_for_titles (bool): 
 32            If True, adds arrows pointing to centroids with group labels displayed near the centroids.
 33        legend (bool): 
 34            If True, includes a legend indicating each group label.
 35        color_palette (optional): 
 36            Color map or list of colors to use for unique labels. Defaults to 'tab10' if None.
 37        show (bool, optional):
 38            Whether to display the plot. Defaults to False.
 39        save_path (str, optional):
 40            Path to save the plot image. If None, the plot is not saved.
 41        
 42    Returns:
 43        None
 44    """
 45    # Load labels from TSV
 46    labels_df = pd.read_csv(labels_file, sep='\t')
 47
 48    # Ensure 'indID' is treated as a string
 49    labels_df['indID'] = labels_df['indID'].astype(str)
 50
 51    # Filter labels based on the indIDs in dimredobj
 52    sample_ids = dimredobj.samples_
 53    filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)]
 54
 55    # Define unique colors for each group label, either from color_palette or defaulting to 'tab10'
 56    unique_labels = filtered_labels_df['label'].unique()
 57    colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels))
 58
 59    # Initialize the plot
 60    fig, ax = plt.subplots(figsize=(10, 8))
 61
 62    # Calculate the overall center of the plot (used for positioning arrows)
 63    plot_center = dimredobj.X_new_.mean(axis=0)
 64
 65    # Dictionary to hold centroid positions for each label
 66    centroids = {}
 67
 68    # Plot data points and centroids by label
 69    for i, label in enumerate(unique_labels):
 70        # Get sample IDs corresponding to the current label
 71        sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID']
 72        
 73        # Filter points based on sample IDs
 74        points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)]
 75
 76        if dots:
 77            # Plot individual points for the current group
 78            ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label)
 79        else:
 80            # TODO: solve bug
 81            for point in points:
 82                print(point[0], point[1])
 83                ax.text(point[0], point[1], label[:2].upper(), ha='center', va='center', color=colors(i), fontsize=8, weight='bold')
 84
 85        # Calculate and mark the centroid for the current group
 86        centroid = points.mean(axis=0)
 87        centroids[label] = centroid  # Store centroid for later use
 88        
 89        # Plot the centroid as a larger dot
 90        ax.scatter(*centroid, color=colors(i), s=300)
 91
 92        # Optionally add an abbreviation inside the centroid dot
 93        if abbreviation_inside_dots:
 94            ax.text(centroid[0], centroid[1], label[:2].upper(), ha='center', va='center', color='white', fontsize=8, weight='bold')
 95
 96    # Adding arrows and labels with `adjust_text` for no overlap
 97    texts = []
 98    for label, centroid in centroids.items():
 99        # Determine the direction of the arrow based on centroid position
100        offset_x = 0.07 if centroid[0] >= plot_center[0] else -0.07
101        offset_y = 0.07 if centroid[1] >= plot_center[1] else -0.07
102        
103        if arrows_for_titles:
104            ax.annotate('', xy=centroid, 
105                        xytext=(centroid[0] + offset_x, centroid[1] + offset_y),
106                        arrowprops=dict(facecolor=colors(unique_labels.tolist().index(label)), 
107                                        shrink=0.05, width=1.5, headwidth=8))
108        
109            # Label text for centroid
110            texts.append(ax.text(centroid[0] + offset_x, centroid[1] + offset_y, label, 
111                                color=colors(unique_labels.tolist().index(label)), 
112                                fontsize=12, weight='bold'))
113
114        # Adjust text to prevent overlap
115        adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.6))
116
117    # Configure additional plot settings
118    ax.set_xlabel("Component 1")
119    ax.set_ylabel("Component 2")
120    if legend:
121        ax.legend(loc='upper right')
122    
123    # Save plot if save_path is provided
124    if save_path:
125        plt.savefig(save_path)
126
127    # Display plot if show is True
128    if show:
129        plt.show()
130    else:
131        plt.close()

Plot a scatter plot with centroids for each group, with options for labeling and display styles.

Arguments:
  • dimredobj (np.ndarray): Reduced dimensionality data; expected to have (n_haplotypes, 2) shape.
  • labels_file (str): Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
  • abbreviation_inside_dots (bool): If True, displays abbreviated labels (first 3 characters) inside the centroid markers.
  • arrows_for_titles (bool): If True, adds arrows pointing to centroids with group labels displayed near the centroids.
  • legend (bool): If True, includes a legend indicating each group label.
  • color_palette (optional): Color map or list of colors to use for unique labels. Defaults to 'tab10' if None.
  • show (bool, optional): Whether to display the plot. Defaults to False.
  • save_path (str, optional): Path to save the plot image. If None, the plot is not saved.
Returns:

None