Source code for proteopy.pp.filtering

import warnings
from pathlib import Path
from typing import Callable
import numpy as np
import pandas as pd
import scipy.sparse as sp
import anndata as ad
from Bio import SeqIO

from proteopy.utils.functools import partial_with_docsig
from proteopy.utils.anndata import check_proteodata, is_proteodata


def filter_axis(
    adata,
    axis,
    min_fraction=None,
    min_count=None,
    group_by=None,
    zero_to_na=False,
    inplace=True,
):
    """
    Filter observations or variables based on non-missing value content.

    This function filters the AnnData object along a specified axis (observations
    or variables) based on the fraction or number of non-missing (np.nan) values.
    Filtering can be performed globally or within groups defined by the `group_by`
    parameter.

    Parameters
    ----------
    adata : anndata.AnnData
        The annotated data matrix to filter.
    axis : int
        The axis to filter on. `0` for observations, `1` for variables.
    min_fraction : float, optional
        The minimum fraction of non-missing values required to keep an observation
        or variable. If `group_by` is provided, this threshold is applied to the
        maximum completeness across all groups.
    min_count : int, optional
        The minimum number of non-missing values required to keep an observation
        or variable. If `group_by` is provided, this threshold is applied to the
        maximum count across all groups.
    group_by : str, optional
        A column key in `adata.obs` (if `axis=1`) or `adata.var` (if `axis=0`)
        used for grouping before applying the filter. The maximum completeness or
        count across the groups is used for filtering.
    zero_to_na : bool, optional
        If True, zeros in the data matrix are treated as missing values (NaN).
    inplace : bool, optional
        If True, modifies the `adata` object in place. Otherwise, returns a
        filtered copy.

    Returns
    -------
    anndata.AnnData or None
        If `inplace=False`, returns a new filtered AnnData object. Otherwise,
        returns `None`.

    Raises
    ------
    KeyError
        If the `group_by` key is not found in the corresponding annotation
        DataFrame.
    """
    check_proteodata(adata)

    if min_fraction is None and min_count is None:
        warnings.warn(
            "Neither `min_fraction` nor `min_count` were provided, so "
            "the function does nothing."
        )
        return None if inplace else adata.copy()

    X = adata.X.copy()
    if zero_to_na:
        if sp.issparse(X):
            X.data[X.data == 0] = np.nan
        else:
            X[X == 0] = np.nan

    if sp.issparse(X):
        X.eliminate_zeros()

    axis_i = 1 - axis
    axis_labels = adata.obs_names if axis == 0 else adata.var_names
    completeness = None # assigned below when min_fraction is set

    if group_by is not None:
        metadata = adata.obs if axis == 1 else adata.var
        if group_by not in metadata.columns:
            raise KeyError(
                f'`group_by`="{group_by}" not present in '
                f'adata.{"obs" if axis == 1 else "var"}'
            )
        grouping = metadata[group_by]
        unique_groups = grouping.dropna().unique()

        counts_by_group = []
        completeness_by_group = []
        for label in unique_groups:
            mask = (grouping == label).values
            subset = X[mask, :] if axis == 1 else X[:, mask]

            if subset.shape[axis_i] == 0:
                continue

            group_size = subset.shape[axis_i]

            if sp.issparse(subset):
                group_counts = subset.getnnz(axis=axis_i)
            else:
                group_counts = np.count_nonzero(~np.isnan(subset), axis=axis_i)

            df_counts = pd.DataFrame(group_counts, index=axis_labels)
            counts_by_group.append(df_counts)
            if min_fraction is not None:
                df_completeness = df_counts / group_size
                completeness_by_group.append(df_completeness)

        if not counts_by_group:
            counts = pd.Series(0, index=axis_labels, dtype=float)
        else:
            counts = pd.concat(counts_by_group, axis=1).max(axis=1)
        if min_fraction is not None:
            if not completeness_by_group:
                completeness = pd.Series(0, index=axis_labels, dtype=float)
            else:
                completeness = pd.concat(completeness_by_group, axis=1).max(axis=1)
    else:
        if sp.issparse(X):
            counts = pd.Series(X.getnnz(axis=axis_i), index=axis_labels)
        else:
            counts = pd.Series(
                np.count_nonzero(~np.isnan(X), axis=axis_i), index=axis_labels
            )
        if min_fraction is not None:
            num_total = adata.shape[axis_i]
            completeness = counts / num_total

    mask_filt = pd.Series(True, index=axis_labels)
    if min_fraction is not None:
        mask_filt &= completeness >= min_fraction

    if min_count is not None:
        mask_filt &= counts >= min_count

    n_removed = (~mask_filt).sum()
    axis_name = ["obs", "var"][axis]
    print(f"{n_removed} {axis_name} removed")

    if inplace:
        if axis == 0:
            adata._inplace_subset_obs(mask_filt.values)
        else:
            adata._inplace_subset_var(mask_filt.values)
        check_proteodata(adata)
        return None
    else:
        adata_filtered = adata[mask_filt, :] if axis == 0 else adata[:, mask_filt]
        check_proteodata(adata_filtered)
        return adata_filtered


