Skip to content

Commit

Permalink
Merge pull request #609 from AlexanderCaichen/plot_dup_function
Browse files Browse the repository at this point in the history
Plot helper function
  • Loading branch information
Xiaojieqiu authored Dec 10, 2023
2 parents ff4bac2 + ed45e4d commit b074392
Show file tree
Hide file tree
Showing 18 changed files with 252 additions and 1,197 deletions.
30 changes: 4 additions & 26 deletions dynamo/plot/cell_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from anndata import AnnData
from matplotlib.axes import Axes

from ..tools.utils import update_dict
from .utils import save_fig
from .utils import save_show_ret


def cell_cycle_scores(
Expand All @@ -25,8 +24,8 @@ def cell_cycle_scores(
cells: a list of cell ids used to subset the AnnData object. If None, all cells would be used. Defaults to None.
save_show_or_return: whether to save, show, or return the figure. Available flags are `"save"`, `"show"`, and
`"return"`. Defaults to "show".
save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent":
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent":
True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that
properly modify those keys according to your needs. Defaults to {}.
Expand Down Expand Up @@ -72,25 +71,4 @@ def cell_cycle_scores(
# Heatmap returns an axes obj but you need to get a mappable obj (get_children)
colorbar(ax.get_children()[0], cax=cax, ticks=[-0.9, 0, 0.9])

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "plot_direct_graph",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

s_kwargs = update_dict(s_kwargs, save_kwargs)

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return ax
return save_show_ret("plot_direct_graph", save_show_or_return, save_kwargs, ax)
60 changes: 7 additions & 53 deletions dynamo/plot/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
from ..configuration import _themes
from ..docrep import DocstringProcessor
from ..tools.connectivity import check_and_recompute_neighbors
from ..tools.utils import update_dict
from .utils import is_list_of_lists # is_gene_name
from .utils import (
_datashade_points,
_embed_datashader_in_an_axis,
_get_extent,
_select_font_color,
save_fig,
save_show_ret,
)

docstrings = DocstringProcessor()
Expand Down Expand Up @@ -185,8 +184,8 @@ def connectivity_base(
sort: the method to reorder data so that high values points will be on top of background points. Can be one of
{'raw', 'abs'}, i.e. sorted by raw data or sort by absolute values. Defaults to "raw".
save_show_or_return: whether to save, show or return the figure. Defaults to "return".
save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'connectivity_base',
Expand All @@ -203,7 +202,6 @@ def connectivity_base(
ImportError: `datashader` is not installed.
NotImplementedError: invalid `theme`.
ValueError: invalid `edge_bundling`.
NotImplementedError: invalid `save_show_or_return`.
Returns:
The matplotlib axis with the relevant plot displayed by default. If `save_show_or_return` is set to be `"show"`
Expand Down Expand Up @@ -309,28 +307,7 @@ def connectivity_base(

ax.set(xticks=[], yticks=[])

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "connectivity_base",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)

if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return ax
return save_show_ret("connectivity_base", save_show_or_return, save_kwargs, ax)


docstrings.delete_params("con_base.parameters", "edge_df", "save_show_or_return", "save_kwargs")
Expand Down Expand Up @@ -429,8 +406,8 @@ def nneighbors(
ax: the axis on which the subplot would be shown. If set to be `None`, a new axis would be created. Defaults to
None.
save_show_or_return: whether to save, show or return the figure. Defaults to "return".
save_kwargs: a dictionary that will passed to the save_fig function. By default it is an empty dictionary and
the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'connectivity_base',
Expand All @@ -445,7 +422,6 @@ def nneighbors(
Raises:
TypeError: wrong type of `x` and `y`.
NotImplementedError: invalid `save_show_or_return`.
Returns:
The matplotlib axis with the plotted knn graph by default. If `save_show_or_return` is set to be `"show"`
Expand Down Expand Up @@ -580,29 +556,7 @@ def nneighbors(
ax.set_ylabel(cur_b + "_2")
ax.set_title(cur_c)

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "nneighbors",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return g
else:
raise NotImplementedError('Invalid "save_show_or_return".')
return save_show_ret("nneighbors", save_show_or_return, save_kwargs, g)


def pgraph():
Expand Down
58 changes: 7 additions & 51 deletions dynamo/plot/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
despline,
despline_all,
quiver_autoscaler,
save_fig,
save_show_ret,
)
from .utils_dynamics import *

Expand Down Expand Up @@ -155,8 +155,8 @@ def phase_portraits(
currently using. If None, only the first panel in the expression / velocity plot will have the arrowed
spine. Defaults to None.
save_show_or_return: whether to save, show or return the figure. Defaults to "show".
save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'phase_portraits',
Expand Down Expand Up @@ -1067,31 +1067,7 @@ def phase_portraits(

update_vel_params(adata, params_df=vel_params_df)

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "phase_portraits",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:

with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.tight_layout()

plt.show()
if save_show_or_return in ["return", "all"]:
return g
return save_show_ret("phase_portraits", save_show_or_return, save_kwargs, g)


def dynamics(
Expand Down Expand Up @@ -1142,8 +1118,8 @@ def dynamics(
gene_order: the order of genes to present on the figure, either row-major or column major. Defaults to "column".
font_size_scale: the scale factor of fonts. Defaults to 1.
save_show_or_return: whether to save, show or return the figure. Defaults to "show".
save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'dynamics',
Expand Down Expand Up @@ -2733,27 +2709,7 @@ def dynamics(
elif experiment_type == "coassay":
pass # show protein velocity (steady state and the Gamma distribution model)
# g.autofmt_xdate(rotation=-30, ha='right')
if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "dynamics",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return g
return save_show_ret("dynamics", save_show_or_return, save_kwargs, g)


def dynamics_(
Expand Down
59 changes: 8 additions & 51 deletions dynamo/plot/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from matplotlib.axes import Axes

from ..prediction.fate import fate_bias as fate_bias_pd
from ..tools.utils import update_dict
from .scatters import save_fig, scatters
from .utils import map2color
from .scatters import scatters
from .utils import map2color, save_show_ret


def fate_bias(
Expand Down Expand Up @@ -44,8 +43,8 @@ def fate_bias(
fate_bias_df = dyn.tl.fate_bias(adata). Defaults to None.
figsize: the size of the figure. Defaults to (6, 4).
save_show_or_return: whether to save, show or return the figure. Defaults to "show".
save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'phase_portraits',
Expand Down Expand Up @@ -74,27 +73,7 @@ def fate_bias(
fate_bias, col_cluster=True, row_cluster=True, figsize=figsize, yticklabels=False, **cluster_maps_kwargs
)

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "fate_bias",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return ax
return save_show_ret("fate_bias", save_show_or_return, save_kwargs, ax)


def fate(
Expand All @@ -119,8 +98,8 @@ def fate(
ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component.
If None, new axis would be created. Defaults to None.
save_show_or_return: whether to save, show or return the figure. Defaults to "show".
save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the
save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary
and the save_show_ret function will use the
{
"path": None,
"prefix": 'phase_portraits',
Expand All @@ -147,26 +126,4 @@ def fate(
ax.scatter(*i[:, [x, y]].T, c=map2color(j))
ax.plot(*i[:, [x, y]].T, c="k")

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "kinetic_curves",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}

# prevent the plot from being closed if the plot need to be shown or returned.
if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

s_kwargs = update_dict(s_kwargs, save_kwargs)

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return ax
return save_show_ret("kinetic_curves", save_show_or_return, save_kwargs, ax)
Loading

0 comments on commit b074392

Please sign in to comment.