diff --git a/docs/release-notes/4037.feat.md b/docs/release-notes/4037.feat.md new file mode 100644 index 0000000000..a84539ecec --- /dev/null +++ b/docs/release-notes/4037.feat.md @@ -0,0 +1 @@ +Add `exp_post_agg` argument to {func}`scanpy.tl.rank_genes_groups` for customizing how log-fold-change is calculated {user}`ilan-gold` diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index 697c55e765..9a6e0488e6 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -63,6 +63,7 @@ class PcaPreset(NamedTuple): class RankGenesGroupsPreset(NamedTuple): method: DETest mask_var: str | None + exp_post_agg: bool class ScalePreset(NamedTuple): @@ -167,9 +168,11 @@ def pca() -> Mapping[Preset, PcaPreset]: def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]: """Correlation method for :func:`~scanpy.tl.rank_genes_groups`.""" return { - Preset.ScanpyV1: RankGenesGroupsPreset(method="t-test", mask_var=None), + Preset.ScanpyV1: RankGenesGroupsPreset( + method="t-test", mask_var=None, exp_post_agg=True + ), Preset.ScanpyV2Preview: RankGenesGroupsPreset( - method="wilcoxon", mask_var=None + method="wilcoxon", mask_var=None, exp_post_agg=False ), } diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 02d2e3ebad..aec02c4b2a 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -201,7 +201,7 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats(self) -> None: + def _basic_stats(self, *, exponentiate_values: bool = False) -> None: """Set self.{means,vars,pts}{,_rest} depending on X.""" n_genes = self.X.shape[1] n_groups = self.groups_masks_obs.shape[0] @@ -217,6 +217,8 @@ def _basic_stats(self) -> None: else: mask_rest = self.groups_masks_obs[self.ireference] x_rest = self.X[mask_rest] + if exponentiate_values: + x_rest = self.expm1_func(x_rest) self.means[self.ireference], self.vars[self.ireference] = mean_var( x_rest, axis=0, correction=1 ) @@ -230,6 +232,8 @@ def _basic_stats(self) -> None: for group_index, mask_obs in enumerate(self.groups_masks_obs): x_mask = self.X[mask_obs] + if exponentiate_values: + x_mask = self.expm1_func(x_mask) if self.comp_pts: self.pts[group_index] = get_nonzeros(x_mask) / x_mask.shape[0] @@ -244,6 +248,8 @@ def _basic_stats(self) -> None: if self.ireference is None: mask_rest = ~mask_obs x_rest = self.X[mask_rest] + if exponentiate_values: + x_rest = self.expm1_func(x_rest) ( self.means_rest[group_index], self.vars_rest[group_index], @@ -259,8 +265,6 @@ def t_test( ) -> Generator[tuple[int, NDArray[np.floating], NDArray[np.floating]], None, None]: from scipy import stats - self._basic_stats() - for group_index, (mask_obs, mean_group, var_group) in enumerate( zip(self.groups_masks_obs, self.means, self.vars, strict=True) ): @@ -312,8 +316,6 @@ def wilcoxon( ) -> Generator[tuple[int, NDArray[np.floating], NDArray[np.floating]], None, None]: from scipy import stats - self._basic_stats() - n_genes = self.X.shape[1] # First loop: Loop over all genes if self.ireference is not None: @@ -429,12 +431,16 @@ def compute_statistics( # noqa: PLR0912 n_genes_user: int | None = None, rankby_abs: bool = False, tie_correct: bool = False, + exp_post_agg: bool = True, **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: + self._basic_stats(exponentiate_values=False) generate_test_results = self.t_test(method) elif method == "wilcoxon": generate_test_results = self.wilcoxon(tie_correct=tie_correct) + # If we're not exponentiating after the mean aggregation, then do it now. + self._basic_stats(exponentiate_values=not exp_post_agg) elif method == "logreg": generate_test_results = self.logreg(**kwds) @@ -481,9 +487,12 @@ def compute_statistics( # noqa: PLR0912 mean_rest = self.means_rest[group_index] else: mean_rest = self.means[self.ireference] - foldchanges = (self.expm1_func(mean_group) + 1e-9) / ( - self.expm1_func(mean_rest) + 1e-9 - ) # add small value to remove 0's + if exp_post_agg: + foldchanges = (self.expm1_func(mean_group) + 1e-9) / ( + self.expm1_func(mean_rest) + 1e-9 + ) # add small value to remove 0's + else: + foldchanges = (mean_group + 1e-9) / (mean_rest + 1e-9) self.stats[group_name, "logfoldchanges"] = np.log2( foldchanges[global_indices] ) @@ -511,6 +520,7 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 corr_method: _CorrMethod = "benjamini-hochberg", tie_correct: bool = False, layer: str | None = None, + exp_post_agg: bool = Default(preset=("rank_genes_groups", "exp_post_agg")), **kwds, ) -> AnnData | None: """Rank genes for characterizing groups. @@ -574,6 +584,8 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 The key in `adata.uns` information is saved to. copy Whether to copy `adata` or modify it inplace. + exp_post_agg + Whether to do log(mean(exp(values))) (`False`) or log(exp(mean(values))) (`True`) kwds Are passed to test methods. Currently this affects only parameters that are passed to :class:`sklearn.linear_model.LogisticRegression`. @@ -626,6 +638,8 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 if isinstance(mask_var, Default): mask_var = settings.preset.rank_genes_groups.mask_var + if isinstance(exp_post_agg, Default): + exp_post_agg = settings.preset.rank_genes_groups.exp_post_agg if method is None or isinstance(method, Default): method = settings.preset.rank_genes_groups.method @@ -714,6 +728,7 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 n_genes_user=n_genes_user, rankby_abs=rankby_abs, tie_correct=tie_correct, + exp_post_agg=exp_post_agg, **kwds, ) diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index ba38ffc94d..6f43e1e1de 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -311,3 +311,42 @@ def test_mask_not_equal(): with_mask = pbmc.uns["rank_genes_groups"]["names"] assert not np.array_equal(no_mask, with_mask) + + +@pytest.mark.parametrize( + ("exp_post_agg", "expected_logfc"), + [ + # exp after agg: log2(expm1(mean_log_a) / expm1(mean_log_b)) + # = log2(expm1(ln(9) * 5 / 10) / expm1(ln9)) = log2(2 / 8) = -2.0 + (True, -2.0), + # exp before agg: log2(mean(expm1(linear_a)) / mean(expm1(linear_b))) + # = log2(mean([0] * 5 + [8] * 5) / mean([8] * 10)) = log2(4 / 8) = -1.0 + (False, -1.0), + ], +) +def test_exp_post_agg( + expected_logfc: float, + *, + exp_post_agg: bool, +): + # group_a: 5 cells with log-space value 0, 5 cells with log(9) + # group_b: 10 cells all with log(9) (used as reference) + n_genes = 5 + group_a = np.zeros((10, n_genes)) + group_a[5:] = np.log(9) + group_b = np.full((10, n_genes), np.log(9)) + adata = AnnData( + X=np.concatenate([group_a, group_b]), + obs={"bulk_labels": ["a"] * 10 + ["b"] * 10}, + ) + + rank_genes_groups( + adata, + groupby="bulk_labels", + groups=["a"], + reference="b", + method="wilcoxon", + exp_post_agg=exp_post_agg, + ) + logfcs = adata.uns["rank_genes_groups"]["logfoldchanges"]["a"] + np.testing.assert_equal(logfcs, expected_logfc)