Source code for proteopy.utils.stat_tests

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
import warnings

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from adjustText import adjust_text


def _validate_numeric_inputs(fc_vals, pvals):
    try:
        fc_vals = np.asarray(fc_vals, dtype=float)
    except (ValueError, TypeError):
        raise ValueError(
            "fc_vals must contain numeric values."
        )
    try:
        pvals = np.asarray(pvals, dtype=float)
    except (ValueError, TypeError):
        raise ValueError(
            "pvals must contain numeric values."
        )

    if fc_vals.ndim != 1:
        raise ValueError("fc_vals must be 1D.")
    if pvals.ndim != 1:
        raise ValueError("pvals must be 1D.")

    if fc_vals.shape != pvals.shape:
        raise ValueError(
            "fc_vals and pvals must have the same length."
        )

    return fc_vals, pvals


def _validate_highlight_labels(highlight_labels):
    if len(highlight_labels) != len(set(highlight_labels)):
        raise ValueError(
            "highlight_labels must not contain "
            "duplicates."
        )
    if len(highlight_labels) == 0:
        warnings.warn(
            "highlight_labels is empty.",
            UserWarning,
        )
    if not np.issubdtype(
        np.asarray(highlight_labels).dtype, np.str_
    ):
        raise ValueError(
            "highlight_labels must contain "
            "string values."
        )


def _validate_label_inputs(
    labels, top_labels, highlight_labels, n_points,
):
    if (
        top_labels is not None
        and highlight_labels is not None
    ):
        raise ValueError(
            "top_labels and highlight_labels are "
            "mutually exclusive."
        )

    if (
        labels is None
        and (top_labels is not None
             or highlight_labels is not None)
    ):
        raise ValueError(
            "labels must be provided when "
            "top_labels or highlight_labels is set."
        )

    if top_labels is not None:
        if (
            not isinstance(top_labels, int)
            or top_labels <= 0
        ):
            raise ValueError(
                "top_labels must be a positive integer."
            )

    if highlight_labels is not None:
        _validate_highlight_labels(highlight_labels)

    if labels is not None:
        labels = np.asarray(labels)
        if not np.issubdtype(labels.dtype, np.str_):
            raise ValueError(
                "labels must contain string values."
            )
        if labels.shape[0] != n_points:
            raise ValueError(
                "labels must have the same length as "
                "fc_vals."
            )

    return labels


def _validate_alt_color(alt_color, n_points):
    if alt_color is None:
        return None
    alt_color = np.asarray(alt_color)
    if alt_color.ndim != 1:
        raise ValueError(
            "alt_color must be a 1D boolean sequence."
        )
    if alt_color.shape[0] != n_points:
        raise ValueError(
            "alt_color must have the same length as "
            "fc_vals."
        )
    if not np.issubdtype(alt_color.dtype, np.bool_):
        raise ValueError("alt_color must be boolean.")
    return alt_color


def _validate_thresholds(fc_thresh, pval_thresh):
    if fc_thresh is not None and fc_thresh <= 0:
        raise ValueError(
            "fc_thresh must be a positive number."
        )
    if pval_thresh is not None:
        if pval_thresh < 0 or pval_thresh > 1:
            raise ValueError(
                "pval_thresh must be in [0, 1]."
            )


def _filter_volcano_data(fc_vals, pvals, labels, alt_color):
    # Drop NaN
    nan_mask = np.isnan(fc_vals) | np.isnan(pvals)
    if nan_mask.any():
        warnings.warn(
            "Dropping entries with NaN fold changes or "
            "p-values.",
            RuntimeWarning,
        )

    # Drop non-finite (inf, -inf)
    nonfinite_mask = (
        ~np.isfinite(fc_vals) | ~np.isfinite(pvals)
    )
    inf_only = nonfinite_mask & ~nan_mask
    if inf_only.any():
        warnings.warn(
            "Dropping entries with non-finite fold changes "
            "or p-values.",
            RuntimeWarning,
        )

    # Drop non-positive p-values
    nonpos_mask = pvals <= 0
    nonpos_new = nonpos_mask & ~nonfinite_mask
    if nonpos_new.any():
        warnings.warn(
            "Dropping non-positive p-values before log "
            "transform.",
            RuntimeWarning,
        )

    keep = ~(nonfinite_mask | nonpos_mask)
    fc_vals = fc_vals[keep]
    pvals = pvals[keep]
    if labels is not None:
        labels = labels[keep]
    if alt_color is not None:
        alt_color = alt_color[keep]

    if len(fc_vals) == 0:
        raise ValueError(
            "No valid results available for plotting."
        )

    return fc_vals, pvals, labels, alt_color


