Source code for snputils.visualization.scatter_plot

from __future__ import annotations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from typing import Callable, Mapping, Optional, Sequence, Union

from adjustText import adjust_text

from ._figure_export import default_savefig_kwargs, scatter_rasterized_for_path
from .constants import get_palette_colors

PUBLICATION_RC = {
    "font.family": "sans-serif",
    "font.sans-serif": ["DejaVu Sans", "Helvetica", "Arial", "Liberation Sans"],
    "font.size": 11,
    "axes.labelsize": 12,
    "axes.titlesize": 12,
    "legend.fontsize": 9,
    "axes.linewidth": 0.9,
    "xtick.major.width": 0.9,
    "ytick.major.width": 0.9,
    "xtick.direction": "out",
    "ytick.direction": "out",
}


def _normalize_label_key(label: str) -> list[str]:
    s = str(label).strip()
    return [
        s,
        s.replace(" ", "_"),
        s.replace("_", " "),
    ]


def _resolve_label_color(
    label: str,
    idx: int,
    label_colors: Optional[Mapping[str, str]],
    cmap_fn: Callable[[int], np.ndarray],
) -> Union[str, np.ndarray]:
    if label_colors is not None:
        for key in _normalize_label_key(label):
            if key in label_colors:
                return label_colors[key]
    return cmap_fn(int(idx))


def _generate_distinct_colors(n: int) -> list[str]:
    """Generate *n* colors by cycling the snputils palette."""
    return get_palette_colors(n)