docstr_header = """
Filter observations based on non-missing value content.

This function filters the AnnData object along the `obs` axis based on the
fraction or number of non-missing values (np.nan). Filtering can be performed
globally or within groups defined by the `group_by` parameter.
"""
filter_samples = partial_with_docsig(
    filter_axis,
    axis=0,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter observations based on data completeness.

This function filters the AnnData object along a the obs axis based on the
fraction of non-missing values (np.nan). Filtering can be performed globally
or within groups defined by the `group_by` parameter.
"""
filter_samples_completeness = partial_with_docsig(
    filter_axis,
    axis=0,
    min_count=None,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter variables based on non-missing value content.

This function filters the AnnData object along the `var` axis based on the
fraction or number of non-missing values (np.nan). Filtering can be performed
globally or within groups defined by the `group_by` parameter.
"""
filter_var = partial_with_docsig(
    filter_axis,
    axis=1,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter variables based on data completeness.

This function filters the AnnData object along a the var axis based on the
fraction of non-missing values (np.nan). Filtering can be performed globally
or within groups defined by the `group_by` parameter.
"""
filter_var_completeness = partial_with_docsig(
    filter_axis,
    axis=1,
    min_count=None,
    docstr_header=docstr_header,
    )


[docs] def filter_proteins_by_peptide_count( adata, min_count=None, max_count=None, protein_col="protein_id", inplace=True, ): """ Filter proteins by their peptide count. Parameters ---------- adata : anndata.AnnData Annotated data matrix with a protein identifier column in ``adata.var``. min_count : int or None, optional Keep peptides whose proteins have at least this many peptides. max_count : int or None, optional Keep peptides whose proteins have at most this many peptides. protein_col : str, optional (default: "protein_id") Column in ``adata.var`` containing protein identifiers. inplace : bool, optional (default: True) If True, modify ``adata`` in place. Otherwise, return a filtered view. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData view. """ check_proteodata(adata) if is_proteodata(adata)[1] != "peptide": raise ValueError(( "`AnnData` object must be in ProteoData peptide format." )) if min_count is None and max_count is None: warnings.warn("Pass at least one argument: min_count | max_count") adata_copy = None if inplace else adata.copy() if adata_copy is not None: check_proteodata(adata_copy) return adata_copy if min_count is not None: if min_count < 0: raise ValueError("`min_count` must be non-negative.") if max_count is not None: if max_count < 0: raise ValueError("`max_count` must be non-negative.") if (min_count is not None and max_count is not None) and (min_count > max_count): raise ValueError("`min_count` cannot be greater than `max_count`.") if protein_col not in adata.var.columns: raise KeyError(f"`protein_col`='{protein_col}' not found in adata.var") proteins = adata.var[protein_col] counts = proteins.value_counts() keep_mask = pd.Series(True, index=counts.index) if min_count is not None: keep_mask &= counts >= min_count if max_count is not None: keep_mask &= counts <= max_count protein_ids_keep = counts.index[keep_mask] var_keep_mask = proteins.isin(protein_ids_keep) if inplace: adata._inplace_subset_var(var_keep_mask.values) check_proteodata(adata) n_proteins_removed = len(counts.index) - len(protein_ids_keep) n_peptides_removed = int((~var_keep_mask).sum()) print( f"Removed {n_proteins_removed} proteins and " f"{n_peptides_removed} peptides." ) return None else: new_adata = adata[:, var_keep_mask] check_proteodata(new_adata) n_proteins_removed = len(counts.index) - len(protein_ids_keep) n_peptides_removed = int((~var_keep_mask).sum()) print( f"Removed {n_proteins_removed} proteins and " f"{n_peptides_removed} peptides." ) return new_adata
[docs] def filter_samples_by_category_count( adata, category_col, min_count=None, max_count=None, inplace=True, ): """ Filter observations by the frequency of their category value in a ``.vars`` metadata column. Parameters ---------- adata : anndata.AnnData Annotated data matrix. category_col : str Column in ``adata.obs`` containing the categories to count. min_count : int or None, optional Keep categories with at least this many observations. max_count : int or None, optional Keep categories with at most this many observations. inplace : bool, optional (default: True) If True, modify ``adata`` in place. Otherwise, return a filtered copy. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData. """ check_proteodata(adata) if min_count is None and max_count is None: raise ValueError( "At least one argument must be passed: min_count | max_count" ) if min_count is not None and min_count < 0: raise ValueError("`min_count` must be non-negative.") if max_count is not None and max_count < 0: raise ValueError("`max_count` must be non-negative.") if ( min_count is not None and max_count is not None and min_count > max_count ): raise ValueError("`min_count` cannot be greater than `max_count`.") if category_col not in adata.obs.columns: raise KeyError(f"`category_col`='{category_col}' not found in adata.obs") obs_series = adata.obs[category_col] counts = obs_series.value_counts(dropna=False) counts_filt = counts if min_count is not None: counts_filt = counts_filt[counts_filt >= min_count] if max_count is not None: counts_filt = counts_filt[counts_filt <= max_count] obs_keep_mask = obs_series.isin(counts_filt.index) removed = int((~obs_keep_mask).sum()) print(f"Removed {removed} observations.") if inplace: adata._inplace_subset_obs(obs_keep_mask.values) check_proteodata(adata) return None new_adata = adata[obs_keep_mask, :].copy() check_proteodata(new_adata) return new_adata
def _validate_remove_zero_variance_vars_input( adata, group_by, atol, inplace, verbose, ): """Validate inputs for ``remove_zero_variance_vars``.""" if not isinstance(adata, ad.AnnData): raise TypeError( f"`adata` must be an AnnData object, " f"got {type(adata).__name__}." ) if group_by is not None and not isinstance(group_by, str): raise TypeError( f"`group_by` must be a string or None, " f"got {type(group_by).__name__}." ) if not isinstance(atol, (int, float)): raise TypeError( f"`atol` must be a numeric value, " f"got {type(atol).__name__}." ) if atol < 0: raise ValueError("`atol` must be non-negative.") if not isinstance(inplace, bool): raise TypeError( f"`inplace` must be a bool, " f"got {type(inplace).__name__}." ) if not isinstance(verbose, bool): raise TypeError( f"`verbose` must be a bool, " f"got {type(verbose).__name__}." ) if group_by is not None: if group_by not in adata.obs.columns: raise KeyError( f"`group_by`='{group_by}' not found " f"in adata.obs" ) if adata.obs[group_by].isna().any(): raise ValueError( f"`group_by`='{group_by}' column in " f"adata.obs contains NaN values." )
[docs] def remove_zero_variance_vars( adata, group_by=None, atol=1e-8, inplace=True, verbose=False, ): """ Remove variables with near-zero or zero variance, skipping NaN values. Variables whose variance is at or below ``atol`` are removed. Variables that are entirely NaN — globally or within any group when ``group_by`` is set — are treated as zero variance and also removed. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` annotated data matrix. group_by : str | None, optional Column in ``adata.obs`` to compute variance per group. When set, a variable is removed if its variance is ``<= atol`` or all-NaN in *any* group. atol : float, optional Absolute tolerance threshold. Variables with variance ``<= atol`` are considered zero-variance and removed. inplace : bool, optional Modify ``adata`` in place and return ``None``. Otherwise, returns a filtered copy. verbose : bool, optional Print how many variables were present, removed, and remaining. Returns ------- None | AnnData ``None`` when ``inplace=True``; a new :class:`anndata.AnnData` containing only variables with variance ``> atol`` otherwise. Raises ------ TypeError If any argument has an incorrect type. ValueError If ``atol`` is negative or the ``group_by`` column contains NaN values. KeyError If ``group_by`` is not a column in ``adata.obs``. Warns ----- UserWarning Raised when one or more variables are removed because they are entirely NaN (globally or within at least one group). Examples -------- Build a small protein-level dataset with four variables: ``p1`` varies, ``p2`` is constant, ``p3`` is all-NaN, and ``p4`` varies. >>> import numpy as np >>> import pandas as pd >>> import anndata as ad >>> import proteopy as pr >>> X = np.array([ ... [1.0, 5.0, np.nan, 7.0], ... [2.0, 5.0, np.nan, 7.0], ... [3.0, 5.0, np.nan, 8.0], ... ]) >>> obs = pd.DataFrame( ... {"sample_id": ["s1", "s2", "s3"]}, ... index=["s1", "s2", "s3"], ... ) >>> var = pd.DataFrame( ... {"protein_id": ["p1", "p2", "p3", "p4"]}, ... index=["p1", "p2", "p3", "p4"], ... ) >>> adata = ad.AnnData(X=X, obs=obs, var=var) >>> pr.pp.remove_zero_variance_vars(adata) >>> adata.var_names.tolist() ['p1', 'p4'] With ``group_by``, a variable is removed if it has zero variance or is all-NaN in *any* group. Here ``p2`` is constant in group A, ``p3`` is all-NaN in group A, and ``p4`` is constant in both groups: >>> X_grp = np.array([ ... [1.0, 5.0, np.nan, 9.0], ... [2.0, 5.0, np.nan, 9.0], ... [3.0, 7.0, 8.0, 9.0], ... [4.0, 8.0, 8.0, 9.0], ... ]) >>> obs_grp = pd.DataFrame( ... {"sample_id": ["s1", "s2", "s3", "s4"], ... "group": ["A", "A", "B", "B"]}, ... index=["s1", "s2", "s3", "s4"], ... ) >>> adata = ad.AnnData(X=X_grp, obs=obs_grp, var=var) >>> pr.pp.remove_zero_variance_vars(adata, group_by="group") >>> adata.var_names.tolist() ['p1'] """ _validate_remove_zero_variance_vars_input( adata, group_by, atol, inplace, verbose, ) check_proteodata(adata) X = adata.X n_vars = adata.n_vars is_sparse = sp.issparse(X) keep_mask = np.ones(n_vars, dtype=bool) n_allnan = 0 if group_by is None: X_full = X.toarray() if is_sparse else np.asarray(X) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) var_all = np.nanvar(X_full, axis=0, ddof=0) allnan_mask = np.isnan(var_all) # np.nanvar(all_nan_vector) == np.nan n_allnan = int(allnan_mask.sum()) keep_mask &= (var_all > atol) & ~allnan_mask else: groups = adata.obs[group_by].astype("category") zero_any = np.zeros(n_vars, dtype=bool) allnan_any = np.zeros(n_vars, dtype=bool) for g in groups.cat.categories: idx = np.where(groups.values == g)[0] if idx.size == 0: continue Xg = X[idx, :] Xg_arr = ( Xg.toarray() if sp.issparse(Xg) else np.asarray(Xg) ) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) vg = np.nanvar(Xg_arr, axis=0, ddof=0) allnan_g = np.isnan(vg) # np.nanvar(all_nan_vector) == np.nan allnan_any |= allnan_g zero_any |= (vg <= atol) & ~allnan_g n_allnan = int(allnan_any.sum()) keep_mask &= ~zero_any & ~allnan_any if n_allnan > 0: warnings.warn( f"{n_allnan} variable(s) contained only NaN values" f"{' in at least one group' if group_by else ''}" " and were treated as zero variance.", UserWarning, stacklevel=2, ) removed = int((~keep_mask).sum()) if verbose: remaining = int(keep_mask.sum()) print( f"{n_vars} variables present, " f"{removed} removed, " f"{remaining} remaining." ) if inplace: adata._inplace_subset_var(keep_mask) check_proteodata(adata) return None else: new_adata = adata[:, keep_mask].copy() check_proteodata(new_adata) return new_adata
[docs] def remove_contaminants( adata, contaminant_path, protein_key="protein_id", header_parser: Callable[[str], str] | None = None, inplace=True, ): """ Remove variables whose protein identifier matches a contaminant FASTA entry. Parameters ---------- adata : anndata.AnnData Annotated data. contaminant_path : str | Path Path to the contaminant list. The file can be in FASTA format, in which case the headers are parsed to extract the contaminant ids (see param: header_parser); or tabular format TSV/CSV files, in which case the first column is extracted as contaminant ids.. protein_key : str, optional (default: "protein_id") Column in ``adata.var`` containing protein identifiers to match. header_parser : callable, optional Function to extract protein IDs from FASTA headers. Defaults to splitting the header on ``"|"`` and returning the second element, falling back to the full header if not present. inplace : bool, optional (default: False) If True, modify ``adata`` in place. Otherwise, return a filtered view. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData view. """ check_proteodata(adata) if header_parser is None: def header_parser(header: str) -> str: parts = header.split("|") return parts[1] if len(parts) > 1 else header def _load_contaminant_ids_from_fasta(fasta_path: Path) -> set[str]: contaminant_ids = set() for record in SeqIO.parse(fasta_path, "fasta"): parsed = header_parser(record.id) if parsed == "": warnings.warn( f"Header parser returned empty ID for record '{record.id}'.", ) continue contaminant_ids.add(parsed) return contaminant_ids def _load_contaminant_ids_from_table(table_path: Path, sep: str) -> set[str]: series = pd.read_csv(table_path, sep=sep, usecols=[0]).iloc[:, 0] series = series.dropna().astype(str) return set(series.tolist()) cont_path = Path(contaminant_path) if not cont_path.exists(): raise FileNotFoundError(f"Contaminant file not found at {cont_path}") if protein_key not in adata.var.columns: raise KeyError(f"`protein_key`='{protein_key}' not found in adata.var") suffix = cont_path.suffix.lower() match suffix: case ".fasta" | ".fa" | ".faa": contaminant_ids = _load_contaminant_ids_from_fasta(cont_path) case ".csv": contaminant_ids = _load_contaminant_ids_from_table(cont_path, ",") case ".tsv": contaminant_ids = _load_contaminant_ids_from_table(cont_path, "\t") case _: raise ValueError( "Unsupported contaminant file type. Use FASTA (.fasta/.fa/.faa), " "CSV (.csv), or TSV (.tsv).", ) proteins = adata.var[protein_key] keep_mask = ~proteins.isin(contaminant_ids) _, level = is_proteodata(adata) if level == "peptide": removed_peptides = int((~keep_mask).sum()) removed_proteins = int(proteins[~keep_mask].nunique()) print( f"Removed {removed_peptides} contaminating peptides and " f"{removed_proteins} contaminating proteins.", ) elif level == "protein": removed_proteins = int((~keep_mask).sum()) print(f"Removed {removed_proteins} contaminating proteins.") else: removed = int((~keep_mask).sum()) print(f"Removed {removed} contaminating variables.") if inplace: adata._inplace_subset_var(keep_mask.values) check_proteodata(adata) return None new_adata = adata[:, keep_mask] check_proteodata(new_adata) return new_adata