Skip to content

Commit

Permalink
Feature/spearman correlation (#444)
Browse files Browse the repository at this point in the history
* add spearman correlation

* add tests

* adapt tests
  • Loading branch information
MUCDK authored and lucaeyring committed Mar 15, 2023
1 parent 3d22677 commit b5370e0
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 33 deletions.
22 changes: 17 additions & 5 deletions src/moscot/problems/base/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
_validate_args_cell_transition,
_check_argument_compatibility_cell_transition,
)
from moscot._constants._constants import Key, AdataKeys, PlottingKeys, CorrTestMethod, AggregationMode, PlottingDefaults
from moscot._constants._constants import (
Key,
AdataKeys,
CorrMethod,
PlottingKeys,
CorrTestMethod,
AggregationMode,
PlottingDefaults,
)
from moscot.problems._subset_policy import SubsetPolicy
from moscot.problems.base._compound_problem import B, K, ApplyOutput_t

Expand Down Expand Up @@ -446,7 +454,8 @@ def _cell_aggregation_transition(
def compute_feature_correlation(
self: AnalysisMixinProtocol[K, B],
obs_key: str,
method: Literal["fischer", "perm_test"] = CorrTestMethod.FISCHER,
corr_method: CorrMethod = CorrMethod.PEARSON,
significance_method: Literal["fischer", "perm_test"] = CorrTestMethod.FISCHER,
annotation: Optional[Dict[str, Iterable[str]]] = None,
layer: Optional[str] = None,
features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None,
Expand All @@ -465,7 +474,9 @@ def compute_feature_correlation(
----------
obs_key
Column of :attr:`anndata.AnnData.obs` containing push-forward or pull-back distributions.
method
corr_method
Which type of correlation to compute, options are `pearson`, and `spearman`.
significance_method
Mode to use when calculating p-values and confidence intervals. Valid options are:
- `fischer` - use Fischer transformation :cite:`fischer:21`.
Expand Down Expand Up @@ -512,7 +523,7 @@ def compute_feature_correlation(
if obs_key not in self.adata.obs:
raise KeyError(f"Unable to access data in `adata.obs[{obs_key!r}]`.")

method = CorrTestMethod(method)
significance_method = CorrTestMethod(significance_method)

if annotation is not None:
annotation_key, annotation_vals = next(iter(annotation.items()))
Expand Down Expand Up @@ -542,7 +553,8 @@ def compute_feature_correlation(
X=sc.get.obs_df(adata, keys=features, layer=layer).values,
Y=distribution,
feature_names=features,
method=method,
corr_method=corr_method,
significance_method=significance_method,
confidence_level=confidence_level,
n_perms=n_perms,
seed=seed,
Expand Down
41 changes: 26 additions & 15 deletions src/moscot/problems/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import warnings

from scipy.stats import norm
from scipy.stats import norm, rankdata
from scipy.sparse import issparse, spmatrix, csr_matrix, isspmatrix_csr
from statsmodels.stats.multitest import multipletests
import pandas as pd
Expand All @@ -16,7 +16,7 @@
from moscot._types import ArrayLike, Str_Dict_t
from moscot._logging import logger
from moscot._docs._docs import d
from moscot._constants._constants import CorrTestMethod, AggregationMode
from moscot._constants._constants import CorrMethod, CorrTestMethod, AggregationMode

__all__ = [
"attributedispatch",
Expand Down Expand Up @@ -279,7 +279,8 @@ def _correlation_test(
X: Union[ArrayLike, spmatrix],
Y: pd.DataFrame,
feature_names: Sequence[str],
method: CorrTestMethod = CorrTestMethod.FISCHER,
corr_method: CorrMethod = CorrMethod.PEARSON,
significance_method: CorrTestMethod = CorrTestMethod.FISCHER,
confidence_level: float = 0.95,
n_perms: Optional[int] = None,
seed: Optional[int] = None,
Expand All @@ -298,7 +299,9 @@ def _correlation_test(
Data frame of shape ``(n_cells, 1)`` containing the pull/push distribution.
feature_names
Sequence of shape ``(n_features,)`` containing the feature names.
method
corr_method
Which type of correlation to compute, options are `pearson`, and `spearman`.
significance_method
Method for p-value calculation.
confidence_level
Confidence level for the confidence interval calculation. Must be in `[0, 1]`.
Expand All @@ -322,7 +325,8 @@ def _correlation_test(
corr, pvals, ci_low, ci_high = _correlation_test_helper(
X.T,
Y.values,
method=method,
corr_method=corr_method,
significance_method=significance_method,
n_perms=n_perms,
seed=seed,
confidence_level=confidence_level,
Expand Down Expand Up @@ -363,7 +367,8 @@ def _correlation_test(
def _correlation_test_helper(
X: ArrayLike,
Y: ArrayLike,
method: CorrTestMethod = CorrTestMethod.FISCHER,
corr_method: CorrMethod = CorrMethod.SPEARMAN,
significance_method: CorrTestMethod = CorrTestMethod.FISCHER,
n_perms: Optional[int] = None,
seed: Optional[int] = None,
confidence_level: float = 0.95,
Expand All @@ -378,7 +383,9 @@ def _correlation_test_helper(
Array or matrix of `(M, N)` elements.
Y
Array of `(N, K)` elements.
method
corr_method
Which type of correlation to compute, options are `pearson`, and `spearman`.
significance_method
Method for p-value calculation.
n_perms
Number of permutations if ``method='perm_test'``.
Expand Down Expand Up @@ -415,19 +422,23 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr

if issparse(X) and not isspmatrix_csr(X):
X = csr_matrix(X)
corr = _mat_mat_corr_sparse(X, Y) if issparse(X) else _mat_mat_corr_dense(X, Y)

if method == CorrTestMethod.FISCHER:
if corr_method == CorrMethod.SPEARMAN:
X, Y = rankdata(X, method="average", axis=0), rankdata(Y, method="average", axis=0)
corr = _pearson_mat_mat_corr_sparse(X, Y) if issparse(X) else _pearson_mat_mat_corr_dense(X, Y)

if significance_method == CorrTestMethod.FISCHER:
# see: https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Using_the_Fisher_transformation
mean, se = np.arctanh(corr), 1.0 / np.sqrt(n - 3)
# for spearman see: https://www.sciencedirect.com/topics/mathematics/spearman-correlation
mean, se = np.arctanh(corr), 1 / np.sqrt(n - 3)
z_score = (np.arctanh(corr) - np.arctanh(0)) * np.sqrt(n - 3)

z = norm.ppf(qh)
corr_ci_low = np.tanh(mean - z * se)
corr_ci_high = np.tanh(mean + z * se)
pvals = 2 * norm.cdf(-np.abs(z_score))

elif method == CorrTestMethod.PERM_TEST:
elif significance_method == CorrTestMethod.PERM_TEST:
if not isinstance(n_perms, int):
raise TypeError(f"Expected `n_perms` to be an integer, found `{type(n_perms).__name__}`.")
if n_perms <= 0:
Expand All @@ -443,12 +454,12 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr
)(corr, X, Y, seed=seed)

else:
raise NotImplementedError(method)
raise NotImplementedError(significance_method)

return corr, pvals, corr_ci_low, corr_ci_high


def _mat_mat_corr_sparse(
def _pearson_mat_mat_corr_sparse(
X: csr_matrix,
Y: ArrayLike,
) -> ArrayLike:
Expand All @@ -464,7 +475,7 @@ def _mat_mat_corr_sparse(
return (X @ Y - (n * X_bar * y_bar)) / ((n - 1) * X_std * y_std)


def _mat_mat_corr_dense(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
def _pearson_mat_mat_corr_dense(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
from moscot._utils import np_std, np_mean

n = X.shape[1]
Expand Down Expand Up @@ -493,7 +504,7 @@ def _perm_test(
pvals = np.zeros_like(corr, dtype=np.float64)
corr_bs = np.zeros((len(ixs), X.shape[0], Y.shape[1])) # perms x genes x lineages

mmc = _mat_mat_corr_sparse if issparse(X) else _mat_mat_corr_dense
mmc = _pearson_mat_mat_corr_sparse if issparse(X) else _pearson_mat_mat_corr_dense

for i, _ in enumerate(ixs):
rs.shuffle(cell_ixs)
Expand Down
50 changes: 39 additions & 11 deletions tests/analysis_mixins/test_base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,14 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD
ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
)

@pytest.mark.parametrize("method", ["fischer", "perm_test"])
def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal["fischer", "perm_test"]):
@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
@pytest.mark.parametrize("significance_method", ["fischer", "perm_test"])
def test_compute_feature_correlation(
self,
adata_time: AnnData,
corr_method: Literal["pearson", "spearman"],
significance_method: Literal["fischer", "perm_test"],
):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
Expand All @@ -180,7 +186,9 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal[

adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))

res = problem.compute_feature_correlation(obs_key=key_added, method=method)
res = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, significance_method=significance_method
)

assert isinstance(res, pd.DataFrame)
assert res.isnull().values.sum() == 0
Expand All @@ -191,10 +199,15 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal[
assert np.all(res[f"{key_added}_qval"] >= 0)
assert np.all(res[f"{key_added}_qval"] <= 1.0)

@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
@pytest.mark.parametrize("features", [10, None])
@pytest.mark.parametrize("method", ["fischer", "perm_test"])
def test_compute_feature_correlation_subset(
self, adata_time: AnnData, features: Optional[int], method: Literal["fischer", "perm_test"]
self,
adata_time: AnnData,
features: Optional[int],
corr_method: Literal["pearson", "spearman"],
method: Literal["fischer", "perm_test"],
):
key_added = "test"
rng = np.random.RandomState(42)
Expand All @@ -215,7 +228,11 @@ def test_compute_feature_correlation_subset(
else:
features_validation = list(adata_time.var_names)
res = problem.compute_feature_correlation(
obs_key=key_added, annotation={"celltype": ["A"]}, method=method, features=features
obs_key=key_added,
annotation={"celltype": ["A"]},
corr_method=corr_method,
significance_method=method,
features=features,
)
assert isinstance(res, pd.DataFrame)
assert res.isnull().values.sum() == 0
Expand Down Expand Up @@ -275,9 +292,15 @@ def test_seed_reproducible(self, adata_time: AnnData):

adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))

res_a = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test")
res_b = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test")
res_c = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=1, method="perm_test")
res_a = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
)
res_b = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
)
res_c = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=2, significance_method="perm_test"
)

assert res_a is not res_b
np.testing.assert_array_equal(res_a.index, res_b.index)
Expand Down Expand Up @@ -312,7 +335,8 @@ def test_seed_reproducible_parallelized(self, adata_time: AnnData):
np.testing.assert_array_equal(res_a.columns, res_b.columns)
np.testing.assert_allclose(res_a.values, res_b.values)

def test_confidence_level(self, adata_time: AnnData):
@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
def test_confidence_level(self, adata_time: AnnData, corr_method: Literal["pearson", "spearman"]):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
Expand All @@ -326,8 +350,12 @@ def test_confidence_level(self, adata_time: AnnData):

adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))

res_narrow = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.95)
res_wide = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.99)
res_narrow = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, confidence_level=0.95
)
res_wide = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, confidence_level=0.99
)

assert np.all(res_narrow[f"{key_added}_ci_low"] >= res_wide[f"{key_added}_ci_low"])
assert np.all(res_narrow[f"{key_added}_ci_high"] <= res_wide[f"{key_added}_ci_high"])
2 changes: 1 addition & 1 deletion tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_compute_feature_correlation(self, adata_space_rotate: AnnData, method:

key_added = "test_push"
problem.push(source="0", target="1", data="celltype", subset="A", key_added=key_added)
feature_correlation = problem.compute_feature_correlation(key_added, method=method)
feature_correlation = problem.compute_feature_correlation(key_added, significance_method=method)

assert isinstance(feature_correlation, pd.DataFrame)
suffix = ["_corr", "_pval", "_qval", "_ci_low", "_ci_high"]
Expand Down
2 changes: 1 addition & 1 deletion tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: str):

key_added = "test_push"
problem.push(source=0, target=1, data="celltype", subset="A", key_added=key_added)
feature_correlation = problem.compute_feature_correlation(key_added, method=method)
feature_correlation = problem.compute_feature_correlation(key_added, significance_method=method)

assert isinstance(feature_correlation, pd.DataFrame)
suffix = ["_corr", "_pval", "_qval", "_ci_low", "_ci_high"]
Expand Down

0 comments on commit b5370e0

Please sign in to comment.