def plot_embedding(
    embedding: pd.DataFrame,
    *,
    x: str = "PC1",
    y: str = "PC2",
    hue: Optional[str] = None,
    title: Optional[str] = None,
    ax: Optional[plt.Axes] = None,
    figsize: tuple[float, float] = (9.0, 7.0),
    point_size: float = 34.0,
    point_alpha: float = 0.6,
    markers: Sequence[str] = ("o", "v", "^", "<", ">", "s", "P", "X", "D"),
    label_colors: Optional[Mapping[str, str]] = None,
    category_order: Optional[Sequence[str]] = None,
    legend: bool = True,
    legend_title: Optional[str] = "",
    legend_fontsize: float = 14.0,
    legend_outside: bool = True,
    grid: bool = False,
    zero_lines: bool = False,
    despine: bool = True,
    axis_style: str = "standard",
    save_path: Optional[str] = None,
    show: bool = False,
    savefig_kwargs: Optional[dict] = None,
) -> plt.Axes:
    """
    Plot a two-dimensional embedding table, optionally colored by metadata.

    Args:
        embedding: DataFrame containing coordinate columns such as ``PC1`` and
            ``PC2``. Tables produced by
            :func:`snputils.processing.embedding_dataframe_from_model` work
            directly.
        x, y: Coordinate columns to plot.
        hue: Optional metadata column used to color points.
        title: Optional axes title.
        ax: Existing matplotlib axes. If omitted, a new figure and axes are created.
        figsize: Figure size used when ``ax`` is omitted.
        point_size: Matplotlib marker area for points.
        point_alpha: Point opacity.
        markers: Marker cycle used across hue categories.
        label_colors: Optional mapping from hue values to matplotlib colors.
        category_order: Optional order for hue categories. Missing categories are ignored.
        legend: Whether to draw a legend when ``hue`` is set.
        legend_title: Legend title. Empty string removes the title by default.
        legend_fontsize: Legend label font size.
        legend_outside: If True, place the legend outside the right side of the axes.
        grid: Whether to draw a grid.
        zero_lines: Draw horizontal and vertical lines at zero.
        despine: Hide top and right axes spines.
        axis_style: Axis spine style. ``"separated"`` offsets left/bottom spines
            so they do not touch and adds end ticks; ``"standard"`` keeps the
            default matplotlib spine behavior.
        save_path: Optional path for the figure.
        show: If True, call ``plt.show()`` before returning.
        savefig_kwargs: Extra keyword arguments for ``Figure.savefig``.

    Returns:
        The matplotlib axes containing the plot.
    """
    if x not in embedding.columns:
        raise ValueError(f"embedding does not contain x column {x!r}")
    if y not in embedding.columns:
        raise ValueError(f"embedding does not contain y column {y!r}")
    if hue is not None and hue not in embedding.columns:
        raise ValueError(f"embedding does not contain hue column {hue!r}")

    if ax is None:
        _, ax = plt.subplots(figsize=figsize)
    if len(markers) == 0:
        raise ValueError("markers must contain at least one marker style")

    if hue is None:
        ax.scatter(
            embedding[x],
            embedding[y],
            s=point_size,
            alpha=point_alpha,
            marker=markers[0],
            linewidths=0,
            rasterized=scatter_rasterized_for_path(save_path),
        )
    else:
        labels = embedding[hue].fillna("Unknown").astype(str)
        if category_order is None:
            categories = sorted(labels.unique())
        else:
            order = [str(c) for c in category_order]
            categories = [c for c in order if c in set(labels)]
            categories.extend(c for c in sorted(labels.unique()) if c not in set(categories))
        colors = _generate_distinct_colors(len(categories))

        for idx, category in enumerate(categories):
            mask = labels == category
            color = _resolve_label_color(
                category,
                idx,
                label_colors,
                lambda i: colors[int(i) % len(colors)],
            )
            ax.scatter(
                embedding.loc[mask, x],
                embedding.loc[mask, y],
                label=category,
                s=point_size,
                alpha=point_alpha,
                linewidths=0,
                color=color,
                marker=markers[idx % len(markers)],
                rasterized=scatter_rasterized_for_path(save_path),
            )

    if zero_lines:
        ax.axhline(0, color="0.85", linewidth=0.8, zorder=0)
        ax.axvline(0, color="0.85", linewidth=0.8, zorder=0)
    if title is not None:
        ax.set_title(title)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.grid(grid)
    ax.xaxis.label.set_size(ax.xaxis.label.get_size() * 1.5)
    ax.yaxis.label.set_size(ax.yaxis.label.get_size() * 1.5)
    ax.title.set_size(ax.title.get_size() * 1.5)
    tick_size = FontProperties(size=plt.rcParams["xtick.labelsize"]).get_size_in_points()
    ax.tick_params(axis="both", which="both", labelsize=tick_size * 1.5)
    if axis_style not in {"separated", "standard"}:
        raise ValueError("axis_style must be one of {'separated', 'standard'}")
    if axis_style == "separated":
        ax.spines["left"].set_position(("outward", 12))
        ax.spines["bottom"].set_position(("outward", 12))
        ax.tick_params(axis="x", direction="out", length=10, width=0.9, pad=10)
        ax.tick_params(axis="y", direction="out", length=10, width=0.9, pad=10)
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        xticks = np.asarray(ax.get_xticks(), dtype=float)
        yticks = np.asarray(ax.get_yticks(), dtype=float)
        xticks = xticks[(xticks >= xlim[0]) & (xticks <= xlim[1])]
        yticks = yticks[(yticks >= ylim[0]) & (yticks <= ylim[1])]
        ax.set_xticks(xticks)
        ax.set_yticks(yticks)
        ax.spines["bottom"].set_bounds(xticks[0], xticks[-1])
        ax.spines["left"].set_bounds(yticks[0], yticks[-1])
    if despine:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    if hue is not None and legend:
        legend_kw = {"title": legend_title, "frameon": False, "fontsize": legend_fontsize}
        if legend_outside:
            legend_kw.update({"bbox_to_anchor": (1.02, 1), "loc": "upper left", "borderaxespad": 0})
        ax.legend(**legend_kw)
    ax.figure.tight_layout()

    if save_path is not None:
        kw = dict(default_savefig_kwargs(str(save_path)))
        kw.update(savefig_kwargs or {})
        ax.figure.savefig(save_path, **kw)
    if show:
        plt.show()
    return ax


