snputils.visualization.scatter_plot
1import numpy as np 2import pandas as pd 3import matplotlib.pyplot as plt 4from matplotlib import cm 5from typing import Optional 6from adjustText import adjust_text 7 8 9def scatter( 10 dimredobj: np.ndarray, 11 labels_file: str, 12 abbreviation_inside_dots: bool = True, 13 arrows_for_titles: bool = False, 14 dots: bool = True, 15 legend: bool = True, 16 color_palette=None, 17 show: bool = True, 18 save_path: Optional[str] = None 19) -> None: 20 """ 21 Plot a scatter plot with centroids for each group, with options for labeling and display styles. 22 23 Args: 24 dimredobj (np.ndarray): 25 Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape. 26 labels_file (str): 27 Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points. 28 abbreviation_inside_dots (bool): 29 If True, displays abbreviated labels (first 3 characters) inside the centroid markers. 30 arrows_for_titles (bool): 31 If True, adds arrows pointing to centroids with group labels displayed near the centroids. 32 legend (bool): 33 If True, includes a legend indicating each group label. 34 color_palette (optional): 35 Color map or list of colors to use for unique labels. Defaults to 'tab10' if None. 36 show (bool, optional): 37 Whether to display the plot. Defaults to False. 38 save_path (str, optional): 39 Path to save the plot image. If None, the plot is not saved. 40 41 Returns: 42 None 43 """ 44 # Load labels from TSV 45 labels_df = pd.read_csv(labels_file, sep='\t') 46 47 # Ensure 'indID' is treated as a string 48 labels_df['indID'] = labels_df['indID'].astype(str) 49 50 # Filter labels based on the indIDs in dimredobj 51 sample_ids = dimredobj.samples_ 52 filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)] 53 54 # Define unique colors for each group label, either from color_palette or defaulting to 'tab10' 55 unique_labels = filtered_labels_df['label'].unique() 56 colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels)) 57 58 # Initialize the plot 59 fig, ax = plt.subplots(figsize=(10, 8)) 60 61 # Calculate the overall center of the plot (used for positioning arrows) 62 plot_center = dimredobj.X_new_.mean(axis=0) 63 64 # Dictionary to hold centroid positions for each label 65 centroids = {} 66 67 # Plot data points and centroids by label 68 for i, label in enumerate(unique_labels): 69 # Get sample IDs corresponding to the current label 70 sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID'] 71 72 # Filter points based on sample IDs 73 points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)] 74 75 if dots: 76 # Plot individual points for the current group 77 ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label) 78 else: 79 # TODO: solve bug 80 for point in points: 81 print(point[0], point[1]) 82 ax.text(point[0], point[1], label[:2].upper(), ha='center', va='center', color=colors(i), fontsize=8, weight='bold') 83 84 # Calculate and mark the centroid for the current group 85 centroid = points.mean(axis=0) 86 centroids[label] = centroid # Store centroid for later use 87 88 # Plot the centroid as a larger dot 89 ax.scatter(*centroid, color=colors(i), s=300) 90 91 # Optionally add an abbreviation inside the centroid dot 92 if abbreviation_inside_dots: 93 ax.text(centroid[0], centroid[1], label[:2].upper(), ha='center', va='center', color='white', fontsize=8, weight='bold') 94 95 # Adding arrows and labels with `adjust_text` for no overlap 96 texts = [] 97 for label, centroid in centroids.items(): 98 # Determine the direction of the arrow based on centroid position 99 offset_x = 0.07 if centroid[0] >= plot_center[0] else -0.07 100 offset_y = 0.07 if centroid[1] >= plot_center[1] else -0.07 101 102 if arrows_for_titles: 103 ax.annotate('', xy=centroid, 104 xytext=(centroid[0] + offset_x, centroid[1] + offset_y), 105 arrowprops=dict(facecolor=colors(unique_labels.tolist().index(label)), 106 shrink=0.05, width=1.5, headwidth=8)) 107 108 # Label text for centroid 109 texts.append(ax.text(centroid[0] + offset_x, centroid[1] + offset_y, label, 110 color=colors(unique_labels.tolist().index(label)), 111 fontsize=12, weight='bold')) 112 113 # Adjust text to prevent overlap 114 adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.6)) 115 116 # Configure additional plot settings 117 ax.set_xlabel("Component 1") 118 ax.set_ylabel("Component 2") 119 if legend: 120 ax.legend(loc='upper right') 121 122 # Save plot if save_path is provided 123 if save_path: 124 plt.savefig(save_path) 125 126 # Display plot if show is True 127 if show: 128 plt.show() 129 else: 130 plt.close()
def
scatter( dimredobj: numpy.ndarray, labels_file: str, abbreviation_inside_dots: bool = True, arrows_for_titles: bool = False, dots: bool = True, legend: bool = True, color_palette=None, show: bool = True, save_path: Optional[str] = None) -> None:
10def scatter( 11 dimredobj: np.ndarray, 12 labels_file: str, 13 abbreviation_inside_dots: bool = True, 14 arrows_for_titles: bool = False, 15 dots: bool = True, 16 legend: bool = True, 17 color_palette=None, 18 show: bool = True, 19 save_path: Optional[str] = None 20) -> None: 21 """ 22 Plot a scatter plot with centroids for each group, with options for labeling and display styles. 23 24 Args: 25 dimredobj (np.ndarray): 26 Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape. 27 labels_file (str): 28 Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points. 29 abbreviation_inside_dots (bool): 30 If True, displays abbreviated labels (first 3 characters) inside the centroid markers. 31 arrows_for_titles (bool): 32 If True, adds arrows pointing to centroids with group labels displayed near the centroids. 33 legend (bool): 34 If True, includes a legend indicating each group label. 35 color_palette (optional): 36 Color map or list of colors to use for unique labels. Defaults to 'tab10' if None. 37 show (bool, optional): 38 Whether to display the plot. Defaults to False. 39 save_path (str, optional): 40 Path to save the plot image. If None, the plot is not saved. 41 42 Returns: 43 None 44 """ 45 # Load labels from TSV 46 labels_df = pd.read_csv(labels_file, sep='\t') 47 48 # Ensure 'indID' is treated as a string 49 labels_df['indID'] = labels_df['indID'].astype(str) 50 51 # Filter labels based on the indIDs in dimredobj 52 sample_ids = dimredobj.samples_ 53 filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)] 54 55 # Define unique colors for each group label, either from color_palette or defaulting to 'tab10' 56 unique_labels = filtered_labels_df['label'].unique() 57 colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels)) 58 59 # Initialize the plot 60 fig, ax = plt.subplots(figsize=(10, 8)) 61 62 # Calculate the overall center of the plot (used for positioning arrows) 63 plot_center = dimredobj.X_new_.mean(axis=0) 64 65 # Dictionary to hold centroid positions for each label 66 centroids = {} 67 68 # Plot data points and centroids by label 69 for i, label in enumerate(unique_labels): 70 # Get sample IDs corresponding to the current label 71 sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID'] 72 73 # Filter points based on sample IDs 74 points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)] 75 76 if dots: 77 # Plot individual points for the current group 78 ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label) 79 else: 80 # TODO: solve bug 81 for point in points: 82 print(point[0], point[1]) 83 ax.text(point[0], point[1], label[:2].upper(), ha='center', va='center', color=colors(i), fontsize=8, weight='bold') 84 85 # Calculate and mark the centroid for the current group 86 centroid = points.mean(axis=0) 87 centroids[label] = centroid # Store centroid for later use 88 89 # Plot the centroid as a larger dot 90 ax.scatter(*centroid, color=colors(i), s=300) 91 92 # Optionally add an abbreviation inside the centroid dot 93 if abbreviation_inside_dots: 94 ax.text(centroid[0], centroid[1], label[:2].upper(), ha='center', va='center', color='white', fontsize=8, weight='bold') 95 96 # Adding arrows and labels with `adjust_text` for no overlap 97 texts = [] 98 for label, centroid in centroids.items(): 99 # Determine the direction of the arrow based on centroid position 100 offset_x = 0.07 if centroid[0] >= plot_center[0] else -0.07 101 offset_y = 0.07 if centroid[1] >= plot_center[1] else -0.07 102 103 if arrows_for_titles: 104 ax.annotate('', xy=centroid, 105 xytext=(centroid[0] + offset_x, centroid[1] + offset_y), 106 arrowprops=dict(facecolor=colors(unique_labels.tolist().index(label)), 107 shrink=0.05, width=1.5, headwidth=8)) 108 109 # Label text for centroid 110 texts.append(ax.text(centroid[0] + offset_x, centroid[1] + offset_y, label, 111 color=colors(unique_labels.tolist().index(label)), 112 fontsize=12, weight='bold')) 113 114 # Adjust text to prevent overlap 115 adjust_text(texts, arrowprops=dict(arrowstyle="->", color='gray', lw=0.6)) 116 117 # Configure additional plot settings 118 ax.set_xlabel("Component 1") 119 ax.set_ylabel("Component 2") 120 if legend: 121 ax.legend(loc='upper right') 122 123 # Save plot if save_path is provided 124 if save_path: 125 plt.savefig(save_path) 126 127 # Display plot if show is True 128 if show: 129 plt.show() 130 else: 131 plt.close()
Plot a scatter plot with centroids for each group, with options for labeling and display styles.
Arguments:
- dimredobj (np.ndarray): Reduced dimensionality data; expected to have
(n_haplotypes, 2)
shape. - labels_file (str): Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
- abbreviation_inside_dots (bool): If True, displays abbreviated labels (first 3 characters) inside the centroid markers.
- arrows_for_titles (bool): If True, adds arrows pointing to centroids with group labels displayed near the centroids.
- legend (bool): If True, includes a legend indicating each group label.
- color_palette (optional): Color map or list of colors to use for unique labels. Defaults to 'tab10' if None.
- show (bool, optional): Whether to display the plot. Defaults to False.
- save_path (str, optional): Path to save the plot image. If None, the plot is not saved.
Returns:
None