def _draw_scatter(
    _ax,
    fc_vals,
    y_vals,
    up_mask,
    down_mask,
    other_mask,
    alt_color,
):
    if alt_color is None:
        _ax.scatter(
            fc_vals[other_mask],
            y_vals[other_mask],
            color="grey",
            alpha=0.5,
            s=12,
        )
        _ax.scatter(
            fc_vals[down_mask],
            y_vals[down_mask],
            color="#1f77b4",
            alpha=0.8,
            s=14,
        )
        _ax.scatter(
            fc_vals[up_mask],
            y_vals[up_mask],
            color="#d62728",
            alpha=0.8,
            s=14,
        )
    else:
        _ax.scatter(
            fc_vals[~alt_color],
            y_vals[~alt_color],
            color="grey",
            alpha=0.5,
            s=12,
        )
        _ax.scatter(
            fc_vals[alt_color],
            y_vals[alt_color],
            color="#8E54E5",
            alpha=0.8,
            s=14,
        )


def _draw_threshold_lines(_ax, fc_thresh, pval_thresh, yscale_log):
    if fc_thresh is not None:
        _ax.axvline(
            fc_thresh,
            color="black",
            linestyle="--",
            linewidth=1,
        )
        _ax.axvline(
            -fc_thresh,
            color="black",
            linestyle="--",
            linewidth=1,
        )
    if yscale_log:
        if pval_thresh is not None:
            _ax.axhline(
                pval_thresh,
                color="black",
                linestyle="--",
                linewidth=1,
            )
        _ax.set_yscale("log", base=10)
        _ax.invert_yaxis()
    else:
        if pval_thresh is not None:
            _ax.axhline(
                -np.log10(pval_thresh),
                color="black",
                linestyle="--",
                linewidth=1,
            )


def _annotate_top_labels(
    _ax,
    fc_vals,
    pvals,
    y_vals,
    labels,
    top_labels,
    sig_mask,
    fc_thresh,
):
    abs_fc = np.abs(fc_vals)
    label_mask = (
        sig_mask
        if fc_thresh is None
        else sig_mask & (abs_fc >= fc_thresh)
    )
    idx = np.where(label_mask)[0]

    if len(idx) > 0:
        lbl_fc = fc_vals[idx]
        lbl_pv = pvals[idx]
        lbl_abs = abs_fc[idx]
        lbl_labels = labels[idx]
        lbl_y = y_vals[idx]

        # Positive side
        pos_idx = np.where(lbl_fc >= 0)[0]
        if len(pos_idx) > 0:
            order = np.lexsort(
                (-lbl_abs[pos_idx], lbl_pv[pos_idx]),
            )
            pos_sel = pos_idx[order[:top_labels]]
        else:
            pos_sel = np.array([], dtype=int)

        # Negative side
        neg_idx = np.where(lbl_fc < 0)[0]
        if len(neg_idx) > 0:
            order = np.lexsort(
                (-lbl_abs[neg_idx], lbl_pv[neg_idx]),
            )
            neg_sel = neg_idx[order[:top_labels]]
        else:
            neg_sel = np.array([], dtype=int)

        sel = np.concatenate([pos_sel, neg_sel])
        texts = []
        for i in sel:
            texts.append(
                _ax.text(
                    lbl_fc[i],
                    lbl_y[i],
                    str(lbl_labels[i]),
                    fontsize=8,
                )
            )
        if texts:
            adjust_text(
                texts,
                ax=_ax,
                arrowprops=dict(
                    arrowstyle="->",
                    color="0.4",
                    lw=0.7,
                ),
            )