[docs] def scatter( dimredobj: np.ndarray, labels_file: Union[str, pd.DataFrame], abbreviation_inside_dots: bool = True, arrows_for_titles: bool = False, dots: bool = True, legend: bool = True, color_palette=None, show: bool = True, save_path: Optional[str] = None, *, label_mode: Optional[str] = None, style: str = "default", figsize: Optional[tuple[float, float]] = None, label_colors: Optional[Mapping[str, str]] = None, legend_outside: Optional[bool] = None, despine: Optional[bool] = None, axis_xlabel: Optional[str] = None, axis_ylabel: Optional[str] = None, point_size: Optional[float] = None, centroid_size: Optional[float] = None, point_alpha: Optional[float] = None, savefig_kwargs: Optional[dict] = None, equal_aspect: Optional[bool] = None, ) -> None: """ Plot a scatter with group centroids and optional label styling. Args: dimredobj: Object produced by a dimensionality-reduction step, e.g. :class:`~snputils.processing.maasmds.maasMDS`, :class:`~snputils.processing.mdpca.mdPCA`, or :class:`~snputils.processing.pca.PCA`. Must expose ``X_new_`` (``(n, 2)`` embedding) and ``samples_`` (identifiers aligned with embedding rows). labels_file (str or pandas.DataFrame): TSV path or in-memory table with columns ``indID`` and ``label``. abbreviation_inside_dots (bool): If True, show a short acronym inside each centroid marker. arrows_for_titles (bool): If True, draw arrows from text labels to centroids. dots (bool): If True, draw scatter points; if False, print coordinates and use text markers instead. legend (bool): If True, include a legend for group labels. color_palette (optional): Colormap or indexable color list; default palette is chosen automatically if None. show (bool, optional): If True, call ``plt.show()``; otherwise close the figure after saving. Default True. save_path (str, optional): If set, save the figure to this path (``plt.savefig``). Prefer ``.pdf`` or ``.svg`` for publication: dense scatter is rasterized at ``dpi`` (default 300) while axes and text stay vector. Bitmap formats (``.png``, ...) also default to that ``dpi``. Override via ``savefig_kwargs``. label_mode (str, optional): Overrides ``abbreviation_inside_dots``, ``arrows_for_titles``, and ``legend``. ``"legend"`` — legend plus abbreviations inside centroids. ``"acronym"`` — abbreviations inside centroids only. ``"arrow"`` — labels near centroids with ``adjustText`` arrows; best for many groups. ``None`` keeps the individual boolean flags. style (str): ``"default"`` — legacy appearance. ``"publication"`` — typography, despine, room for an outside legend, slightly larger markers, MDS-oriented axis labels. figsize (tuple, optional): Figure size in inches; chosen from ``style`` when None. label_colors (Mapping, optional): Map group labels (as in the TSV) to matplotlib color strings; unlisted labels use the palette. legend_outside (bool, optional): If True, place the legend outside the axes. Default True when ``style=="publication"``. despine (bool, optional): Hide top and right spines. Default True when ``style=="publication"``. axis_xlabel, axis_ylabel (str, optional): Axis labels; defaults depend on ``style``. point_size, centroid_size, point_alpha (float, optional): Override scatter sizes and point alpha. savefig_kwargs (dict, optional): Extra keyword arguments for ``plt.savefig`` when ``save_path`` is set. equal_aspect (bool, optional): If True, equal data aspect (typical for MDS/PCA). Default True when ``style="publication"``. Returns: None """ if style not in ("default", "publication"): raise ValueError(f"style must be 'default' or 'publication', got {style!r}") if label_mode is not None: _valid = ("legend", "acronym", "arrow") if label_mode not in _valid: raise ValueError(f"label_mode must be one of {_valid}, got {label_mode!r}") if label_mode == "legend": legend, abbreviation_inside_dots, arrows_for_titles = True, True, False elif label_mode == "acronym": legend, abbreviation_inside_dots, arrows_for_titles = False, True, False elif label_mode == "arrow": legend, abbreviation_inside_dots, arrows_for_titles = False, False, True pub = style == "publication" if legend_outside is None: legend_outside = pub if despine is None: despine = pub if equal_aspect is None: equal_aspect = pub if figsize is None: if arrows_for_titles: figsize = (16.0, 14.0) elif pub and legend_outside: figsize = (12.0, 8.0) else: figsize = (10.0, 8.0) if axis_xlabel is None: axis_xlabel = "MDS 1" if pub else "Component 1" if axis_ylabel is None: axis_ylabel = "MDS 2" if pub else "Component 2" if point_size is None: point_size = 42.0 if pub else 30.0 if centroid_size is None: centroid_size = 220.0 if pub else 300.0 if point_alpha is None: point_alpha = 0.72 if pub else 0.6 rc = PUBLICATION_RC if pub else {} savefig_kwargs = dict(savefig_kwargs or {}) if pub and save_path and "bbox_inches" not in savefig_kwargs: savefig_kwargs["bbox_inches"] = "tight" if pub and save_path and "pad_inches" not in savefig_kwargs: savefig_kwargs["pad_inches"] = 0.08 if save_path: for k, v in default_savefig_kwargs(str(save_path)).items(): savefig_kwargs.setdefault(k, v) # Load labels from TSV or use an in-memory table. labels_df = labels_file.copy() if isinstance(labels_file, pd.DataFrame) else pd.read_csv(labels_file, sep="\t") # Ensure 'indID' is treated as a string labels_df["indID"] = labels_df["indID"].astype(str) # Filter labels based on the indIDs in dimredobj sample_ids = dimredobj.samples_ filtered_labels_df = labels_df[labels_df["indID"].isin(sample_ids)] # Define unique colors for each group label unique_labels = filtered_labels_df["label"].unique() n_labels = len(unique_labels) if color_palette is not None: _cmap = color_palette def cmap_fn(i: int): return _cmap(int(i)) else: _auto_colors = _generate_distinct_colors(n_labels) def cmap_fn(i: int): return _auto_colors[int(i) % len(_auto_colors)] with plt.rc_context(rc): fig, ax = plt.subplots(figsize=figsize) centroids = {} all_scatter_x: list[float] = [] all_scatter_y: list[float] = [] for i, label in enumerate(unique_labels): sample_ids_for_label = filtered_labels_df[filtered_labels_df["label"] == label]["indID"] points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)] c = _resolve_label_color(label, i, label_colors, cmap_fn) all_scatter_x.extend(points[:, 0].tolist()) all_scatter_y.extend(points[:, 1].tolist()) if dots: ec = "0.15" if pub else None lw = 0.25 if pub else 0.0 ax.scatter( points[:, 0], points[:, 1], s=point_size, color=c, alpha=point_alpha, label=label, edgecolors=ec if ec and lw else "none", linewidths=lw if lw else 0, rasterized=scatter_rasterized_for_path(save_path), ) else: for point in points: print(point[0], point[1]) ax.text( point[0], point[1], label[:2].upper(), ha="center", va="center", color=c, fontsize=8, weight="bold", ) centroid = points.mean(axis=0) centroids[label] = centroid ax.scatter( *centroid, color=c, s=centroid_size, edgecolors="none", linewidths=0, zorder=5, ) if abbreviation_inside_dots: ax.text( centroid[0], centroid[1], label[:2].upper(), ha="center", va="center", color="white", fontsize=8 if not pub else 9, weight="bold", zorder=6, ) texts = [] if arrows_for_titles: for label, centroid in centroids.items(): idx = unique_labels.tolist().index(label) c_arrow = _resolve_label_color(label, idx, label_colors, cmap_fn) texts.append( ax.text( centroid[0], centroid[1], label, color=c_arrow, fontsize=9 if pub else 10, weight="bold", zorder=7, ) ) if texts: target_x = [centroids[lbl][0] for lbl in centroids] target_y = [centroids[lbl][1] for lbl in centroids] adjust_text( texts, all_scatter_x, all_scatter_y, ax=ax, force_text=(0.4, 0.6), force_static=(0.3, 0.4), force_pull=(0.005, 0.01), force_explode=(0.2, 0.8), expand=(1.2, 1.4), max_move=(80, 80), explode_radius="auto", ensure_inside_axes=False, prevent_crossings=True, min_arrow_len=1, iter_lim=3000, target_x=target_x, target_y=target_y, arrowprops=dict( arrowstyle="->", color="gray", alpha=0.8, lw=1.0, mutation_scale=12, shrinkA=2, shrinkB=2, ), ) ax.set_xlabel(axis_xlabel) ax.set_ylabel(axis_ylabel) if equal_aspect: ax.set_aspect("equal", adjustable="box") if despine: ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if legend: legend_kw: dict = { "frameon": True, "fancybox": False, } if pub: legend_kw.update( { "framealpha": 0.96, "edgecolor": "#bfbfbf", "fontsize": 9, } ) if legend_outside: legend_kw.update( { "loc": "upper left", "bbox_to_anchor": (1.01, 1.0), "borderaxespad": 0.0, } ) ax.legend(**legend_kw) fig.subplots_adjust(right=0.74) else: legend_kw.setdefault("loc", "upper right") ax.legend(**legend_kw) if save_path: fig.savefig(save_path, **savefig_kwargs) if show: plt.show() else: plt.close(fig)