Skip to content

Commit

Permalink
Feature/interpolate colors sankey (#434)
Browse files Browse the repository at this point in the history
* fix return type in mpl

* change import acronyms

* fix tests

* add interpolation option to sankey

* add test to interpolate color

* define colors for pull/push`

* adapt tests

* introduce axes in mpl.push/pull

* incorporate requested changes

* change default color

* adapt plotting

* introduce scaling

* fix scale

* make start/end categorical in plot

* regenerate images
  • Loading branch information
MUCDK authored Dec 23, 2022
1 parent 5e8f847 commit ce367a4
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 96 deletions.
60 changes: 59 additions & 1 deletion src/moscot/_constants/_key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any, Callable
from typing import Any, Set, List, Callable, Optional

import numpy as np

from anndata import AnnData


class cprop:
Expand Down Expand Up @@ -28,3 +32,57 @@ def spatial(cls) -> str:
@classmethod
def nhood_enrichment(cls, cluster: str) -> str:
return f"{cluster}_nhood_enrichment"


class RandomKeys:
"""
Create random keys inside an :class:`anndata.AnnData` object.
Parameters
----------
adata
Annotated data object.
n
Number of keys, If `None`, create just 1 keys.
where
Attribute of ``adata``. If `'obs'`, also clean up `'{key}_colors'` for each generated key.
"""

def __init__(self, adata: AnnData, n: Optional[int] = None, where: str = "obs"):
self._adata = adata
self._where = where
self._n = n or 1
self._keys: List[str] = []

def _generate_random_keys(self):
def generator():
return f"RNG_COL_{np.random.randint(2 ** 16)}"

where = getattr(self._adata, self._where)
names: List[str] = []
seen: Set[str] = set(where.keys())

while len(names) != self._n:
name = generator()
if name not in seen:
seen.add(name)
names.append(name)

return names

def __enter__(self):
self._keys = self._generate_random_keys()
return self._keys

def __exit__(self, exc_type, exc_val, exc_tb):
for key in self._keys:
try:
getattr(self._adata, self._where).drop(key, axis="columns", inplace=True)
except KeyError:
pass
if self._where == "obs":
try:
del self._adata.uns[f"{key}_colors"]
except KeyError:
pass
43 changes: 39 additions & 4 deletions src/moscot/_docs/_docs_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
_cbar_kwargs_cell_transition = """\
cbar_kwargs
Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`."""
Keyword arguments for :func:`matplotlib.figure.Figure.colorbar`."""
# return cell transition
_return_cell_transition = """\
:class:`matplotlib.figure.Figure` heatmap of cell transition matrix.
Expand Down Expand Up @@ -64,7 +64,18 @@
- `captions`
- `key`
"""

_alpha_transparency = """\
alpha
Transparancy value.
"""
_interpolate_color = """\
interpolate_color
Whether the color is continuously interpolated.
"""
_sankey_kwargs = """\
kwargs
Keyword arguments for :func:`matplotlib.pyplot.fill_between`.
"""
###############################################################################
# plotting.push/pull
# input
Expand All @@ -85,6 +96,10 @@
basis
Basis of the embedding, saved in :attr:`anndata.AnnData.obsm`.
"""
_scale_push_pull = """\
scale
Whether to linearly scale the distribution.
"""
# return push/pull
_return_push_pull = """\
:class:`matplotlib.figure.Figure` scatterplot in `basis` coordinates.
Expand Down Expand Up @@ -113,8 +128,20 @@
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`."""
_title = """\
title
TODO."""

Title of the plot.
"""
_dot_scale_factor = """\
dot_scale_factor
If `time_points` is not `None`, `dot_scale_factor` increases the size of the dots by this factor.
"""
_na_color = """\
na_color
Color to use for null or masked values. Can be anything matplotlib accepts as a color.
"""
_suptitle_fontsize = """
suptitle_fontsize
Fontsize of the suptitle.
"""
###############################################################################
# general output
_return_fig = """\
Expand All @@ -133,6 +160,7 @@
Path where to save the plot. If `None`, the plot is not saved.
{_ax}"""


d_plotting = DocstringProcessor(
desc_cell_transition=_desc_cell_transition,
transition_labels_cell_transition=_transition_labels_cell_transition,
Expand All @@ -159,4 +187,11 @@
ax=_ax,
figsize_dpi_save=_figsize_dpi_save,
fontsize=_fontsize,
alpha_transparency=_alpha_transparency,
interpolate_color=_interpolate_color,
sankey_kwargs=_sankey_kwargs,
na_color=_na_color,
dot_scale_factor=_dot_scale_factor,
scale_push_pull=_scale_push_pull,
suptitle_fontsize=_suptitle_fontsize,
)
78 changes: 59 additions & 19 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import MappingProxyType
from typing import Any, Dict, List, Tuple, Union, Mapping, Iterable, Optional
from typing import Any, Dict, List, Tuple, Union, Mapping, Optional, Sequence

from matplotlib import colors as mcolors
from matplotlib.axes import Axes
Expand All @@ -11,10 +11,9 @@

from moscot.problems.base import CompoundProblem # type: ignore[attr-defined]
from moscot.problems.time import LineageProblem, TemporalProblem # type: ignore[attr-defined]
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas, _create_col_colors
from moscot._docs._docs_plot import d_plotting
from moscot._constants._constants import AdataKeys, PlottingKeys, PlottingDefaults
from moscot.problems.base._compound_problem import K


@d_plotting.dedent
Expand All @@ -30,7 +29,7 @@ def cell_transition(
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
return_fig: bool = False,
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> mpl.figure.Figure:
Expand Down Expand Up @@ -99,12 +98,14 @@ def sankey(
captions: Optional[List[str]] = None,
title: Optional[str] = None,
colors_dict: Optional[Dict[str, float]] = None,
alpha: float = 1.0,
interpolate_color: bool = False,
cmap: Union[str, mcolors.Colormap] = "viridis",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
return_fig: bool = False,
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Expand All @@ -119,8 +120,11 @@ def sankey(
%(captions_sankey)s
%(title)s
%(colors_dict_sankey)s
%(alpha_transparency)s
%(interpolate_color)s
%(cmap)s
%(figsize_dpi_save)s
%(sankey_kwargs)s
Returns
-------
Expand Down Expand Up @@ -150,6 +154,8 @@ def sankey(
figsize=figsize,
dpi=dpi,
ax=ax,
alpha=alpha,
interpolate_color=interpolate_color,
**kwargs,
)
if save:
Expand All @@ -162,17 +168,21 @@ def sankey(
def push(
inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
uns_key: Optional[str] = None,
time_points: Optional[Iterable[K]] = None,
time_points: Optional[Sequence[float]] = None,
basis: str = "umap",
result_key: str = "plot_push",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
scale: bool = True,
title: Optional[Union[str, List[str]]] = None,
suptitle: Optional[str] = None,
cmap: Optional[Union[str, mcolors.Colormap]] = None,
dot_scale_factor: float = 2.0,
na_color: str = "#e8ebe9",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
return_fig: bool = False,
suptitle_fontsize: Optional[float] = None,
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Expand All @@ -186,11 +196,14 @@ def push(
%(uns_key)s
%(time_points_push_pull)s
%(basis_push_pull)s
%(result_key_push_pull)s
%(fill_value_push_pull)s
%(scale_push_pull)s
%(title)s
%(cmap)s
%(dot_scale_factor)s
%(na_color)s
%(figsize_dpi_save)s
%(suptitle_fontsize)s
Returns
-------
Expand All @@ -206,20 +219,30 @@ def push(
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PUSH][key]
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
categories=data["subset"],
push=True,
time_points=time_points,
basis=basis,
result_key=result_key,
constant_fill_value=fill_value,
scale=scale,
save=save,
cont_cmap=cmap,
dot_scale_factor=dot_scale_factor,
na_color=na_color,
title=title,
suptitle=suptitle,
figsize=figsize,
dpi=dpi,
ax=ax,
suptitle_fontsize=suptitle_fontsize,
**kwargs,
)
if return_fig:
Expand All @@ -230,17 +253,21 @@ def push(
def pull(
inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
uns_key: Optional[str] = None,
time_points: Optional[Iterable[K]] = None,
time_points: Optional[Sequence[float]] = None,
basis: str = "umap",
result_key: str = "plot_pull",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
scale: bool = True,
title: Optional[Union[str, List[str]]] = None,
suptitle: Optional[str] = None,
cmap: Optional[Union[str, mcolors.Colormap]] = None,
dot_scale_factor: float = 2.0,
na_color: str = "#e8ebe9",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
return_fig: bool = False,
suptitle_fontsize: Optional[float] = None,
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Expand All @@ -254,11 +281,14 @@ def pull(
%(uns_key)s
%(time_points_push_pull)s
%(basis_push_pull)s
%(result_key_push_pull)s
%(fill_value_push_pull)s
%(scale_push_pull)s
%(title)s
%(cmap)s
%(dot_scale_factor)s
%(na_color)s
%(figsize_dpi_save)s
%(suptitle_fontsize)s
Returns
-------
Expand All @@ -274,20 +304,30 @@ def pull(
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PULL][key]
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
categories=data["subset"],
push=False,
time_points=time_points,
basis=basis,
result_key=result_key,
constant_fill_value=fill_value,
scale=scale,
save=save,
cont_cmap=cmap,
dot_scale_factor=dot_scale_factor,
na_color=na_color,
title=title,
suptitle=suptitle,
figsize=figsize,
dpi=dpi,
ax=ax,
suptitle_fontsize=suptitle_fontsize,
**kwargs,
)
if return_fig:
Expand Down
Loading

0 comments on commit ce367a4

Please sign in to comment.