def _annotate_highlight_labels(
    _ax,
    fc_vals,
    y_vals,
    labels,
    highlight_labels,
):
    hl_set = set(highlight_labels)
    hl_idx = np.where(
        np.isin(labels, list(hl_set))
    )[0]

    # Warn about missing labels
    found = set(labels[hl_idx])
    missing = hl_set - found
    if missing:
        warnings.warn(
            "highlight_labels not found after "
            f"filtering: {sorted(missing)}",
            RuntimeWarning,
        )

    if len(hl_idx) > 0:
        texts = []
        for i in hl_idx:
            texts.append(
                _ax.text(
                    fc_vals[i],
                    y_vals[i],
                    str(labels[i]),
                    fontsize=8,
                )
            )
        if texts:
            adjust_text(
                texts,
                ax=_ax,
                arrowprops=dict(
                    arrowstyle="->",
                    color="0.4",
                    lw=0.7,
                ),
            )


[docs] def volcano_plot( fc_vals: Sequence[float] | np.ndarray, pvals: Sequence[float] | np.ndarray, fc_thresh: float | None = None, pval_thresh: float | None = None, *, labels: Sequence[str] | np.ndarray | None = None, top_labels: int | None = None, highlight_labels: Sequence[str] | None = None, figsize: tuple[float, float] = (6.0, 5.0), xlabel: str | None = None, ylabel: str | None = None, alt_color: list[bool] | np.ndarray | None = None, yscale_log: bool = True, title: str | None = None, show: bool = True, save: str | Path | None = None, ax: Axes | None = None, ) -> Axes: """ Volcano plot renderer (framework-agnostic). Draws a scatter plot of fold change (x-axis) versus p-value (y-axis). Points are colored by significance or by an optional alternative boolean mask. Parameters ---------- fc_vals : Sequence[float] | np.ndarray Fold change values (x-axis). pvals : Sequence[float] | np.ndarray P-values (y-axis). Must be same length as ``fc_vals``. fc_thresh : float | None, optional Absolute fold change threshold for significance. When ``None``, the fold change requirement for significance coloring is dropped. Threshold line is not drawn. pval_thresh : float | None, optional P-value threshold for significance. When ``None``, the p-value requirement for significance coloring is dropped. Threshold line is not drawn. When both thresholds are ``None``, all points are colored as significant (blue for negative FC, red for positive FC). labels : Sequence[str] | np.ndarray | None, optional Labels for each point, same length as ``fc_vals``. Required when ``top_labels`` is set. top_labels : int | None, optional Number of top proteins to label per side (up to 2N total). Ranked by smallest p-value, then largest ``|fc|``. highlight_labels : Sequence[str] | None, optional Sequence of label strings to highlight on the plot. Each entry must match a value in ``labels``. Matched points are annotated with their label and a connecting arrow. Labels not found after filtering trigger a warning. figsize : tuple[float, float], optional Figure dimensions (width, height) in inches. xlabel : str | None, optional X-axis label. Defaults to ``"logFC"`` when ``None``. ylabel : str | None, optional Y-axis label. When ``None``, defaults to ``"pval"`` if ``yscale_log=True`` or ``"-log10(pval)"`` if ``yscale_log=False``. alt_color : list[bool] | np.ndarray | None, optional Boolean mask (same length as ``fc_vals``) for alternative coloring. ``True`` entries are colored purple, ``False`` gray. Overrides significance-based coloring. yscale_log : bool, optional When ``True``, plot raw p-values on a log10-scaled inverted y-axis. When ``False``, plot ``-log10(pval)`` on a linear y-axis. title : str | None, optional Plot title. If ``None``, no title is set. show : bool, optional Call ``matplotlib.pyplot.show()`` to display the plot. save : str | Path | None, optional File path to save the figure at 300 DPI. ax : matplotlib.axes.Axes | None, optional Matplotlib Axes to plot onto. If ``None``, a new figure and axes are created. Returns ------- Axes The Matplotlib Axes object used for plotting. Raises ------ ValueError If ``fc_vals`` or ``pvals`` are not 1D, contain non-numeric values, or have different lengths; if no valid data remains after filtering; if ``fc_thresh`` is not positive (when set); if ``pval_thresh`` is not in ``[0, 1]`` (when set); if ``top_labels`` is set but ``labels`` is ``None``, or is not a positive integer; if ``highlight_labels`` is set but ``labels`` is ``None``, contains duplicates, or contains non-string values; if both ``top_labels`` and ``highlight_labels`` are set; if ``labels`` contains non-string values; if ``ax`` is not a ``matplotlib.axes.Axes`` object (when set); if ``alt_color`` fails validation. Examples -------- Basic usage with lists: >>> from proteopy.utils.stat_tests import volcano_plot >>> fc = [-2.1, -0.5, 0.3, 1.8, 3.0] >>> pv = [0.001, 0.3, 0.5, 0.04, 0.0005] >>> volcano_plot(fc, pv, show=False) <Axes: ...> With fold change and p-value thresholds: >>> volcano_plot( ... fc, pv, ... fc_thresh=1.5, ... pval_thresh=0.05, ... show=False, ... ) <Axes: ...> With ``top_labels`` to annotate the most significant hits: >>> genes = ["GeneA", "GeneB", "GeneC", "GeneD", "GeneE"] >>> volcano_plot( ... fc, pv, ... fc_thresh=1.5, ... pval_thresh=0.05, ... labels=genes, ... top_labels=2, ... show=False, ... ) <Axes: ...> With ``highlight_labels`` to annotate specific proteins: >>> volcano_plot( ... fc, pv, ... labels=genes, ... highlight_labels=["GeneA", "GeneE"], ... show=False, ... ) <Axes: ...> """ fc_vals, pvals = _validate_numeric_inputs(fc_vals, pvals) _validate_thresholds(fc_thresh, pval_thresh) if ax is not None and not isinstance(ax, Axes): raise ValueError( "ax must be a matplotlib Axes object." ) labels = _validate_label_inputs( labels, top_labels, highlight_labels, fc_vals.shape[0], ) alt_color = _validate_alt_color( alt_color, fc_vals.shape[0], ) fc_vals, pvals, labels, alt_color = _filter_volcano_data( fc_vals, pvals, labels, alt_color, ) # -- Prepare plotting arrays neg_log_p = -np.log10(pvals) y_vals = pvals if yscale_log else neg_log_p sig_mask = ( np.ones(len(pvals), dtype=bool) if pval_thresh is None else pvals <= pval_thresh ) if fc_thresh is None: up_mask = sig_mask & (fc_vals > 0) down_mask = sig_mask & (fc_vals < 0) else: up_mask = sig_mask & (fc_vals >= fc_thresh) down_mask = sig_mask & (fc_vals <= -fc_thresh) other_mask = ~(up_mask | down_mask) # -- Plot if ax is None: fig, _ax = plt.subplots(figsize=figsize) else: _ax = ax fig = _ax.get_figure() _draw_scatter( _ax, fc_vals, y_vals, up_mask, down_mask, other_mask, alt_color, ) _draw_threshold_lines(_ax, fc_thresh, pval_thresh, yscale_log) # Axis labels _ax.set_xlabel("logFC" if xlabel is None else xlabel) if ylabel is None: ylabel = "pval" if yscale_log else "-log10(pval)" _ax.set_ylabel(ylabel) if title is not None: _ax.set_title(title) # -- Labels if top_labels is not None and labels is not None: _annotate_top_labels( _ax, fc_vals, pvals, y_vals, labels, top_labels, sig_mask, fc_thresh, ) if highlight_labels is not None and labels is not None: _annotate_highlight_labels( _ax, fc_vals, y_vals, labels, highlight_labels, ) if save: fig.savefig(save, dpi=300, bbox_inches="tight") if show: plt.show() return _ax