diff --git a/dynamo/plot/cell_cycle.py b/dynamo/plot/cell_cycle.py index 9fd4686f9..a0b1794c5 100644 --- a/dynamo/plot/cell_cycle.py +++ b/dynamo/plot/cell_cycle.py @@ -8,8 +8,7 @@ from anndata import AnnData from matplotlib.axes import Axes -from ..tools.utils import update_dict -from .utils import save_fig +from .utils import save_show_ret def cell_cycle_scores( @@ -25,8 +24,8 @@ def cell_cycle_scores( cells: a list of cell ids used to subset the AnnData object. If None, all cells would be used. Defaults to None. save_show_or_return: whether to save, show, or return the figure. Available flags are `"save"`, `"show"`, and `"return"`. Defaults to "show". - save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the - save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -72,25 +71,4 @@ def cell_cycle_scores( # Heatmap returns an axes obj but you need to get a mappable obj (get_children) colorbar(ax.get_children()[0], cax=cax, ticks=[-0.9, 0, 0.9]) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_direct_graph", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_direct_graph", save_show_or_return, save_kwargs, ax) diff --git a/dynamo/plot/connectivity.py b/dynamo/plot/connectivity.py index dc7aa22df..88aa0d1e4 100755 --- a/dynamo/plot/connectivity.py +++ b/dynamo/plot/connectivity.py @@ -32,14 +32,13 @@ from ..configuration import _themes from ..docrep import DocstringProcessor from ..tools.connectivity import check_and_recompute_neighbors -from ..tools.utils import update_dict from .utils import is_list_of_lists # is_gene_name from .utils import ( _datashade_points, _embed_datashader_in_an_axis, _get_extent, _select_font_color, - save_fig, + save_show_ret, ) docstrings = DocstringProcessor() @@ -185,8 +184,8 @@ def connectivity_base( sort: the method to reorder data so that high values points will be on top of background points. Can be one of {'raw', 'abs'}, i.e. sorted by raw data or sort by absolute values. Defaults to "raw". save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'connectivity_base', @@ -203,7 +202,6 @@ def connectivity_base( ImportError: `datashader` is not installed. NotImplementedError: invalid `theme`. ValueError: invalid `edge_bundling`. - NotImplementedError: invalid `save_show_or_return`. Returns: The matplotlib axis with the relevant plot displayed by default. If `save_show_or_return` is set to be `"show"` @@ -309,28 +307,7 @@ def connectivity_base( ax.set(xticks=[], yticks=[]) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "connectivity_base", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("connectivity_base", save_show_or_return, save_kwargs, ax) docstrings.delete_params("con_base.parameters", "edge_df", "save_show_or_return", "save_kwargs") @@ -429,8 +406,8 @@ def nneighbors( ax: the axis on which the subplot would be shown. If set to be `None`, a new axis would be created. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will passed to the save_fig function. By default it is an empty dictionary and - the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'connectivity_base', @@ -445,7 +422,6 @@ def nneighbors( Raises: TypeError: wrong type of `x` and `y`. - NotImplementedError: invalid `save_show_or_return`. Returns: The matplotlib axis with the plotted knn graph by default. If `save_show_or_return` is set to be `"show"` @@ -580,29 +556,7 @@ def nneighbors( ax.set_ylabel(cur_b + "_2") ax.set_title(cur_c) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "nneighbors", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g - else: - raise NotImplementedError('Invalid "save_show_or_return".') + return save_show_ret("nneighbors", save_show_or_return, save_kwargs, g) def pgraph(): diff --git a/dynamo/plot/dynamics.py b/dynamo/plot/dynamics.py index ad366dbb2..676aa6191 100755 --- a/dynamo/plot/dynamics.py +++ b/dynamo/plot/dynamics.py @@ -27,7 +27,7 @@ despline, despline_all, quiver_autoscaler, - save_fig, + save_show_ret, ) from .utils_dynamics import * @@ -155,8 +155,8 @@ def phase_portraits( currently using. If None, only the first panel in the expression / velocity plot will have the arrowed spine. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'phase_portraits', @@ -1067,31 +1067,7 @@ def phase_portraits( update_vel_params(adata, params_df=vel_params_df) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "phase_portraits", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("phase_portraits", save_show_or_return, save_kwargs, g) def dynamics( @@ -1142,8 +1118,8 @@ def dynamics( gene_order: the order of genes to present on the figure, either row-major or column major. Defaults to "column". font_size_scale: the scale factor of fonts. Defaults to 1. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'dynamics', @@ -2733,27 +2709,7 @@ def dynamics( elif experiment_type == "coassay": pass # show protein velocity (steady state and the Gamma distribution model) # g.autofmt_xdate(rotation=-30, ha='right') - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "dynamics", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("dynamics", save_show_or_return, save_kwargs, g) def dynamics_( diff --git a/dynamo/plot/fate.py b/dynamo/plot/fate.py index 0ac572c24..84ed8ca83 100755 --- a/dynamo/plot/fate.py +++ b/dynamo/plot/fate.py @@ -12,9 +12,8 @@ from matplotlib.axes import Axes from ..prediction.fate import fate_bias as fate_bias_pd -from ..tools.utils import update_dict -from .scatters import save_fig, scatters -from .utils import map2color +from .scatters import scatters +from .utils import map2color, save_show_ret def fate_bias( @@ -44,8 +43,8 @@ def fate_bias( fate_bias_df = dyn.tl.fate_bias(adata). Defaults to None. figsize: the size of the figure. Defaults to (6, 4). save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'phase_portraits', @@ -74,27 +73,7 @@ def fate_bias( fate_bias, col_cluster=True, row_cluster=True, figsize=figsize, yticklabels=False, **cluster_maps_kwargs ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "fate_bias", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("fate_bias", save_show_or_return, save_kwargs, ax) def fate( @@ -119,8 +98,8 @@ def fate( ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component. If None, new axis would be created. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the { "path": None, "prefix": 'phase_portraits', @@ -147,26 +126,4 @@ def fate( ax.scatter(*i[:, [x, y]].T, c=map2color(j)) ax.plot(*i[:, [x, y]].T, c="k") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "kinetic_curves", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("kinetic_curves", save_show_or_return, save_kwargs, ax) diff --git a/dynamo/plot/heatmaps.py b/dynamo/plot/heatmaps.py index e87f8fe42..7a1326596 100644 --- a/dynamo/plot/heatmaps.py +++ b/dynamo/plot/heatmaps.py @@ -23,7 +23,7 @@ hill_inh_func, hill_inh_grad, ) -from ..tools.utils import flatten, update_dict +from ..tools.utils import flatten from ..vectorfield.utils import get_jacobian from ..vectorfield.vector_calculus import hessian as run_hessian from .utils import ( @@ -37,7 +37,7 @@ is_gene_name, is_layer_keys, is_list_of_lists, - save_fig, + save_show_ret, ) @@ -249,8 +249,8 @@ def response( figsize: size of the figure. Defaults to (6, 4). save_show_or_return: whether to save or show the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. return_data: whether to return the data used to generate the heatmap. Defaults to False. @@ -609,29 +609,7 @@ def scale_func(x, X, grid_num): axes[i, j].set_yticklabels(ylabels) plt.subplots_adjust(left=0.1, right=1, top=0.80, bottom=0.1, wspace=0.1) - if save_show_or_return in ["save", "both"]: - s_kwargs = { - "path": None, - "prefix": "scatters", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return == "both": - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() + save_show_ret("scatters", save_show_or_return, save_kwargs) list_for_return = [] @@ -680,8 +658,8 @@ def plot_hill_function( linewidth: the line width of the curve. Defaults to 2. save_show_or_return: whether to save or show the figure. Could be one of "save", "show", "both", or "all". "both" and "all" have the same effect. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **plot_kwargs: any other kwargs passed to `pyplot.plot`. @@ -762,29 +740,7 @@ def plot_hill_function( raise NotImplementedError(f"The fit mode `{mode}` is not supported.") plt.subplots_adjust(left=0.1, right=1, top=0.80, bottom=0.1, wspace=0.1) - if save_show_or_return in ["save", "both"]: - s_kwargs = { - "path": None, - "prefix": "scatters", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return == "both": - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() + save_show_ret("scatters", save_show_or_return, save_kwargs) def causality( @@ -854,8 +810,8 @@ def causality( figsize: the size of the figure. Defaults to (6, 4). save_show_or_return: whether to save or show the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. Defaults to {}. return_data: whether to return the calculated causality data. Defaults to False. @@ -1190,29 +1146,7 @@ def causality( # plt.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) plt.subplots_adjust(left=0.1, right=1, top=0.80, bottom=0.1, wspace=0.1) - if save_show_or_return in ["save", "both"]: - s_kwargs = { - "path": None, - "prefix": "scatters", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return == "both": - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() + save_show_ret("scatters", save_show_or_return, save_kwargs) if return_data: return flat_res @@ -1292,8 +1226,8 @@ def comb_logic( figsize: the size of the figure. Defaults to (6, 4). save_show_or_return: whether to save or show the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. Defaults to {}. return_data: whether to return the calculated causality data. Defaults to False. @@ -1453,8 +1387,8 @@ def hessian( figsize: the size of the figure. Defaults to (6, 4). save_show_or_return: whether to save or show the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. Defaults to {}. return_data: whether to return the calculated causality data. Defaults to False. diff --git a/dynamo/plot/least_action_path.py b/dynamo/plot/least_action_path.py index f56f5d889..aa42e1295 100644 --- a/dynamo/plot/least_action_path.py +++ b/dynamo/plot/least_action_path.py @@ -14,10 +14,9 @@ interp_second_derivative, kneedle_difference, ) -from ..tools.utils import update_dict from ..utils import denormalize, normalize from .ezplots import plot_X, zscatter -from .scatters import save_fig, scatters +from .scatters import save_show_ret, scatters from .utils import map2color @@ -45,8 +44,8 @@ def least_action( save_show_or_return: whether the figure should be saved, show, or return. Can be one of "save", "show", "return", "both", "all". "both" means that the figure would be shown and saved but not returned. Defaults to "show". - save_kwargs:a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary and - the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -66,28 +65,7 @@ def least_action( ax.scatter(*i[:, [x, y]].T, c=map2color(j)) ax.plot(*i[:, [x, y]].T, c="k") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "kinetic_curves", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("kinetic_curves", save_show_or_return, save_kwargs, ax) def lap_min_time( @@ -116,8 +94,8 @@ def lap_min_time( n_col: the number of subplot columns. Defaults to 3. save_show_or_return: whether to save or show the figure. Can be one of "save", "show", "both" or "all". "both" and "all" have the same effect. The axis of the plot cannot be returned here. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: not used here. @@ -187,24 +165,4 @@ def lap_min_time( # scatters(adata, basis=basis, color=color, ax=axes[i, j], **kwargs) # axes[i, j].scatter(*i[:, [x, y]].T, c=map2color(j)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "kinetic_curves", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return == "both": - s_kwargs["close"] = False - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both"]: - plt.tight_layout() - plt.show() + return save_show_ret("kinetic_curves", save_show_or_return, save_kwargs) diff --git a/dynamo/plot/markers.py b/dynamo/plot/markers.py index 8a3343b57..c6bd0012d 100644 --- a/dynamo/plot/markers.py +++ b/dynamo/plot/markers.py @@ -14,9 +14,9 @@ from matplotlib.figure import Figure from scipy.sparse import issparse -from ..configuration import _themes, reset_rcParams, set_figure_params -from ..tools.utils import get_mapper, update_dict -from .utils import save_fig +from ..configuration import _themes, set_figure_params +from ..tools.utils import get_mapper +from .utils import save_show_ret def bubble( @@ -112,8 +112,8 @@ def bubble( figsize: the size of the figure. Defaults to None. save_show_or_return: whether to save, show or return the figure. Can be one of "save", "show", or "return". Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. @@ -316,34 +316,4 @@ def bubble( ) axes[igene].set_xlabel("") if transpose else axes[igene].set_ylabel("") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "violin", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if background is not None: - reset_rcParams() - if save_show_or_return in ["show", "both", "all"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() - if background is not None: - reset_rcParams() - if save_show_or_return in ["return", "all"]: - if background is not None: - reset_rcParams() - - return fig, axes + return save_show_ret("violin", save_show_or_return, save_kwargs, (fig, axes), background=background) diff --git a/dynamo/plot/networks.py b/dynamo/plot/networks.py index 484e87f7c..6b6cd9f9a 100644 --- a/dynamo/plot/networks.py +++ b/dynamo/plot/networks.py @@ -12,7 +12,7 @@ from matplotlib.axes import Axes from ..tools.utils import flatten, index_gene, update_dict -from .utils import save_fig, set_colorbar +from .utils import save_fig, save_show_ret, set_colorbar from .utils_graph import ArcPlot @@ -583,24 +583,4 @@ def hivePlot( # ax.legend(custom_lines, reg_groups, loc='upper left', bbox_to_anchor=(0.37, 0.35), # title="Regulatory network based on Jacobian analysis") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "hiveplot", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("hiveplot", save_show_or_return, save_kwargs, ax) diff --git a/dynamo/plot/preprocess.py b/dynamo/plot/preprocess.py index f41787596..696c74d3b 100755 --- a/dynamo/plot/preprocess.py +++ b/dynamo/plot/preprocess.py @@ -19,8 +19,8 @@ from ..preprocessing import gene_selection from ..preprocessing.gene_selection import get_prediction_by_svr from ..preprocessing.utils import detect_experiment_datatype -from ..tools.utils import get_mapper, update_dict -from .utils import save_fig +from ..tools.utils import get_mapper +from .utils import save_fig, save_show_ret def basic_stats( @@ -39,8 +39,8 @@ def basic_stats( figsize: the size of each panel in the figure. Defaults to (4, 3). save_show_or_return: whether to save, show, or return the plots. Could be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'basic_stats', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'basic_stats', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -103,25 +103,7 @@ def basic_stats( g.set_ylabels("") g.set(ylim=(0, None)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "basic_stats", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("basic_stats", save_show_or_return, save_kwargs, g) def show_fraction( @@ -142,8 +124,8 @@ def show_fraction( figsize: the size of each panel in the figure. Defaults to (4, 3). save_show_or_return: whether to save, show, or return the plots. Could be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'show_fraction', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'show_fraction', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -339,27 +321,7 @@ def show_fraction( g.set_ylabels("Fraction") g.set(ylim=(0, None)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "show_fraction", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("show_fraction", save_show_or_return, save_kwargs, g) def variance_explained( @@ -382,8 +344,8 @@ def variance_explained( figsize: the size of each panel of the figure. Defaults to (4, 3). save_show_or_return: whether to save, show, or return the generated figure. Can be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'variance_explained', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'variance_explained', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -403,27 +365,7 @@ def variance_explained( ax.set_xticks(list(ax.get_xticks()) + [n_comps]) ax.set_xlim(0, len(var_)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "variance_explained", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("variance_explained", save_show_or_return, save_kwargs, ax) def biplot( @@ -461,8 +403,8 @@ def biplot( draw_pca_embedding: whether to draw the pca embedding. Defaults to False. save_show_or_return: whether to save, show, or return the generated figure. Can be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'variance_explained', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the function will use the {"path": None, "prefix": 'variance_explained', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -515,27 +457,7 @@ def biplot( ax.plot(xs[i] * scalex, ys[i] * scaley, "b", alpha=0.1) ax.text(xs[i] * scalex * 1.01, ys[i] * scaley * 1.01, list(adata.obs.cluster)[i], color="b", alpha=0.1) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "biplot", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("biplot", save_show_or_return, save_kwargs, ax) def loading( @@ -561,8 +483,8 @@ def loading( figsize: the size of each panel of the figure. Defaults to (6, 4). save_show_or_return: whether to save, show, or return the generated figure. Can be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'biplot', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the function will use the {"path": None, "prefix": 'biplot', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -605,27 +527,7 @@ def loading( axes[cur_row, cur_col].set_title("PC " + str(i)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "loading", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return axes + return save_show_ret("loading", save_show_or_return, save_kwargs, axes) def feature_genes( @@ -645,8 +547,8 @@ def feature_genes( figsize: the size of each panel of the figure. Defaults to (4, 3). save_show_or_return: whether to save, show, or return the generated figure. Can be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'feature_genes', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'feature_genes', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -722,27 +624,7 @@ def feature_genes( plt.xlabel("Mean (log)") plt.ylabel("Dispersion (log)") if mode == "dispersion" else plt.ylabel("CV (log)") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "feature_genes", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("feature_genes", save_show_or_return, save_kwargs, ax) def exp_by_groups( @@ -779,8 +661,8 @@ def exp_by_groups( figsize: the size of each panel of the figure. Defaults to (4, 3). save_show_or_return: whether to save, show, or return the generated figure. Can be one of 'save', 'show', or 'return'. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'exp_by_groups', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'exp_by_groups', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -906,27 +788,7 @@ def exp_by_groups( g.set_xlabels("") g.set(ylim=(0, None)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "exp_by_groups", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("exp_by_groups", save_show_or_return, save_kwargs, g) def highest_frac_genes( diff --git a/dynamo/plot/pseudotime.py b/dynamo/plot/pseudotime.py index a684a0ddb..0f6d0573e 100755 --- a/dynamo/plot/pseudotime.py +++ b/dynamo/plot/pseudotime.py @@ -13,8 +13,7 @@ from anndata import AnnData from scipy.sparse import csr_matrix -from ..tools.utils import update_dict -from .utils import get_color_map_from_labels, save_fig +from .utils import get_color_map_from_labels, save_show_ret def _calculate_cells_mapping( @@ -187,27 +186,7 @@ def plot_dim_reduced_direct_graph( fontsize="medium", ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_dim_reduced_direct_graph", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("plot_dim_reduced_direct_graph", save_show_or_return, save_kwargs, g) def plot_direct_graph( @@ -266,24 +245,4 @@ def plot_direct_graph( else: raise Exception("layout", layout, " is not supported.") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_direct_graph", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("plot_direct_graph", save_show_or_return, save_kwargs, g) diff --git a/dynamo/plot/scPotential.py b/dynamo/plot/scPotential.py index 74843657b..263582c84 100755 --- a/dynamo/plot/scPotential.py +++ b/dynamo/plot/scPotential.py @@ -9,8 +9,7 @@ from anndata import AnnData from matplotlib.axes import Axes -from ..tools.utils import update_dict -from .utils import save_fig +from .utils import save_show_ret def show_landscape( @@ -32,8 +31,8 @@ def show_landscape( basis: the method of dimension reduction. By default, it is trimap. Currently, it is not checked with Xgrid and Ygrid. Defaults to "umap". save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'show_landscape', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'show_landscape', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your need. Defaults to {}. @@ -87,27 +86,7 @@ def show_landscape( ax.set_ylabel(basis + "_2") ax.set_zlabel("U") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "show_landscape", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("show_landscape", save_show_or_return, save_kwargs, ax) # show_pseudopot(Xgrid, Ygrid, Zgrid) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 0d7460af2..ff227eaa0 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -33,7 +33,7 @@ default_quiver_args, quiver_autoscaler, retrieve_plot_save_path, - save_fig, + save_show_ret, save_plotly_figure, save_pyvista_plotter, set_arrow_alpha, @@ -149,8 +149,8 @@ def cell_wise_vectors_3d( vector: which vector type will be used for plotting, one of {'velocity', 'acceleration'} or either velocity field or acceleration field will be plotted. Defaults to "velocity". save_show_or_return: whether to save, show or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - an the save_fig function will use the {"path": None, "prefix": 'cell_wise_velocity', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'cell_wise_velocity', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. quiver_3d_kwargs: any other kwargs to be passed to `pyplot.quiver`. Defaults to { "zorder": 3, "length": 2, @@ -491,26 +491,7 @@ def add_axis_label(ax, labels): ax.set_facecolor(background) add_axis_label(ax, axis_labels) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "cell_wise_vectors_3d", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list + return save_show_ret("cell_wise_vectors_3d", save_show_or_return, save_kwargs, axes_list, tight=False) def grid_vectors_3d(): @@ -615,8 +596,8 @@ def line_integral_conv( field or acceleration field will be plotted. Defaults to "velocity". file: the path to save the slice figure. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'line_integral_conv', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'line_integral_conv', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. g_kwargs_dict: any other kwargs that would be passed to `dynamo.tl.grid_velocity_filter`. Defaults to {}. @@ -626,8 +607,8 @@ def line_integral_conv( Exception: _description_ Returns: - None would be returned by default. If `save_show_or_return` is set to be True, the generated `yt.SlicePlot` will - be returned. + None would be returned by default. If `save_show_or_return` is set to "return" or "all", the generated + `yt.SlicePlot` will be returned. """ import matplotlib.pyplot as plt @@ -742,27 +723,7 @@ def line_integral_conv( # plot_LIC_gray(velocyto_tex) pass - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "line_integral_conv", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return slc + return save_show_ret("line_integral_conv", save_show_or_return, save_kwargs, slc) @docstrings.with_indent(4) @@ -893,8 +854,8 @@ def cell_wise_vectors( tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'cell_wise_velocity', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'cell_wise_velocity', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. s_kwargs_dict: any other kwargs that will be passed to `dynamo.pl.scatters`. Defaults to {}. @@ -1083,28 +1044,7 @@ def cell_wise_vectors( ) ax.set_facecolor(background) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "cell_wise_vector", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - if projection != "3d": - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list + return save_show_ret("cell_wise_vector", save_show_or_return, save_kwargs, axes_list, tight = projection != "3d") @docstrings.with_indent(4) @@ -1241,8 +1181,8 @@ def grid_vectors( tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'grid_velocity', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'grid_velocity', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. s_kwargs_dict: any other kwargs that would be passed to `dynamo.pl.scatters`. Defaults to {}. @@ -1448,27 +1388,7 @@ def grid_vectors( axes_list.quiver(X_grid[0], X_grid[1], V_grid[0], V_grid[1], **quiver_kwargs) axes_list.set_facecolor(background) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "grid_velocity", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list + return save_show_ret("grid_velocity", save_show_or_return, save_kwargs, axes_list) @docstrings.with_indent(4) @@ -1601,8 +1521,8 @@ def streamline_plot( tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'streamline_plot', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'streamline_plot', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. s_kwargs_dict: any other kwargs that would be passed to `dynamo.pl.scatters`. Defaults to {}. @@ -1814,27 +1734,7 @@ def streamplot_2d(ax): ax = axes_list[i] streamplot_2d(ax) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "streamline_plot", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list + return save_show_ret("streamline_plot", save_show_or_return, save_kwargs, axes_list) # refactor line_conv_integration @@ -1862,8 +1762,8 @@ def plot_energy( fig: the figure object where panels of the energy or energy change rate over iteration plots will be appended to. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'energy', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'energy', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. @@ -1909,24 +1809,4 @@ def plot_energy( plt.xlabel("Iteration") plt.ylabel("Energy change rate") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "energy", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return fig + return save_show_ret("energy", save_show_or_return, save_kwargs, fig) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index e00b520d2..b2c7532d0 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -19,7 +19,7 @@ from matplotlib.colors import rgb2hex, to_hex from pandas.api.types import is_categorical_dtype -from ..configuration import _themes, reset_rcParams +from ..configuration import _themes from ..docrep import DocstringProcessor from ..dynamo_logger import main_debug, main_info, main_warning from ..preprocessing.utils import affine_transform, gen_rotation_2d @@ -39,7 +39,7 @@ is_layer_keys, is_list_of_lists, retrieve_plot_save_path, - save_fig, + save_show_ret, save_plotly_figure, save_pyvista_plotter, ) @@ -184,8 +184,8 @@ def scatters( save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. Defaults to "show". - save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and - the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: A dictionary that will passed to the save_show_ret function. By default it is an empty dictionary and + the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. return_all: whether to return all the scatter related variables. Defaults to False. @@ -876,45 +876,12 @@ def _plot_basis_layer(cur_b, cur_l): _plot_basis_layer(cur_b, cur_l) main_debug("show, return or save...") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - save_fig(**s_kwargs) - if background is not None: - reset_rcParams() - if save_show_or_return in ["show", "both", "all"]: - if show_legend: - plt.subplots_adjust(right=0.85) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() - if background is not None: - reset_rcParams() - if save_show_or_return in ["return", "all"]: - if background is not None: - reset_rcParams() - - if return_all: - return (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) - else: - return axes_list if total_panels > 1 else ax + return_value = None + if return_all: + return_value = (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) + else: + return_value = axes_list if total_panels > 1 else ax + return save_show_ret("scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, background=background) def map_to_points( diff --git a/dynamo/plot/state_graph.py b/dynamo/plot/state_graph.py index 71838e462..92bc95dda 100755 --- a/dynamo/plot/state_graph.py +++ b/dynamo/plot/state_graph.py @@ -10,9 +10,8 @@ from anndata import AnnData from matplotlib.axes import Axes -from ..tools.utils import update_dict from .scatters import docstrings, scatters -from .utils import save_fig +from .utils import save_show_ret docstrings.delete_params("scatters.parameters", "aggregate", "kwargs", "save_kwargs") @@ -226,8 +225,8 @@ def state_graph( scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True, `frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. s_kwargs_dict: any other kwargs that would be passed to `dynamo.pl.scatters`. Defaults to {"alpha": 1}. @@ -323,26 +322,4 @@ def state_graph( plt.axis("off") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "state_graph", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - if show_legend: - plt.subplots_adjust(right=0.85) - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list, color_list, font_color + save_show_ret("state_graph", save_show_or_return, save_kwargs, (axes_list, color_list, font_color), adjust = show_legend) diff --git a/dynamo/plot/time_series.py b/dynamo/plot/time_series.py index 4592622f9..4d02cf1ac 100755 --- a/dynamo/plot/time_series.py +++ b/dynamo/plot/time_series.py @@ -17,7 +17,7 @@ from ..external.hodge import ddhodge from ..prediction.utils import fetch_exprs from ..tools.utils import update_dict -from .utils import _to_hex, save_fig +from .utils import _to_hex, save_show_ret docstrings = DocstringProcessor() @@ -78,8 +78,8 @@ def kinetic_curves( when predicted data is not inverse transformed back to original expression space, no transformation will be applied. Defaults to True. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. @@ -171,27 +171,7 @@ def kinetic_curves( facet_kws={"sharex": True, "sharey": False}, ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "kinetic_curves", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return g + return save_show_ret("kinetic_curves", save_show_or_return, save_kwargs, g) docstrings.delete_params("kin_curves.parameters", "ncol", "color", "c_palette") @@ -289,8 +269,8 @@ def kinetic_heatmap( vline_cols: the indices of column that we can place a line on the heatmap. Defaults to None. vlines_kwargs: a dictionary of arguments that will be passed into sns_heatmap.ax_heatmap.vlines. Defaults to {}. save_show_or_return: whether to save, show, or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'kinetic_heatmap', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'kinetic_heatmap', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. transpose: whether to transpose the dataframe and swap X-Y in heatmap. In single cell case, `transpose=True` @@ -467,29 +447,7 @@ def kinetic_heatmap( vline_kwargs = update_dict({"linestyles": "dashdot"}, vlines_kwargs) sns_heatmap.ax_heatmap.vlines(vline_cols, *sns_heatmap.ax_heatmap.get_ylim(), **vline_kwargs) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "kinetic_heatmap", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - if show_colorbar: - plt.subplots_adjust(right=0.85) - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return sns_heatmap + return save_show_ret("kinetic_heatmap", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) def _half_max_ordering(exprs, time, mode, interpolate=False, spaced_num=100): @@ -702,8 +660,8 @@ def jacobian_kinetics( to 1. n_convolve: the number of cells for convolution. Defaults to 30. save_show_or_return: whether to save, show, or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. **kwargs: any other kwargs that would be passed to `seaborn.clustermap`. @@ -843,29 +801,7 @@ def jacobian_kinetics( if not show_colorbar: sns_heatmap.cax.set_visible(False) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "jacobian_kinetics", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - if show_colorbar: - plt.subplots_adjust(right=0.85) - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return sns_heatmap + return save_show_ret("jacobian_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) @docstrings.with_indent(4) @@ -920,8 +856,8 @@ def sensitivity_kinetics( meaning for each row or column, subtract the minimum and divide each by its maximum. Defaults to 1. n_convolve: the number of cells for convolution. Defaults to 30. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'kinetic_curves', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs that would be passed to `heatmap(). Currently `xticklabels=False, yticklabels='auto'` @@ -1062,26 +998,4 @@ def sensitivity_kinetics( if not show_colorbar: sns_heatmap.cax.set_visible(False) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "sensitivity_kinetics", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - if show_colorbar: - plt.subplots_adjust(right=0.85) - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return sns_heatmap + return save_show_ret("sensitivity_kinetics", save_show_or_return, save_kwargs, sns_heatmap, adjust = show_colorbar) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 9bf45b7cd..e8f45bec4 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -33,7 +33,7 @@ default_quiver_args, quiver_autoscaler, retrieve_plot_save_path, - save_fig, + save_show_ret, save_plotly_figure, save_pyvista_plotter, set_arrow_alpha, @@ -79,8 +79,8 @@ def plot_flow_field( streamline_alpha: the alpha value applied to the vector field streamlines. Defaults to 0.4. color_start_points: the color of the starting point that will be used to predict cell fates. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_flow_field', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_flow_field', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. ax: the Axis on which to make the plot. Defaults to None. @@ -178,27 +178,7 @@ def plot_flow_field( set_arrow_alpha(ax, streamline_alpha) set_stream_line_alpha(s, streamline_alpha) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_flow_field", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_flow_field", save_show_or_return, save_kwargs, ax) def plot_nullclines( @@ -218,8 +198,8 @@ def plot_nullclines( lw: the linewidth of the nullcline. Defaults to 3. background: the background color of the plot. Defaults to None. save_show_or_return: whether to save, show, or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_nullclines', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_nullclines', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. ax: the matplotlib axes used for plotting. Default is to use the current axis. Defaults to None. @@ -282,27 +262,7 @@ def plot_nullclines( for ncy in NCy: ax.plot(*ncy.T, c=colors[1], lw=lw) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_nullclines", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_nullclines", save_show_or_return, save_kwargs, ax) def plot_fixed_points_2d( @@ -329,8 +289,8 @@ def plot_fixed_points_2d( respectively. Defaults to ["full", "top", "none"]. background: the background color of the plot. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. ax: the matplotlib axes used for plotting. Default is to use the current axis. Defaults to None. @@ -394,27 +354,7 @@ def plot_fixed_points_2d( ] ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_fixed_points", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_fixed_points", save_show_or_return, save_kwargs, ax) def plot_fixed_points( @@ -447,8 +387,8 @@ def plot_fixed_points( respectively. Defaults to ["full", "top", "none"]. background: the background color of the plot. Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. plot_method: the method to plot 3D points. Options include `pv` (pyvista) and `matplotlib`. @@ -636,27 +576,7 @@ def plot_fixed_points( ] ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_fixed_points", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_fixed_points", save_show_or_return, save_kwargs, ax) def plot_traj( @@ -686,8 +606,8 @@ def plot_traj( integration_direction: Determines whether to integrate the trajectory in the forward, backward, or both direction. Default to "both". save_show_or_return: whether to save, show or return the figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_traj', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_traj', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. ax: the axis on which to make the plot. If None, new axis would be created. Defaults to None. @@ -718,27 +638,7 @@ def plot_traj( cur_y0 = y0[i, None] # don't drop dimension ax = _plot_traj(cur_y0, t, args, integration_direction, ax, color, lw, f) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_traj", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_traj", save_show_or_return, save_kwargs, ax) def plot_separatrix( @@ -766,8 +666,8 @@ def plot_separatrix( vecfld_dict: a dict with entries to create a `VectorField2D` instance. Defaults to None. background: the background color of the plot. Defaults to None. save_show_or_return: whether to save, show, or return the generated figure. Defaults to "return". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'plot_separatrix', "dpi": None, + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'plot_separatrix', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. ax: the axis on which to make the plot. Defaults to None. @@ -850,27 +750,7 @@ def rhs(ab, t): all_sep_a = sep_a if all_sep_a is None else np.concatenate((all_sep_a, sep_a)) all_sep_b = sep_b if all_sep_b is None else np.concatenate((all_sep_b, sep_b)) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_separatrix", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + return save_show_ret("plot_separatrix", save_show_or_return, save_kwargs, ax) @docstrings.with_indent(4) @@ -1067,8 +947,8 @@ def topography( None, the default color map will set to be viridis (inferno) when the background is white (black). Defaults to None. save_show_or_return: whether to save, show or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. @@ -1426,31 +1306,7 @@ def topography( **quiver_kwargs, ) # color='red', facecolors='gray' - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "topography", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - plt.tight_layout() - - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list if len(axes_list) > 1 else axes_list[0] + save_show_ret("topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0]) # TODO: Implement more `terms` like streamline and trajectory for 3D topography @@ -1571,8 +1427,8 @@ def topography_3D( None, the default color map will set to be viridis (inferno) when the background is white (black). Defaults to None. save_show_or_return: Whether to save, show or return the figure. Defaults to `show`. - save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', + save_kwargs: A dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. aggregate: The column in adata.obs that will be used to aggregate data points. Defaults to None. @@ -1895,28 +1751,4 @@ def topography_3D( cmap=marker_cmap, ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "topography", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - plt.tight_layout() - - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list if len(axes_list) > 1 else axes_list[0] \ No newline at end of file + return save_show_ret("topography", save_show_or_return, save_kwargs, axes_list if len(axes_list) > 1 else axes_list[0]) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index ca1ec1358..3a1440862 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -4,7 +4,7 @@ # import matplotlib.tri as tri import warnings -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple from warnings import warn import matplotlib @@ -17,7 +17,7 @@ from matplotlib.patches import Patch from scipy.spatial import Delaunay -from ..configuration import _themes +from ..configuration import _themes, reset_rcParams from ..dynamo_logger import main_debug from ..tools.utils import integrate_vf, update_dict # integrate_vf_ivp @@ -1659,6 +1659,84 @@ def save_fig( print("Done") +def save_show_ret( + prefix: str, + save_show_or_return: Literal["save", "show", "return", "both", "all"], + save_kwargs: Dict[str, Any], + ret_value = None, + tight: bool = True, + adjust: bool = False, + background: Optional[str] = None, +): + """ + Helper function that performs actions based on the variable save_show_or_return. + Should always have at least 3 inputs (prefix, save_show__or_return, save_kwargs). + + Args: + prefix: Prefix added to name of figure that will be saved. See the `s_kwargs` variable. + save_show_or_return: Whether the figure should be saved, shown, or returned. + "both" means that the figure would be shown and saved but not returned. Defaults + to "show". + save_kwargs: A dictionary that will be passed to the save_fig() function. + The save_fig() function will use + { + "path": None, + "prefix": [prefix input], + "dpi": None, + "ext": 'pdf', + "transparent": True, + "close": True, + "verbose": True + } + as its parameters. `save_kwargs` modifies those keys according to your needs. Defaults to {}. + ret_value: Value to be returned if `save_show_or_return` equals "return" or "all". + tight: Toggles whether plt.tight_layout() is called. + adjust: Toggles whether plt.subplots_adjust() is called. Some functions, such as scatters(), pass + in a string rather than a boolean. + background: Toggles whether reset_rcParams() is called. + + Returns: + If `save_show_or_return` is set as "return" or "all", returns `ret_value`. + """ + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": prefix, + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if background is not None: + reset_rcParams() + + if save_show_or_return in ["show", "both", "all"]: + if adjust: + plt.subplots_adjust(right=0.85) + + if tight: + #Do note that warnings should not be ignored in the future. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.tight_layout() + + plt.show() + if background is not None: + reset_rcParams() + + if save_show_or_return in ["return", "all"]: + if background is not None: + reset_rcParams() + return ret_value + + def retrieve_plot_save_path( path: Optional[str] = None, prefix: Optional[str] = None, diff --git a/dynamo/plot/vector_calculus.py b/dynamo/plot/vector_calculus.py index 901d8523b..f4ef5a7fb 100644 --- a/dynamo/plot/vector_calculus.py +++ b/dynamo/plot/vector_calculus.py @@ -25,7 +25,7 @@ is_cell_anno_column, is_gene_name, is_layer_keys, - save_fig, + save_show_ret, ) docstrings.delete_params("scatters.parameters", "adata", "color", "cmap", "frontier", "sym_c") @@ -396,8 +396,8 @@ def jacobian( stacked_fraction: whether the jacobian will be represented as a stacked fraction in the title or a linear fraction style will be used. Defaults to False. save_show_or_return: whether to save, show, or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. **kwargs: any other kwargs that would be passed to `plt._matplotlib_points`. @@ -579,27 +579,7 @@ def jacobian( despline_all(ax) deaxis_all(ax) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": jkey, - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return gs + return save_show_ret(jkey, save_show_or_return, save_kwargs, gs) def jacobian_heatmap( @@ -638,8 +618,8 @@ def jacobian_heatmap( cmap: the mapping from data values to color space. If not provided, the default will depend on whether center is set. Defaults to "bwr". save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs passed to `sns.heatmap`. @@ -738,27 +718,7 @@ def jacobian_heatmap( ) ax.title(name) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": jkey + "_heatmap", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return gs + return save_show_ret(jkey + "_heatmap", save_show_or_return, save_kwargs, gs) @docstrings.with_indent(4) @@ -828,8 +788,8 @@ def sensitivity( stacked_fraction: whether to represent the jacobianas a stacked fraction in the title or a linear fraction style will be used. Defaults to False. save_show_or_return: whether to save, show, or return the fugure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs passed to `plt._matplotlib_points`. @@ -1008,27 +968,7 @@ def sensitivity( despline_all(ax) deaxis_all(ax) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": skey, - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return gs + return save_show_ret(skey, save_show_or_return, save_kwargs, gs) def sensitivity_heatmap( @@ -1065,8 +1005,8 @@ def sensitivity_heatmap( cmap: the mapping from data values to color space. If not provided, the default will depend on whether center is set. Defaults to "bwr". save_show_or_return: whether to save, show, or return the figure. Defaults to "show". - save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs.. Defaults to {}. **kwargs: any other kwargs passed to `sns.heatmap`. @@ -1151,24 +1091,4 @@ def sensitivity_heatmap( ) plt.title(name) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": skey + "_heatmap", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) - - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return gs + return save_show_ret(skey + "_heatmap", save_show_or_return, save_kwargs, gs)