From 7700e7c8d647c0aab708a43b65e2fa0a0786f938 Mon Sep 17 00:00:00 2001 From: Siavash Ameli Date: Wed, 7 Feb 2024 18:43:38 -0800 Subject: [PATCH] util updatte --- ortho/__main__.py | 2 +- ortho/__version__.py | 2 +- .../orthogonal_functions.py | 5 +- ortho/_orthogonal_functions/plot_functions.py | 26 +- ortho/_orthogonal_functions/plot_utilities.py | 223 ++++++++++++++---- tests/test_orthogonal_functions.py | 4 +- 6 files changed, 192 insertions(+), 70 deletions(-) diff --git a/ortho/__main__.py b/ortho/__main__.py index 2b85a58..12bfbf9 100755 --- a/ortho/__main__.py +++ b/ortho/__main__.py @@ -47,7 +47,7 @@ def main(args=None): # Plot results if arguments['plot_flag']: - OF.plot() + OF.plot(filename='orthogonal_functions') # =========== diff --git a/ortho/__version__.py b/ortho/__version__.py index 6a9beea..3d26edf 100644 --- a/ortho/__version__.py +++ b/ortho/__version__.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/ortho/_orthogonal_functions/orthogonal_functions.py b/ortho/_orthogonal_functions/orthogonal_functions.py index 9c75e58..994a226 100644 --- a/ortho/_orthogonal_functions/orthogonal_functions.py +++ b/ortho/_orthogonal_functions/orthogonal_functions.py @@ -249,7 +249,7 @@ def print(self): # Plot # ---- - def plot(self): + def plot(self, filename=None): """ Plot the generated functions. """ @@ -261,4 +261,5 @@ def plot(self): plot_functions( self._phi_orthonormalized_list, self.start_index, - self._interval) + self._interval, + filename=filename) diff --git a/ortho/_orthogonal_functions/plot_functions.py b/ortho/_orthogonal_functions/plot_functions.py index b8bc39c..cb061ae 100644 --- a/ortho/_orthogonal_functions/plot_functions.py +++ b/ortho/_orthogonal_functions/plot_functions.py @@ -14,7 +14,7 @@ import numpy import sympy import os -from .plot_utilities import plt, matplotlib, get_custom_theme, save_plot +from .plot_utilities import plt, matplotlib, get_theme, show_or_save_plot from .declarations import t __all__ = ['plot_functions'] @@ -24,8 +24,12 @@ # Plot Functions # ============== -@matplotlib.rc_context(get_custom_theme(font_scale=1.2)) -def plot_functions(phi_orthonormalized_list, start_index, interval): +@matplotlib.rc_context(get_theme(font_scale=1.2)) +def plot_functions( + phi_orthonormalized_list, + start_index, + interval, + filename=None): """ Plots the generated functions, also saves the plots as both ``svg`` and ``pdf`` format. @@ -40,6 +44,9 @@ def plot_functions(phi_orthonormalized_list, start_index, interval): :param Interval: The right side of the interval of the domain of the functions. :param Interval: float + + :param: filename: Name of file to save the plot. If `None`, plot is shown + but not saved. If string, plot is not shown, rather, saved. """ # Axis @@ -81,13 +88,6 @@ def plot_functions(phi_orthonormalized_list, start_index, interval): save_dir = os.getcwd() # Save plot in both svg and pdf format - if os.access(save_dir, os.W_OK): - save_filename = 'orthogonal_functions' - save_plot(plt, save_filename, save_dir=save_dir, - transparent_background=True) - else: - print('Cannot save plot to %s. Directory is not writable.' % save_dir) - - # If no display backend is enabled, do not plot in the interactive mode - if matplotlib.get_backend() != 'agg': - plt.show() + show_or_save_plot(plt, filename=filename, + default_filename='orthogonal_functions', + transparent_background=True) diff --git a/ortho/_orthogonal_functions/plot_utilities.py b/ortho/_orthogonal_functions/plot_utilities.py index 46ef482..3178f7a 100644 --- a/ortho/_orthogonal_functions/plot_utilities.py +++ b/ortho/_orthogonal_functions/plot_utilities.py @@ -51,8 +51,7 @@ message=('This figure includes Axes that are not compatible with ' + 'tight_layout, so results might be incorrect.')) -__all__ = ['get_custom_theme', 'set_custom_theme', 'save_plot', - 'show_or_save_plot'] +__all__ = ['get_theme', 'set_theme', 'reset_theme', 'show_or_save_plot'] # ===================== @@ -228,22 +227,22 @@ def _customize_theme_text(): return text_dict -# ================ -# get custom theme -# ================ +# ========= +# get theme +# ========= -def get_custom_theme( +def get_theme( context="notebook", font_scale=1, use_latex=True, - **kwargs): + rc=None): """ Returns a dictionary that can be used to update plt.rcParams. Usage: Before a function, add this line: - @matplotlib.rc_context(get_custom_theme(font_scale=1.2)) + @matplotlib.rc_context(get_theme(font_scale=1.2)) def some_plotting_function(): ... @@ -276,35 +275,50 @@ def some_plotting_function(): plt_rc_params.update(_customize_theme_text()) # Add extra arguments - plt_rc_params.update(kwargs) + if rc is not None: + plt_rc_params.update(rc) return plt_rc_params -# ================ -# set custom theme -# ================ +# ========= +# set theme +# ========= -def set_custom_theme(context="notebook", font_scale=1, use_latex=True): +def set_theme( + context="notebook", + font_scale=1, + use_latex=True, + rc=None): """ Sets a customized theme for plotting. """ - plt_rc_params = get_custom_theme(context=context, font_scale=font_scale, - use_latex=use_latex) + plt_rc_params = get_theme(context=context, font_scale=font_scale, + use_latex=use_latex, rc=rc) matplotlib.rcParams.update(plt_rc_params) +# =========== +# reset theme +# =========== + +def reset_theme(): + """ + Reset the matplotlib theme back it default. + """ + + matplotlib.rcParams.update(matplotlib.rcParamsDefault) + + # ========= # save plot # ========= -def save_plot( +def _save_plot( plt, filename, - save_dir=None, transparent_background=True, - pdf=True, bbox_extra_artists=None, dpi=200, verbose=False): @@ -322,34 +336,48 @@ def save_plot( :type transparent_background: bool """ - # If no directory specified, write in the current working directory - if save_dir is None: - save_dir = os.getcwd() + # Is filename contain path, use the current path + if os.path.isabs(filename) or os.path.isabs(os.path.expanduser(filename)): + + # Remove redundant separators + filename = os.path.normpath(filename) + + # Extract the directory, base filename, and file extension + directory, base_and_ext = os.path.split(filename) + base_filename, extension = os.path.splitext(base_and_ext) - # Save plot in both svg and pdf format - filename_svg = filename + '.svg' - filename_pdf = filename + '.pdf' - if os.access(save_dir, os.W_OK): - save_fullname_svg = os.path.join(save_dir, filename_svg) - save_fullname_pdf = os.path.join(save_dir, filename_pdf) + # Initialize a list of extensions + extensions = [extension] + + else: + # If no directory specified, write in the current working directory + directory = os.getcwd() + base_filename, extension = os.path.splitext(filename) + + # Determine whether a file extension is provided or not. + if bool(extension): + # filename contains an extension + extensions = [extension] + else: + # No extension is provided. Save to both svg and pdf + extensions = ['svg', 'pdf'] + + if not os.access(directory, os.W_OK): + raise RuntimeError( + 'Cannot save plot to %s. Directory is not writable.' % directory) + + # For each extension, save a file + for extension in extensions: + fullpath_filename = os.path.join(directory, + base_filename + '.' + extension) plt.savefig( - save_fullname_svg, + fullpath_filename, dpi=dpi, transparent=transparent_background, - bbox_inches='tight') - if verbose: - print('Plot saved to "%s".' % (save_fullname_svg)) + bbox_extra_artists=bbox_extra_artists, bbox_inches='tight') - if pdf: - plt.savefig( - save_fullname_pdf, dpi=dpi, - transparent=transparent_background, - bbox_extra_artists=bbox_extra_artists, bbox_inches='tight') - plt.close() - if verbose: - print('Plot saved to "%s".' % (save_fullname_pdf)) - else: - print('Cannot save plot to %s. Directory is not writable.' % save_dir) + if verbose: + print('Plot saved to "%s".' % fullpath_filename) # ================= @@ -358,22 +386,115 @@ def save_plot( def show_or_save_plot( plt, - filename, + filename=None, + default_filename=None, transparent_background=True, - pdf=True, bbox_extra_artists=None, dpi=200, + show_and_save=False, verbose=False): """ - Shows the plot. If no graphical beckend exists, saves the plot. + Show and/or save plot. + + Parameters + ---------- + + filename : str or bool, default=None + If `False`, the plot is neither shown nor saved. If `True` or `None`, + the plot is shown. If string, the plot is saved instead, where the + filename is the given string. If the filename contains no file + extension, the plot is saved as both ``svg`` and ``pdf`` format. If the + filename contains no directory path, the plot is saved in the current + directory. + + default_filename : str, default=None + If the plot cannot be shown (such as when no graphical backend exists), + the plot is saved instead with the default filename. + + transparent_background : bool, default=True + If `True`, the figure and axes backgrounds will be rendered transparent + in saved file. This only works for the file formats that support the + alpha channel, such as ``svg``, ``pdf``, and ``png`` files. + + bbox_extra_artist : list, default=None + A list of extra artists to pass to the renderer. + + dpi : int, default=200 + Dots per inch for rendering plots. + + show_and_save : bool, default=False + By default, the plot is either shown xor saved. But when this argument + is `True`, the plot is forced to both be shown and saved. + + verbose : bool, default=False + If `True`, the fullpath filename of the saved plot is printed. + + Notes + ----- + + The ``filename`` argument determines that whether the plot should be shown + of saved. However, there the following exceptions are made: + + 1. If ``show_and_save`` is enabled, plot is both shown and saved. + + 2. If no graphical beckend exists, it the plot is saved instead of being + shown, even of it was not intended to be saved. + + 2. If the environment is Jupyter notebook, the plot is always shown, even + if it was not intended to be shown. If a filename is given, it is both + shown and saved. """ - # Check if the graphical back-end exists - if matplotlib.get_backend() != 'agg' or is_notebook(): - plt.show() + if filename is False: + # Do nothing + return + elif (filename is True) or (filename is None): + # Show the plot, but do not save it + if matplotlib.get_backend() != 'agg': + show = True + if show_and_save: + save = True + else: + save = False + else: + # There is no graphical backend, no other choice but to save it + save = True + if show_and_save: + show = True + else: + show = False + elif isinstance(filename, str): + # Save the plot, but do not show it, unless it is a Jupyter notebook + save = True + if is_notebook(): + show = True + else: + show = False else: + raise ValueError('"filename" should be boolean or a string.') + + # Determine filename + if (save is True) and (not isinstance(filename, str)): + if isinstance(default_filename, str): + filename = default_filename + else: + raise NotImplementedError('"default_filename" is not set. ' + + 'This is a bug. Please report the ' + + 'issue.') + + # Save plot to file(s) + if save: + # write the plot as SVG file in the current working directory - save_plot(plt, filename, transparent_background=transparent_background, - pdf=pdf, bbox_extra_artists=bbox_extra_artists, dpi=dpi, - verbose=verbose) - plt.close() + _save_plot(plt, filename, + transparent_background=transparent_background, + bbox_extra_artists=bbox_extra_artists, dpi=dpi, + verbose=verbose) + + # Closing is necessary especially if a large number of plots are saved. + if not show: + plt.close() + + # Show plot + if show: + plt.show() diff --git a/tests/test_orthogonal_functions.py b/tests/test_orthogonal_functions.py index 3dd9a28..481b2d1 100755 --- a/tests/test_orthogonal_functions.py +++ b/tests/test_orthogonal_functions.py @@ -20,7 +20,7 @@ from ortho import OrthogonalFunctions # noqa: E402 -import warnings +import warnings # noqa: E402 warnings.resetwarnings() warnings.filterwarnings("error") @@ -60,7 +60,7 @@ def test_orthogonal_functions(): OF.check(verbose=True) # Plot the results - OF.plot() + OF.plot(filename='orthogonal_functions') # Remove saved plots remove_file('orthogonal_functions.svg')