Skip to content

Commit

Permalink
Cytotrace kernel (#527)
Browse files Browse the repository at this point in the history
* Add CytoTRACE kernel

* Add new kernel to docs

* Fix literal

* Work on the docstring

* Rename variables

* Rename the gcs variable

* Update the docstring

* Add docstring to kernel class

* Move the workflow example

* [ci skip] Enable CI

* Move to CytoTRACEKernel

* Improve logging, remove arbitrary restriction

* Do not use short kernel names

* Add CytoTRACE test skeletons

* Set use_raw=False, add tests

* Improve tests

* Fix raw and normal adata with different n_vars

* Add cytotrace score to adata_200

* Add regression test for ct_score

* Add regression test for gene_corr

* Fix tests

Co-authored-by: Marius Lange <marius.lange@t-online.de>
  • Loading branch information
michalk8 and Marius1311 committed Mar 19, 2021
1 parent fc8630e commit fa3673b
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 19 deletions.
3 changes: 2 additions & 1 deletion cellrank/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from statsmodels.stats.multitest import multipletests

from anndata import AnnData

import numpy as np
import pandas as pd
from pandas import Series
Expand All @@ -44,7 +46,6 @@
from cellrank.tl._linear_solver import _solve_lin_system
from cellrank.tl.kernels._utils import np_std, np_mean, _filter_kwargs

AnnData = TypeVar("AnnData")
ColorLike = TypeVar("ColorLike")
GPCCA = TypeVar("GPCCA")
CFLARE = TypeVar("CFLARE")
Expand Down
1 change: 1 addition & 0 deletions cellrank/tl/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cellrank.tl.kernels._base_kernel import Kernel, Constant
from cellrank.tl.kernels._velocity_kernel import VelocityKernel
from cellrank.tl.kernels._cytotrace_kernel import CytoTRACEKernel
from cellrank.tl.kernels._velocity_schemes import (
CosineScheme,
DotProductScheme,
Expand Down
4 changes: 2 additions & 2 deletions cellrank/tl/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,14 @@ def adata(self, _adata: AnnData):
self._adata = _adata

def __repr__(self):
return f"{'~' if self.backward and self._parent is None else ''}<{self.__class__.__name__[:4]}>"
return f"{'~' if self.backward and self._parent is None else ''}<{self.__class__.__name__}>"

def __str__(self):
params_fmt = self._format_params()
if params_fmt:
return (
f"{'~' if self.backward and self._parent is None else ''}"
f"<{self.__class__.__name__[:4]}[{params_fmt}]>"
f"<{self.__class__.__name__}[{params_fmt}]>"
)
return repr(self)

Expand Down
239 changes: 239 additions & 0 deletions cellrank/tl/kernels/_cytotrace_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from typing import Any

from typing_extensions import Literal

from anndata import AnnData

import numpy as np
from scipy.stats import gmean, hmean

from cellrank import logging as logg
from cellrank.ul._docs import d
from cellrank.tl._utils import _correlation_test_helper
from cellrank.tl._constants import ModeEnum
from cellrank.tl.kernels._pseudotime_kernel import PseudotimeKernel


def _ct(key: str) -> str:
return f"ct_{key}"


class CytoTRACEAggregation(ModeEnum): # noqa
MEAN = "mean"
MEDIAN = "median"
GMEAN = "gmean"
HMEAN = "hmean"


@d.dedent
class CytoTRACEKernel(PseudotimeKernel):
"""
Kernel which computes directed transition probabilities based on a KNN graph and the CytoTRACE score [Cyto20]_.
The KNN graph contains information about the (undirected) connectivities among cells, reflecting their similarity.
CytoTRACE can be used to estimate cellular plasticity and in turn, a pseudotemporal ordering of cells from more
plastic to less plastic states.
This kernel internally uses the :class:`cellrank.tl.kernels.PseudotimeKernel` to direct the KNN graph
on the basis of the CytoTRACE-derived pseudotime.
%(density_correction)s
Parameters
----------
%(adata)s
%(backward)s
%(cytotrace.parameters)s
compute_cond_num
Whether to compute condition number of the transition matrix. Note that this might be costly,
since it does not use sparse implementation.
Examples
--------
Workflow::
import scvelo as scv
import cellrank as cr
adata = cr.datasets.pancreas()
sc.pp.filter_genes(adata, min_cells=10)
adata.raw = adata.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata)
if 'spliced' not in adata.layers or 'unspliced' not in adata.layers:
# use the following trick to get scvelo's moments function working
adata.layers['spliced'] = adata.X
adata.layers['unspliced'] = adata.X
scv.pp.moments(adata, n_pcs=None, n_neighbors=None)
"""

def __init__(
self,
adata: AnnData,
backward: bool = False,
layer: str = "Ms",
aggregation: Literal["mean", "median", "hmean", "gmean"] = "mean",
use_raw: bool = False,
compute_cond_num: bool = False,
check_connectivity: bool = False,
):
super().__init__(
adata,
backward=backward,
time_key=_ct("pseudotime"),
compute_cond_num=compute_cond_num,
check_connectivity=check_connectivity,
layer=layer,
aggregation=aggregation,
use_raw=use_raw,
)
self._time_key = _ct("pseudotime") # quirk or PT kernel

def _read_from_adata(
self,
time_key: str,
layer: str = "Ms",
aggregation: Literal["mean", "median", "hmean", "gmean"] = "mean",
use_raw: bool = True,
**kwargs: Any,
):
self.compute_cytotrace(layer=layer, aggregation=aggregation, use_raw=use_raw)

super()._read_from_adata(time_key=time_key, **kwargs)

@d.get_sections(base="cytotrace", sections=["Parameters"])
def compute_cytotrace(
self,
layer: str = "Ms",
aggregation: Literal["mean", "median", "hmean", "gmean"] = "mean",
use_raw: bool = False,
) -> None:
"""
Re-implementation of the CytoTRACE algorithm [Cyto20]_ to estimate cellular plasticity.
Computes the number of genes expressed per cell and ranks genes according to their correlation with this
measure. Next, it selects to top-correlating genes and aggregates their (imputed) expression to obtain
the CytoTRACE score. A high score stands for high differentiation potential (naive, plastic cells) and
a low score stands for low differentiation potential (mature, differentiation cells).
Note that this will not exactly reproduce the results of the original CytoTRACE algorithm [Cyto20]_ because we
allow for any normalization and imputation techniques whereas CytoTRACE has build-in specific methods for that.
Parameters
----------
layer
Key in :attr:`anndata.AnnData.layers` or `'X'` for :attr:`anndata.AnnData.X`
from where to get the expression.
aggregation
How to aggregate expression of the top-correlating genes. Valid options are:
- `'mean'`: arithmetic mean.
- `'median'`: median.
- `'gmean'`: geometric mean.
- `'hmean'`: harmonic mean.
use_raw
Whether to use the :attr:`anndata.AnnData.raw` to compute the number of genes expressed per cell
(#genes/cell) and the correlation of gene expression across cells with #genes/cell.
Returns
-------
Nothing, just modifies :attr:`anndata.AnnData.obs` with the following keys:
- `'ct_score'`: the normalized CytoTRACE score.
- `'ct_pseudotime'`: associated pseudotime, essentially `1 - CytoTRACE score`.
- `'ct_num_exp_genes'`: the number of genes expressed per cell, basis of the CytoTRACE score.
It also modifies :attr:`anndata.AnnData.var` with the following keys:
- `'ct_gene_corr'`: the correlation as specified above.
- `'ct_correlates'`: indication of the genes used to compute the CytoTRACE score, i.e. the ones that
correlated best with `'num_exp_genes'`.
"""
# check use_raw
aggregation = CytoTRACEAggregation(aggregation)
if use_raw and self.adata.raw is None:
logg.warning("`adata.raw` is `None`. Setting `use_raw=False`")
use_raw = False
if use_raw and self.adata.raw.n_vars != self.adata.n_vars:
logg.warning(
f"`adata.raw` has different number of genes ({self.adata.raw.n_vars}) "
f"than `adata` ({self.adata.n_vars}). Setting `use_raw=False`"
)
use_raw = False

adata_mraw = self.adata.raw if use_raw else self.adata
if layer != "X" and layer not in self.adata.layers:
raise KeyError(
f"Unable to find `{layer!r}` in `adata.layers`. "
f"Valid option are: `{sorted({'X'} | set(self.adata.layers.keys()))}`."
)

msg = f"Computing CytoTRACE score with `{self.adata.n_vars}` genes"
if self.adata.n_vars < 10000:
msg += ". Consider using more than `10000` genes"
start = logg.info(msg)

# compute number of expressed genes per cell
logg.debug(
f"Computing number of genes expressed per cell with `use_raw={use_raw}`"
)
num_exp_genes = np.array((adata_mraw.X > 0).sum(axis=1)).reshape(-1)
self.adata.obs[_ct("num_exp_genes")] = num_exp_genes

# fmt: off
# compute correlation with all genes
logg.debug("Correlating all genes with number of genes expressed per cell")
gene_corr, _, _, _ = _correlation_test_helper(adata_mraw.X.T, num_exp_genes[:, None])

# annotate the top 200 genes in terms of correlation
logg.debug("Finding the top `200` most correlated genes")
self.adata.var[_ct("gene_corr")] = gene_corr
top_200 = self.adata.var.sort_values(by=_ct("gene_corr"), ascending=False).index[:200]
self.adata.var[_ct("correlates")] = False
self.adata.var.loc[top_200, _ct("correlates")] = True

# compute mean/median over top 200 genes, aggregate over genes and shift to [0, 1] range
logg.debug(f"Aggregating imputed gene expression using aggregation `{aggregation}` in layer `{layer}`")
corr_mask = self.adata.var[_ct("correlates")]
imputed_exp = self.adata[:, corr_mask].X if layer == "X" else self.adata[:, corr_mask].layers[layer]

# aggregate across the top 200 genes
if aggregation == CytoTRACEAggregation.MEAN:
cytotrace_score = np.mean(imputed_exp, axis=1)
elif aggregation == CytoTRACEAggregation.MEDIAN:
cytotrace_score = np.median(imputed_exp, axis=1)
elif aggregation == CytoTRACEAggregation.GMEAN:
cytotrace_score = gmean(imputed_exp, axis=1)
elif aggregation == CytoTRACEAggregation.HMEAN:
cytotrace_score = hmean(imputed_exp, axis=1)
else:
raise NotImplementedError(f"Aggregation method `{aggregation}` is not yet implemented.")
# fmt: on

# scale to 0-1 range
cytotrace_score -= np.min(cytotrace_score)
cytotrace_score /= np.max(cytotrace_score)
self.adata.obs[_ct("score")] = cytotrace_score
self.adata.obs[_ct("pseudotime")] = 1 - cytotrace_score

self.adata.uns[_ct("params")] = {
"aggregation": aggregation.s,
"layer": layer,
"use_raw": use_raw,
}

logg.info(
f"Adding `adata.obs[{_ct('score')!r}]`\n"
f" `adata.obs[{_ct('pseudotime')!r}]`\n"
f" `adata.obs[{_ct('num_exp_genes')!r}]`\n"
f" `adata.var[{_ct('gene_corr')!r}]`\n"
f" `adata.var[{_ct('correlates')!r}]`\n"
f" `adata.uns[{_ct('params')!r}]`\n"
f" Finish",
time=start,
)
5 changes: 4 additions & 1 deletion cellrank/tl/kernels/_precomputed_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def __invert__(self) -> "PrecomputedKernel":
return self

def __repr__(self):
return f"{'~' if self.backward and self._parent is None else ''}<Precomputed[origin={self._origin}]>"
return (
f"{'~' if self.backward and self._parent is None else ''}"
f"<{self.__class__.__name__}[origin={self._origin}]>"
)

def __str__(self):
return repr(self)
Expand Down
14 changes: 6 additions & 8 deletions cellrank/tl/kernels/_pseudotime_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,21 @@ def __init__(
time_key: str = "dpt_pseudotime",
compute_cond_num: bool = False,
check_connectivity: bool = False,
**kwargs: Any,
):
super().__init__(
adata,
backward=backward,
time_key=time_key,
compute_cond_num=compute_cond_num,
check_connectivity=check_connectivity,
**kwargs,
)
self._time_key = time_key

def _read_from_adata(self, **kwargs):
def _read_from_adata(self, time_key: str, **kwargs):
super()._read_from_adata(**kwargs)

time_key = kwargs.pop("time_key", "dpt_pseudotime")
if time_key not in self.adata.obs.keys():
raise KeyError(f"Could not find time key in `adata.obs[{time_key!r}]`.")

Expand All @@ -71,9 +72,6 @@ def _read_from_adata(self, **kwargs):
if np.any(np.isnan(self._pseudotime)):
raise ValueError("Encountered NaN values in pseudotime.")

logg.debug("Clipping the pseudotime to 0-1 range")
self._pseudotime = np.clip(self._pseudotime, 0, 1)

@d.dedent
def compute_transition_matrix(
self,
Expand Down Expand Up @@ -118,7 +116,7 @@ def compute_transition_matrix(
:class:`cellrank.tl.kernels.PseudotimeKernel`
Makes :paramref:`transition_matrix` available.
"""
start = logg.info("Computing transition matrix based on pseudotime")
start = logg.info(f"Computing transition matrix based on `{self._time_key}`")

# get the connectivities and number of neighbors
n_neighbors = (
Expand Down Expand Up @@ -161,7 +159,7 @@ def compute_transition_matrix(

# handle backward case and run biasing function
pseudotime = (
np.nanmax(self.pseudotime) - self.pseudotime
np.max(self.pseudotime) - self.pseudotime
if self._direction == Direction.BACKWARD
else self.pseudotime
)
Expand Down Expand Up @@ -202,5 +200,5 @@ def copy(self) -> "PseudotimeKernel":

def __invert__(self) -> "PseudotimeKernel":
super().__invert__()
self._pseudotime = np.nanmax(self.pseudotime) - self.pseudotime
self._pseudotime = np.max(self.pseudotime) - self.pseudotime
return self
3 changes: 2 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Kernels are part of the low-level API and are used to estimate cell-to-cell tran

tl.kernels.VelocityKernel
tl.kernels.ConnectivityKernel
tl.kernels.PalantirKernel
tl.kernels.PseudotimeKernel
tl.kernels.CytoTRACEKernel
tl.kernels.PrecomputedKernel

Estimators
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ Soft threshold scheme
:members:
:special-members: __call__

CytoTRACE Kernel
----------------

.. autoclass:: cellrank.tl.kernels.CytoTRACEKernel
:members:

Precomputed Kernel
------------------
Expand Down
4 changes: 4 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ References
*Geometric diffusions as a tool for harmonic analysis and structure definition of data: Diffusion maps*,
`PNAS <https://doi.org/10.1073/pnas.0500334102>`__.
.. [Cyto20] Gulati *et al.* (2020),
*Single-cell transcriptional diversity is a hallmark of developmental potential*,
`Science <https://doi.org/10.1126/science.aax0249>`__.
.. [GPCCA18] Reuter, B. *et al.* (2018),
*Generalized Markov State Modeling Method for Nonequilibrium Biomolecular Dynamics: Exemplified on Amyloid β
Conformational Dynamics Driven by an Oscillating Electric Field.*,
Expand Down
Binary file modified tests/_ground_truth_adatas/adata_200.h5ad
Binary file not shown.
Loading

0 comments on commit fa3673b

Please sign in to comment.