Skip to content

Commit

Permalink
Feature/plot order (#453)
Browse files Browse the repository at this point in the history
* make push/pull plot in good order

* [CI skip], try setting adata.uns color explicitly

* [CI skip], fix copying of adata

* fix pre commits

* fix bug
  • Loading branch information
MUCDK authored and lucaeyring committed Mar 15, 2023
1 parent 34ef904 commit 9806854
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_palette(
adata: AnnData,
key: str,
cont_cmap: Union[str, mcolors.Colormap] = "viridis",
force_update_colors: bool = False,
force_update_colors: bool = True,
**_: Any,
) -> None:
"""Set palette."""
Expand Down Expand Up @@ -408,7 +408,8 @@ def _plot_temporal(
else:
titles = [f"{categories} at time {source if push else target} and {name}"]
for i, ax in enumerate(axs):
with RandomKeys(adata, n=1, where="obs") as keys:
# we need to create adata_view because otherwise the view of the adata is copied in the next step i+1
with RandomKeys(adata, n=2, where="obs") as keys:
if time_points is None:
if scale:
adata.obs[keys[0]] = (
Expand All @@ -417,6 +418,7 @@ def _plot_temporal(
else:
adata.obs[keys[0]] = adata.obs[key_stored]
size = None
adata_view = adata
else:
tmp = np.full(len(adata), constant_fill_value)
mask = adata.obs[temporal_key] == time_points[i]
Expand All @@ -439,15 +441,28 @@ def _plot_temporal(
column = pd.Series(tmp).fillna(st).astype("category")
if len(np.unique(column[mask.values].values)) != 2:
raise ValueError(f"Not exactly two categories, found `{column.cat.categories}`.")
kwargs["palette"] = {vmax: cont_cmap.reversed()(0), vmin: cont_cmap(0), st: na_color}
adata.obs[keys[0]] = column.values
adata.obs[keys[0]] = adata.obs[keys[0]].astype("category")
adata.obs[keys[1]] = list(range(adata.n_obs))

set_palette(
adata, keys[0], cont_cmap={vmax: cont_cmap.reversed()(0), vmin: cont_cmap(0), st: na_color}
)

cells_with_vmax = adata[adata.obs[keys[0]] == vmax].obs[keys[1]].values
cells_with_vmin = adata[adata.obs[keys[0]] == vmin].obs[keys[1]].values
cells_with_st = adata[adata.obs[keys[0]] == st].obs[keys[1]].values
indices = list(cells_with_st) + list(cells_with_vmin) + list(cells_with_vmax)
adata_view = adata[indices, :]
size = size[indices]
else:
kwargs["color_map"] = cont_cmap
kwargs["na_color"] = na_color
adata.obs[keys[0]] = tmp
adata_view = adata

sc.pl.embedding(
adata=adata,
adata=adata_view,
basis=basis,
color=keys[0],
title=titles[i],
Expand Down

0 comments on commit 9806854

Please sign in to comment.