From b5370e0c6af8393c91391157c8843c4e61eb3b82 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Tue, 17 Jan 2023 11:27:52 +0100 Subject: [PATCH] Feature/spearman correlation (#444) * add spearman correlation * add tests * adapt tests --- src/moscot/problems/base/_mixins.py | 22 ++++++-- src/moscot/problems/base/_utils.py | 41 +++++++++------ tests/analysis_mixins/test_base_analysis.py | 50 +++++++++++++++---- tests/problems/generic/test_gw_problem.py | 2 +- .../problems/generic/test_sinkhorn_problem.py | 2 +- 5 files changed, 84 insertions(+), 33 deletions(-) diff --git a/src/moscot/problems/base/_mixins.py b/src/moscot/problems/base/_mixins.py index 3d28b3b9f..0cbd71d5f 100644 --- a/src/moscot/problems/base/_mixins.py +++ b/src/moscot/problems/base/_mixins.py @@ -33,7 +33,15 @@ _validate_args_cell_transition, _check_argument_compatibility_cell_transition, ) -from moscot._constants._constants import Key, AdataKeys, PlottingKeys, CorrTestMethod, AggregationMode, PlottingDefaults +from moscot._constants._constants import ( + Key, + AdataKeys, + CorrMethod, + PlottingKeys, + CorrTestMethod, + AggregationMode, + PlottingDefaults, +) from moscot.problems._subset_policy import SubsetPolicy from moscot.problems.base._compound_problem import B, K, ApplyOutput_t @@ -446,7 +454,8 @@ def _cell_aggregation_transition( def compute_feature_correlation( self: AnalysisMixinProtocol[K, B], obs_key: str, - method: Literal["fischer", "perm_test"] = CorrTestMethod.FISCHER, + corr_method: CorrMethod = CorrMethod.PEARSON, + significance_method: Literal["fischer", "perm_test"] = CorrTestMethod.FISCHER, annotation: Optional[Dict[str, Iterable[str]]] = None, layer: Optional[str] = None, features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, @@ -465,7 +474,9 @@ def compute_feature_correlation( ---------- obs_key Column of :attr:`anndata.AnnData.obs` containing push-forward or pull-back distributions. - method + corr_method + Which type of correlation to compute, options are `pearson`, and `spearman`. + significance_method Mode to use when calculating p-values and confidence intervals. Valid options are: - `fischer` - use Fischer transformation :cite:`fischer:21`. @@ -512,7 +523,7 @@ def compute_feature_correlation( if obs_key not in self.adata.obs: raise KeyError(f"Unable to access data in `adata.obs[{obs_key!r}]`.") - method = CorrTestMethod(method) + significance_method = CorrTestMethod(significance_method) if annotation is not None: annotation_key, annotation_vals = next(iter(annotation.items())) @@ -542,7 +553,8 @@ def compute_feature_correlation( X=sc.get.obs_df(adata, keys=features, layer=layer).values, Y=distribution, feature_names=features, - method=method, + corr_method=corr_method, + significance_method=significance_method, confidence_level=confidence_level, n_perms=n_perms, seed=seed, diff --git a/src/moscot/problems/base/_utils.py b/src/moscot/problems/base/_utils.py index 1a5254bd5..b200fa73b 100644 --- a/src/moscot/problems/base/_utils.py +++ b/src/moscot/problems/base/_utils.py @@ -4,7 +4,7 @@ import inspect import warnings -from scipy.stats import norm +from scipy.stats import norm, rankdata from scipy.sparse import issparse, spmatrix, csr_matrix, isspmatrix_csr from statsmodels.stats.multitest import multipletests import pandas as pd @@ -16,7 +16,7 @@ from moscot._types import ArrayLike, Str_Dict_t from moscot._logging import logger from moscot._docs._docs import d -from moscot._constants._constants import CorrTestMethod, AggregationMode +from moscot._constants._constants import CorrMethod, CorrTestMethod, AggregationMode __all__ = [ "attributedispatch", @@ -279,7 +279,8 @@ def _correlation_test( X: Union[ArrayLike, spmatrix], Y: pd.DataFrame, feature_names: Sequence[str], - method: CorrTestMethod = CorrTestMethod.FISCHER, + corr_method: CorrMethod = CorrMethod.PEARSON, + significance_method: CorrTestMethod = CorrTestMethod.FISCHER, confidence_level: float = 0.95, n_perms: Optional[int] = None, seed: Optional[int] = None, @@ -298,7 +299,9 @@ def _correlation_test( Data frame of shape ``(n_cells, 1)`` containing the pull/push distribution. feature_names Sequence of shape ``(n_features,)`` containing the feature names. - method + corr_method + Which type of correlation to compute, options are `pearson`, and `spearman`. + significance_method Method for p-value calculation. confidence_level Confidence level for the confidence interval calculation. Must be in `[0, 1]`. @@ -322,7 +325,8 @@ def _correlation_test( corr, pvals, ci_low, ci_high = _correlation_test_helper( X.T, Y.values, - method=method, + corr_method=corr_method, + significance_method=significance_method, n_perms=n_perms, seed=seed, confidence_level=confidence_level, @@ -363,7 +367,8 @@ def _correlation_test( def _correlation_test_helper( X: ArrayLike, Y: ArrayLike, - method: CorrTestMethod = CorrTestMethod.FISCHER, + corr_method: CorrMethod = CorrMethod.SPEARMAN, + significance_method: CorrTestMethod = CorrTestMethod.FISCHER, n_perms: Optional[int] = None, seed: Optional[int] = None, confidence_level: float = 0.95, @@ -378,7 +383,9 @@ def _correlation_test_helper( Array or matrix of `(M, N)` elements. Y Array of `(N, K)` elements. - method + corr_method + Which type of correlation to compute, options are `pearson`, and `spearman`. + significance_method Method for p-value calculation. n_perms Number of permutations if ``method='perm_test'``. @@ -415,11 +422,15 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr if issparse(X) and not isspmatrix_csr(X): X = csr_matrix(X) - corr = _mat_mat_corr_sparse(X, Y) if issparse(X) else _mat_mat_corr_dense(X, Y) - if method == CorrTestMethod.FISCHER: + if corr_method == CorrMethod.SPEARMAN: + X, Y = rankdata(X, method="average", axis=0), rankdata(Y, method="average", axis=0) + corr = _pearson_mat_mat_corr_sparse(X, Y) if issparse(X) else _pearson_mat_mat_corr_dense(X, Y) + + if significance_method == CorrTestMethod.FISCHER: # see: https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Using_the_Fisher_transformation - mean, se = np.arctanh(corr), 1.0 / np.sqrt(n - 3) + # for spearman see: https://www.sciencedirect.com/topics/mathematics/spearman-correlation + mean, se = np.arctanh(corr), 1 / np.sqrt(n - 3) z_score = (np.arctanh(corr) - np.arctanh(0)) * np.sqrt(n - 3) z = norm.ppf(qh) @@ -427,7 +438,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr corr_ci_high = np.tanh(mean + z * se) pvals = 2 * norm.cdf(-np.abs(z_score)) - elif method == CorrTestMethod.PERM_TEST: + elif significance_method == CorrTestMethod.PERM_TEST: if not isinstance(n_perms, int): raise TypeError(f"Expected `n_perms` to be an integer, found `{type(n_perms).__name__}`.") if n_perms <= 0: @@ -443,12 +454,12 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr )(corr, X, Y, seed=seed) else: - raise NotImplementedError(method) + raise NotImplementedError(significance_method) return corr, pvals, corr_ci_low, corr_ci_high -def _mat_mat_corr_sparse( +def _pearson_mat_mat_corr_sparse( X: csr_matrix, Y: ArrayLike, ) -> ArrayLike: @@ -464,7 +475,7 @@ def _mat_mat_corr_sparse( return (X @ Y - (n * X_bar * y_bar)) / ((n - 1) * X_std * y_std) -def _mat_mat_corr_dense(X: ArrayLike, Y: ArrayLike) -> ArrayLike: +def _pearson_mat_mat_corr_dense(X: ArrayLike, Y: ArrayLike) -> ArrayLike: from moscot._utils import np_std, np_mean n = X.shape[1] @@ -493,7 +504,7 @@ def _perm_test( pvals = np.zeros_like(corr, dtype=np.float64) corr_bs = np.zeros((len(ixs), X.shape[0], Y.shape[1])) # perms x genes x lineages - mmc = _mat_mat_corr_sparse if issparse(X) else _mat_mat_corr_dense + mmc = _pearson_mat_mat_corr_sparse if issparse(X) else _pearson_mat_mat_corr_dense for i, _ in enumerate(ixs): rs.shuffle(cell_ixs) diff --git a/tests/analysis_mixins/test_base_analysis.py b/tests/analysis_mixins/test_base_analysis.py index aa01cd3a7..3d63538d1 100644 --- a/tests/analysis_mixins/test_base_analysis.py +++ b/tests/analysis_mixins/test_base_analysis.py @@ -165,8 +165,14 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL ) - @pytest.mark.parametrize("method", ["fischer", "perm_test"]) - def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal["fischer", "perm_test"]): + @pytest.mark.parametrize("corr_method", ["pearson", "spearman"]) + @pytest.mark.parametrize("significance_method", ["fischer", "perm_test"]) + def test_compute_feature_correlation( + self, + adata_time: AnnData, + corr_method: Literal["pearson", "spearman"], + significance_method: Literal["fischer", "perm_test"], + ): key_added = "test" rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() @@ -180,7 +186,9 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal[ adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) - res = problem.compute_feature_correlation(obs_key=key_added, method=method) + res = problem.compute_feature_correlation( + obs_key=key_added, corr_method=corr_method, significance_method=significance_method + ) assert isinstance(res, pd.DataFrame) assert res.isnull().values.sum() == 0 @@ -191,10 +199,15 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal[ assert np.all(res[f"{key_added}_qval"] >= 0) assert np.all(res[f"{key_added}_qval"] <= 1.0) + @pytest.mark.parametrize("corr_method", ["pearson", "spearman"]) @pytest.mark.parametrize("features", [10, None]) @pytest.mark.parametrize("method", ["fischer", "perm_test"]) def test_compute_feature_correlation_subset( - self, adata_time: AnnData, features: Optional[int], method: Literal["fischer", "perm_test"] + self, + adata_time: AnnData, + features: Optional[int], + corr_method: Literal["pearson", "spearman"], + method: Literal["fischer", "perm_test"], ): key_added = "test" rng = np.random.RandomState(42) @@ -215,7 +228,11 @@ def test_compute_feature_correlation_subset( else: features_validation = list(adata_time.var_names) res = problem.compute_feature_correlation( - obs_key=key_added, annotation={"celltype": ["A"]}, method=method, features=features + obs_key=key_added, + annotation={"celltype": ["A"]}, + corr_method=corr_method, + significance_method=method, + features=features, ) assert isinstance(res, pd.DataFrame) assert res.isnull().values.sum() == 0 @@ -275,9 +292,15 @@ def test_seed_reproducible(self, adata_time: AnnData): adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) - res_a = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test") - res_b = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test") - res_c = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=1, method="perm_test") + res_a = problem.compute_feature_correlation( + obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test" + ) + res_b = problem.compute_feature_correlation( + obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test" + ) + res_c = problem.compute_feature_correlation( + obs_key=key_added, n_perms=10, n_jobs=1, seed=2, significance_method="perm_test" + ) assert res_a is not res_b np.testing.assert_array_equal(res_a.index, res_b.index) @@ -312,7 +335,8 @@ def test_seed_reproducible_parallelized(self, adata_time: AnnData): np.testing.assert_array_equal(res_a.columns, res_b.columns) np.testing.assert_allclose(res_a.values, res_b.values) - def test_confidence_level(self, adata_time: AnnData): + @pytest.mark.parametrize("corr_method", ["pearson", "spearman"]) + def test_confidence_level(self, adata_time: AnnData, corr_method: Literal["pearson", "spearman"]): key_added = "test" rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() @@ -326,8 +350,12 @@ def test_confidence_level(self, adata_time: AnnData): adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) - res_narrow = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.95) - res_wide = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.99) + res_narrow = problem.compute_feature_correlation( + obs_key=key_added, corr_method=corr_method, confidence_level=0.95 + ) + res_wide = problem.compute_feature_correlation( + obs_key=key_added, corr_method=corr_method, confidence_level=0.99 + ) assert np.all(res_narrow[f"{key_added}_ci_low"] >= res_wide[f"{key_added}_ci_low"]) assert np.all(res_narrow[f"{key_added}_ci_high"] <= res_wide[f"{key_added}_ci_high"]) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 45dacc98c..b64a6eb13 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -77,7 +77,7 @@ def test_compute_feature_correlation(self, adata_space_rotate: AnnData, method: key_added = "test_push" problem.push(source="0", target="1", data="celltype", subset="A", key_added=key_added) - feature_correlation = problem.compute_feature_correlation(key_added, method=method) + feature_correlation = problem.compute_feature_correlation(key_added, significance_method=method) assert isinstance(feature_correlation, pd.DataFrame) suffix = ["_corr", "_pval", "_qval", "_ci_low", "_ci_high"] diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 114ed1571..2af8ab263 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -64,7 +64,7 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: str): key_added = "test_push" problem.push(source=0, target=1, data="celltype", subset="A", key_added=key_added) - feature_correlation = problem.compute_feature_correlation(key_added, method=method) + feature_correlation = problem.compute_feature_correlation(key_added, significance_method=method) assert isinstance(feature_correlation, pd.DataFrame) suffix = ["_corr", "_pval", "_qval", "_ci_low", "_ci_high"]