Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/interpolate colors sankey #434

Merged
merged 29 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/moscot/_docs/_docs_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@
- `captions`
- `key`
"""

_alpha_transparancy = """\
Alpha
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
Transparancy value.
"""
_interpolate_color = """\
interpolate_color
Whether the color is continuously interpolated.
"""
_sankey_kwargs = """\
kwargs
Keyword arguments for :meth:`matplotlib.pyplot.fill_between`.
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
"""
###############################################################################
# plotting.push/pull
# input
Expand All @@ -85,6 +96,10 @@
basis
Basis of the embedding, saved in :attr:`anndata.AnnData.obsm`.
"""
_scale_push_pull = """\
scale
Whether to linearly scale the distribution.
"""
# return push/pull
_return_push_pull = """\
:class:`matplotlib.figure.Figure` scatterplot in `basis` coordinates.
Expand Down Expand Up @@ -113,7 +128,16 @@
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`."""
_title = """\
title
TODO."""
Title of the plot.
"""
_dot_scale_factor = """\
dot_scale_factor
If `time_points` is not `None`, `dot_scale_factor` increases the size of the dots by this factor.
"""
_na_color = """\
na_color
Color to use for null or masked values. Can be anything matplotlib accepts as a color.
"""

###############################################################################
# general output
Expand All @@ -133,6 +157,7 @@
Path where to save the plot. If `None`, the plot is not saved.
{_ax}"""


d_plotting = DocstringProcessor(
desc_cell_transition=_desc_cell_transition,
transition_labels_cell_transition=_transition_labels_cell_transition,
Expand All @@ -159,4 +184,10 @@
ax=_ax,
figsize_dpi_save=_figsize_dpi_save,
fontsize=_fontsize,
alpha_transparancy=_alpha_transparancy,
interpolate_color=_interpolate_color,
sankey_kwargs=_sankey_kwargs,
na_color=_na_color,
dot_scale_factor=_dot_scale_factor,
scale_push_pull=_scale_push_pull,
)
71 changes: 56 additions & 15 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import MappingProxyType
from typing import Any, Dict, List, Tuple, Union, Mapping, Iterable, Optional
from typing import Any, Dict, List, Tuple, Union, Mapping, Optional, Sequence

from matplotlib import colors as mcolors
from matplotlib.axes import Axes
Expand All @@ -11,7 +11,7 @@

from moscot.problems.base import CompoundProblem # type: ignore[attr-defined]
from moscot.problems.time import LineageProblem, TemporalProblem # type: ignore[attr-defined]
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas, _create_col_colors
from moscot._docs._docs_plot import d_plotting
from moscot._constants._constants import AdataKeys, PlottingKeys, PlottingDefaults
from moscot.problems.base._compound_problem import K
Expand Down Expand Up @@ -99,6 +99,8 @@ def sankey(
captions: Optional[List[str]] = None,
title: Optional[str] = None,
colors_dict: Optional[Dict[str, float]] = None,
alpha: float = 1.0,
interpolate_color: bool = False,
cmap: Union[str, mcolors.Colormap] = "viridis",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
Expand All @@ -119,8 +121,11 @@ def sankey(
%(captions_sankey)s
%(title)s
%(colors_dict_sankey)s
%(alpha_transparency)s
%(interpolate_color)s
%(cmap)s
%(figsize_dpi_save)s
%(sankey_kwargs)s

Returns
-------
Expand Down Expand Up @@ -150,6 +155,8 @@ def sankey(
figsize=figsize,
dpi=dpi,
ax=ax,
alpha=alpha,
interpolate_color=interpolate_color,
**kwargs,
)
if save:
Expand All @@ -162,17 +169,21 @@ def sankey(
def push(
inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
uns_key: Optional[str] = None,
time_points: Optional[Iterable[K]] = None,
time_points: Optional[Sequence[K]] = None,
basis: str = "umap",
result_key: str = "plot_push",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
scale: bool = True,
title: Optional[Union[str, List[str]]] = None,
suptitle: Optional[str] = None,
cmap: Optional[Union[str, mcolors.Colormap]] = None,
dot_scale_factor: float = 2.0,
na_color: Optional[str] = "#e8ebe9",
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
suptitle_fontsize: Optional[float] = None,
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Expand All @@ -186,10 +197,12 @@ def push(
%(uns_key)s
%(time_points_push_pull)s
%(basis_push_pull)s
%(result_key_push_pull)s
%(fill_value_push_pull)s
%(scale_push_pull)s
%(title)s
%(cmap)s
%(dot_scale_factor)s
%(na_color)s
%(figsize_dpi_save)s

Returns
Expand All @@ -206,20 +219,30 @@ def push(
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PUSH][key]
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
categories=data["subset"],
push=True,
time_points=time_points,
basis=basis,
result_key=result_key,
constant_fill_value=fill_value,
scale=scale,
save=save,
cont_cmap=cmap,
dot_scale_factor=dot_scale_factor,
na_color=na_color,
title=title,
suptitle=suptitle,
figsize=figsize,
dpi=dpi,
ax=ax,
suptitle_fontsize=suptitle_fontsize,
**kwargs,
)
if return_fig:
Expand All @@ -230,17 +253,21 @@ def push(
def pull(
inp: Union[AnnData, TemporalProblem, LineageProblem, CompoundProblem],
uns_key: Optional[str] = None,
time_points: Optional[Iterable[K]] = None,
time_points: Optional[Sequence[K]] = None,
basis: str = "umap",
result_key: str = "plot_pull",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
scale: bool = True,
title: Optional[Union[str, List[str]]] = None,
suptitle: Optional[str] = None,
cmap: Optional[Union[str, mcolors.Colormap]] = None,
dot_scale_factor: float = 2.0,
na_color: Optional[str] = "#e8ebe9",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
ax: Optional[Axes] = None,
return_fig: bool = True,
suptitle_fontsize: Optional[float] = None,
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Expand All @@ -254,10 +281,12 @@ def pull(
%(uns_key)s
%(time_points_push_pull)s
%(basis_push_pull)s
%(result_key_push_pull)s
%(fill_value_push_pull)s
%(scale_push_pull)s
%(title)s
%(cmap)s
%(dot_scale_factor)s
%(na_color)s
%(figsize_dpi_save)s

Returns
Expand All @@ -269,25 +298,37 @@ def pull(
%(return_push_pull)s
"""
adata, _ = _input_to_adatas(inp)

if time_points is not None:
time_points = list(time_points)
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
time_points.sort()
key = PlottingDefaults.PULL if uns_key is None else uns_key
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PULL][key]
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
categories=data["subset"],
push=False,
time_points=time_points,
basis=basis,
result_key=result_key,
constant_fill_value=fill_value,
scale=scale,
save=save,
cont_cmap=cmap,
dot_scale_factor=dot_scale_factor,
na_color=na_color,
title=title,
suptitle=suptitle,
figsize=figsize,
dpi=dpi,
ax=ax,
suptitle_fontsize=suptitle_fontsize,
**kwargs,
)
if return_fig:
Expand Down
Loading