Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/interpolate colors sankey #434

Merged
merged 29 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/moscot/_docs/_docs_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@
- `captions`
- `key`
"""

_alpha_transparancy = """\
Alpha
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
Transparancy value.
"""
_interpolate_color = """\
interpolate_color
Whether the color is continuously interpolated.
"""
_sankey_kwargs = """\
kwargs
Keyword arguments for :meth:`matplotlib.pyplot.fill_between`.
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
"""
###############################################################################
# plotting.push/pull
# input
Expand Down Expand Up @@ -133,6 +144,7 @@
Path where to save the plot. If `None`, the plot is not saved.
{_ax}"""


d_plotting = DocstringProcessor(
desc_cell_transition=_desc_cell_transition,
transition_labels_cell_transition=_transition_labels_cell_transition,
Expand All @@ -159,4 +171,7 @@
ax=_ax,
figsize_dpi_save=_figsize_dpi_save,
fontsize=_fontsize,
alpha_transparancy=_alpha_transparancy,
interpolate_color=_interpolate_color,
sankey_kwargs=_sankey_kwargs,
)
15 changes: 12 additions & 3 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from moscot.problems.base import CompoundProblem # type: ignore[attr-defined]
from moscot.problems.time import LineageProblem, TemporalProblem # type: ignore[attr-defined]
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas
from moscot.plotting._utils import _sankey, _heatmap, _plot_temporal, _input_to_adatas, _create_col_colors
from moscot._docs._docs_plot import d_plotting
from moscot._constants._constants import AdataKeys, PlottingKeys, PlottingDefaults
from moscot.problems.base._compound_problem import K
Expand Down Expand Up @@ -99,6 +99,8 @@ def sankey(
captions: Optional[List[str]] = None,
title: Optional[str] = None,
colors_dict: Optional[Dict[str, float]] = None,
alpha: float = 1.0,
interpolate_color: bool = True,
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
cmap: Union[str, mcolors.Colormap] = "viridis",
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
Expand All @@ -119,8 +121,11 @@ def sankey(
%(captions_sankey)s
%(title)s
%(colors_dict_sankey)s
%(alpha_transparency)s
%(interpolate_color)s
%(cmap)s
%(figsize_dpi_save)s
%(sankey_kwargs)s

Returns
-------
Expand Down Expand Up @@ -150,6 +155,8 @@ def sankey(
figsize=figsize,
dpi=dpi,
ax=ax,
alpha=alpha,
interpolate_color=interpolate_color,
**kwargs,
)
if save:
Expand All @@ -167,7 +174,7 @@ def push(
result_key: str = "plot_push",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
cmap: Optional[Union[str, mcolors.Colormap]] = None,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
Expand Down Expand Up @@ -206,6 +213,7 @@ def push(
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PUSH][key]
cmap = _create_col_colors(adata, data["annotation"], data["subset"]) if cmap is None else cmap
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
Expand Down Expand Up @@ -235,7 +243,7 @@ def pull(
result_key: str = "plot_pull",
fill_value: float = np.nan,
title: Optional[str] = None,
cmap: Union[str, mcolors.Colormap] = "viridis",
cmap: Optional[Union[str, mcolors.Colormap]] = None,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
save: Optional[str] = None,
Expand Down Expand Up @@ -274,6 +282,7 @@ def pull(
if key not in adata.obs:
raise KeyError(f"No data found in `adata.obs[{key!r}]`.")
data = adata.uns[AdataKeys.UNS][PlottingKeys.PULL][key]
cmap = _create_col_colors(adata, data["annotation"], data["subset"]) if cmap is None else cmap
fig = _plot_temporal(
adata=adata,
temporal_key=data["temporal_key"],
Expand Down
115 changes: 86 additions & 29 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from matplotlib.colors import ListedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import matplotlib as mp
import matplotlib as mpl

import numpy as np

Expand All @@ -20,6 +20,8 @@
from moscot._constants._constants import AggregationMode
from moscot.problems.base._compound_problem import K

_N = 200
MUCDK marked this conversation as resolved.
Show resolved Hide resolved


def set_palette(
adata: AnnData,
Expand Down Expand Up @@ -51,8 +53,10 @@ def _sankey(
fontsize: float = 12.0,
horizontal_space: float = 1.5,
force_update_colors: bool = False,
**_: Any,
) -> mp.figure.Figure:
alpha: float = 1.0,
interpolate_color: bool = True,
**kwargs: Any,
) -> mpl.figure.Figure:
if ax is None:
fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize)
if captions is not None and len(captions) != len(transition_matrices):
Expand Down Expand Up @@ -110,7 +114,8 @@ def _sankey(
2 * [leftWidths[leftLabel]["bottom"]],
2 * [leftWidths[leftLabel]["bottom"] + leftWidths[leftLabel]["left"]],
color=colorDict[leftLabel],
alpha=0.99,
alpha=alpha,
**kwargs,
)
ax.text(
-0.05 * xMax,
Expand All @@ -121,11 +126,12 @@ def _sankey(
)
for rightLabel in rightLabels:
ax.fill_between(
[xMax + left_pos[ind], 1.02 * xMax + left_pos[ind]],
[xMax + left_pos[ind], xMax + left_pos[ind]],
2 * [rightWidths[rightLabel]["bottom"]],
2 * [rightWidths[rightLabel]["bottom"] + rightWidths[rightLabel]["right"]],
color=colorDict[rightLabel],
alpha=0.99,
alpha=alpha,
**kwargs,
)
ax.text(
1.05 * xMax + left_pos[ind],
Expand All @@ -143,7 +149,6 @@ def _sankey(
# Plot strips
for leftLabel in leftLabels:
for rightLabel in rightLabels:
labelColor = leftLabel
if dataFrame.loc[leftLabel, rightLabel] > 0:
# Create array of y values for each strip, half at left value,
# half at right, convolve
Expand All @@ -161,22 +166,41 @@ def _sankey(
leftWidths[leftLabel]["bottom"] += dataFrame.loc[leftLabel, rightLabel]
rightWidths[rightLabel]["bottom"] += dataFrame.loc[leftLabel, rightLabel]

arr = np.linspace(0 + left_pos[ind], xMax + left_pos[ind], len(ys_d))
color = (
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
_color_transition(colorDict[leftLabel], colorDict[rightLabel], num=len(arr), alpha=alpha)
if interpolate_color
else colorDict[leftLabel]
)
if ind == 0:
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
ax.fill_between(
np.linspace(0 + left_pos[ind], xMax + left_pos[ind], len(ys_d)),
ys_d,
ys_u,
alpha=0.65,
color=colorDict[labelColor],
)
if interpolate_color:
for l in range(len(ys_d)): # necessary to get smooth lines
ax.fill_between(
arr[l:], ys_d[l:], ys_u[l:], color=color[l], ec=color[l], alpha=alpha, **kwargs
)
else:
ax.fill_between(arr, ys_d, ys_u, alpha=alpha, color=color, **kwargs)

else:
ax.fill_between(
np.linspace(0 + left_pos[ind], xMax + left_pos[ind], len(ys_d)),
ys_d,
ys_u,
alpha=0.65,
color=colorDict[labelColor],
)
if interpolate_color:
for l in range(len(ys_d)):
ax.fill_between(
arr[l:],
ys_d[l:],
ys_u[l:],
color=color[l],
ec=color[l],
**kwargs,
)
else:
ax.fill_between(
np.linspace(0 + left_pos[ind], xMax + left_pos[ind], len(ys_d)),
ys_d,
ys_u,
alpha=alpha,
color=color,
**kwargs,
)

ax.axis("off")
ax.set_title(title)
Expand All @@ -200,7 +224,7 @@ def _heatmap(
cbar_kwargs: Mapping[str, Any] = MappingProxyType({}),
ax: Optional[Axes] = None,
**kwargs: Any,
) -> mp.figure.Figure:
) -> mpl.figure.Figure:
cbar_kwargs = dict(cbar_kwargs)

if ax is None:
Expand All @@ -214,10 +238,10 @@ def _heatmap(
row_adata, col_adata, transition_matrix, row_annotation, col_annotation
)

row_sm = mp.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
col_sm = mp.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)
row_sm = mpl.cm.ScalarMappable(cmap=row_cmap, norm=row_norm)
col_sm = mpl.cm.ScalarMappable(cmap=col_cmap, norm=col_norm)

norm = mp.colors.Normalize(
norm = mpl.colors.Normalize(
vmin=kwargs.pop("vmin", np.nanmin(transition_matrix)), vmax=kwargs.pop("vmax", np.nanmax(transition_matrix))
)
cont_cmap = copy(plt.get_cmap(cont_cmap))
Expand Down Expand Up @@ -274,9 +298,9 @@ def _get_black_or_white(value: float, cmap: mcolors.Colormap) -> str:

def _annotate_heatmap(
transition_matrix: pd.DataFrame,
im: mp.image.AxesImage,
im: mpl.image.AxesImage,
valfmt: str = "{x:.2f}",
cmap: Union[mp.colors.Colormap, str] = "viridis",
cmap: Union[mpl.colors.Colormap, str] = "viridis",
fontsize: float = 5,
**kwargs: Any,
) -> None:
Expand All @@ -288,7 +312,7 @@ def _annotate_heatmap(
kw.update(**kwargs)

if isinstance(valfmt, str):
valfmt = mp.ticker.StrMethodFormatter(valfmt)
valfmt = mpl.ticker.StrMethodFormatter(valfmt)
if TYPE_CHECKING:
assert callable(valfmt)

Expand Down Expand Up @@ -371,7 +395,7 @@ def _plot_temporal(
ax: Optional[Axes] = None,
show: bool = False,
**kwargs: Any,
) -> mp.figure.Figure:
) -> mpl.figure.Figure:
all_keys = adata.obs[temporal_key].unique()
if time_points is None:
constant_fill_keys: Set[K] = set()
Expand Down Expand Up @@ -401,3 +425,36 @@ def _plot_temporal(
if save:
fig.figure.savefig(save, bbox_inches="tight")
return fig


def _color_transition(c1: str, c2: str, num: int, alpha: float) -> List[str]:
if not mpl.colors.is_color_like(c1):
raise ValueError(f"{c1} cannot be interpreted as an RGB color.")
if not mpl.colors.is_color_like(c2):
raise ValueError(f"{c2} cannot be interpreted as an RGB color.")
c1_rgb = np.array(mpl.colors.to_rgb(c1))
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
c2_rgb = np.array(mpl.colors.to_rgb(c2))
return [mpl.colors.to_rgb((1 - n / num) * c1_rgb + n / num * c2_rgb) + (alpha,) for n in range(num)]


def _create_col_colors(adata: AnnData, obs_col: Optional[str], subset: Optional[str]) -> Optional[mcolors.Colormap]:
if obs_col is not None and subset is not None:
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(subset, list) and isinstance(subset[0], str):
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
subset = subset[0]
if isinstance(subset, str):
try:
colorDict = {
cat: adata.uns[f"{obs_col}_colors"][i] for i, cat in enumerate(adata.obs[obs_col].cat.categories)
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
}
except KeyError:
raise KeyError(f"Unable to access `adata.uns[{obs_col}_colors]`.") from None
MUCDK marked this conversation as resolved.
Show resolved Hide resolved

color = colorDict[subset]

h, _, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
end_color = mcolors.hsv_to_rgb([h, 1, v])

col_cmap = mcolors.LinearSegmentedColormap.from_list("lineage_cmap", ["#ffffff", end_color], N=_N)

return col_cmap
return None
4 changes: 4 additions & 0 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def push(
if key_added is not None:
plot_vars = {
"temporal_key": self.temporal_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
}
self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PUSH, key_added, plot_vars)
Expand Down Expand Up @@ -393,6 +395,8 @@ def pull(
if key_added is not None:
plot_vars = {
"temporal_key": self.temporal_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
}
self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key)
Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PULL, key_added, plot_vars)
Expand Down
Binary file modified tests/plotting/_images/Plotting_pull.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/plotting/_images/Plotting_push.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/plotting/_images/Plotting_sankey.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/plotting/_images/Plotting_sankey_params.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions tests/plotting/_images/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from anndata import AnnData
import scanpy as sc

from tests.plotting.conftest import PlotTester, PlotTesterMeta
import moscot.plotting as mpl

sc.pl.set_rcParams_defaults()
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
sc.set_figure_params(dpi=40, color_map="viridis")

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestPlotting(PlotTester, metaclass=PlotTesterMeta):
def test_plot_cell_transition(self, adata_pl_cell_transition: AnnData):
mpl.cell_transition(adata_pl_cell_transition)

def test_plot_cell_transition_params(self, adata_pl_cell_transition: AnnData):
mpl.cell_transition(adata_pl_cell_transition, annotate=None, cmap="inferno", fontsize=15)

def test_plot_sankey(self, adata_pl_sankey: AnnData):
mpl.sankey(adata_pl_sankey)

def test_plot_sankey_params(self, adata_pl_sankey: AnnData):
mpl.sankey(adata_pl_sankey, captions=["Test", "Other test"], title="Title", figsize=(3, 3))

def test_plot_push(self, adata_pl_push: AnnData):
mpl.push(adata_pl_push, time_points=[2])

def test_plot_pull(self, adata_pl_pull: AnnData):
mpl.pull(adata_pl_pull, time_points=[1])
8 changes: 6 additions & 2 deletions tests/plotting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
@pytest.fixture()
def adata_pl_push(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"temporal_key": "time"}
plot_vars = {"temporal_key": "time", "annotation": "celltype", "subset": "A"}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PUSH, PlottingDefaults.PUSH, plot_vars)
adata_time.obs[PlottingDefaults.PUSH] = np.abs(rng.randn(len(adata_time)))
return adata_time
Expand All @@ -50,7 +52,9 @@ def adata_pl_push(adata_time: AnnData) -> AnnData:
@pytest.fixture()
def adata_pl_pull(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"temporal_key": "time"}
plot_vars = {"temporal_key": "time", "annotation": "celltype", "subset": "A"}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PULL, PlottingDefaults.PULL, plot_vars)
adata_time.obs[PlottingDefaults.PULL] = np.abs(rng.randn(len(adata_time)))
return adata_time
Expand Down
Loading