Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make plots resize with windows better #35195

Merged
merged 14 commits into from
Mar 10, 2023
63 changes: 49 additions & 14 deletions Framework/PythonInterface/mantid/plots/plotfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from matplotlib.gridspec import GridSpec
from matplotlib.legend import Legend
import matplotlib as mpl
from mpl_toolkits.axes_grid1.axes_divider import make_axes_area_auto_adjustable

# local imports
from mantid.api import AnalysisDataService, MatrixWorkspace, WorkspaceGroup
Expand Down Expand Up @@ -225,8 +226,11 @@ def plot(

show_legend = "on" == ConfigService.getString("plots.ShowLegend").lower()
for ax in axes:
if ax.get_legend() is not None:
ax.get_legend().set_visible(show_legend)
legend = ax.get_legend()
if legend is not None:
legend.set_visible(show_legend)
# Stop legend interfering with the tight layout
legend.set_in_layout(False)

# Can't have a waterfall plot with only one line.
if len(nums) * len(workspaces) == 1 and waterfall:
Expand Down Expand Up @@ -298,14 +302,16 @@ def _update_show_figure(fig):
return fig


def create_subplots(nplots, fig=None):
def create_subplots(nplots, fig=None, add_cbar_axis=False):
"""
Create a set of subplots suitable for a given number of plots. A stripped down
version of plt.subplots that can accept an existing figure instance.
Figure is given a tight layout.

:param nplots: The number of plots required
:param fig: An optional figure. It is cleared before plotting the new contents
:return: A 2-tuple of (fig, axes)
:param add_cbar_axis: Boolean for whether to add and return an axis for a colour bar on the right of the figure
:return: fig, axes, ncrows, ncols, cbar_axis
"""
import matplotlib.pyplot as plt

Expand All @@ -324,15 +330,44 @@ def create_subplots(nplots, fig=None):
fig.clf()
# annoyling this repl
nplots = nrows * ncols
gs = GridSpec(nrows, ncols)
axes = np.empty(nplots, dtype=object)
ax0 = fig.add_subplot(gs[0, 0], projection=PROJECTION)
axes[0] = ax0
for i in range(1, nplots):
axes[i] = fig.add_subplot(gs[i // ncols, i % ncols], projection=PROJECTION)
axes = axes.reshape(nrows, ncols)
cbar_axis = None
fig.set_layout_engine(layout="tight")

if add_cbar_axis:
# The right most column of the GridSpec is made a SubGridSpec to facilitate both the colour bar
# and the plots in the rightmost column.
# This is done (instead of adding another column for the colour bar) so that the colour bar can have a narrow
# spacing to the right most plots, without effecting the other column spacings.
# Keeping the colour bar in the GridSpec rather than alongside it means it behaves much nicer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be reworded slightly? It seems a bit subjective and not clear what behaviour it is actually doing.

# when within a resizing QT window
# The relative width of 11.55 for the right most column ensures that all plots remain an equal size (**)
gs = GridSpec(nrows, ncols, width_ratios=[10.0] * (ncols - 1) + [11.55])
gs_and_cbar = gs[:, -1].subgridspec(nrows, 2, width_ratios=[10, 1], wspace=0.1)
if ncols > 1:
ax0 = fig.add_subplot(gs[0, 0], projection=PROJECTION)
else:
ax0 = fig.add_subplot(gs_and_cbar[0, 0], projection=PROJECTION)
axes[0] = ax0
for i in range(1, nplots):
if (i + 1) % ncols:
axes[i] = fig.add_subplot(gs[i // ncols, i % ncols], projection=PROJECTION)
else: # last column so add to colour bar grid spec
axes[i] = fig.add_subplot(gs_and_cbar[i // ncols, 0], projection=PROJECTION)
# avoid possible collision between x axis tick labels and colour bar
# (**) this may also cause plots in the right most column to be slightly narrower than the others
make_axes_area_auto_adjustable(axes[i], pad=0, adjust_dirs=["right"])
cbar_axis = fig.add_subplot(gs_and_cbar[:, 1])
fig.sca(axes[-1])
else:
gs = GridSpec(nrows, ncols)
ax0 = fig.add_subplot(gs[0, 0], projection=PROJECTION)
axes[0] = ax0
for i in range(1, nplots):
axes[i] = fig.add_subplot(gs[i // ncols, i % ncols], projection=PROJECTION)

return fig, axes, nrows, ncols
axes = axes.reshape(nrows, ncols)
return fig, axes, nrows, ncols, cbar_axis


def raise_if_not_sequence(value, seq_name, element_type=None):
Expand Down Expand Up @@ -375,7 +410,7 @@ def get_plot_fig(overplot=None, ax_properties=None, window_title=None, axes_num=
if fig and overplot:
fig = fig
elif fig:
fig, _, _, _ = create_subplots(axes_num, fig)
fig, _, _, _, _ = create_subplots(axes_num, fig)
elif overplot:
# The create subplot below assumes no figure was passed in, this is ensured by the elif above
# but add an assert which prevents a future refactoring from breaking this assumption
Expand All @@ -384,9 +419,9 @@ def get_plot_fig(overplot=None, ax_properties=None, window_title=None, axes_num=
if not fig.axes:
plt.close(fig)
# The user is likely trying to overplot on a non-existent plot, create one for them
fig, _, _, _ = create_subplots(axes_num)
fig, _, _, _, _ = create_subplots(axes_num)
else:
fig, _, _, _ = create_subplots(axes_num)
fig, _, _, _, _ = create_subplots(axes_num)

if not ax_properties and not overplot:
ax_properties = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mantid.simpleapi import CreateMDHistoWorkspace, CloneWorkspace, GroupWorkspaces, CreateSampleWorkspace
from mantid.kernel import config
from mantid.plots import MantidAxes
from mantid.plots.plotfunctions import figure_title, manage_workspace_names, plot, plot_md_histo_ws
from mantid.plots.plotfunctions import create_subplots, figure_title, manage_workspace_names, plot, plot_md_histo_ws
from mantid.plots.utility import MantidAxType


Expand Down Expand Up @@ -59,7 +59,6 @@ def workspace_names_dummy_func(workspaces):


class FunctionsTest(unittest.TestCase):

_test_ws = None
# MD workspace to test
_test_md_ws = None
Expand Down Expand Up @@ -286,6 +285,37 @@ def test_superplot_bin_plot(self):
fig.canvas.manager.superplot.set_workspaces.assert_called_once()
fig.canvas.manager.superplot.set_bin_mode.assert_called_once_with(False)

def test_create_subplots_axes_shape(self):
self._test_subplot_axes_shape_helper(cbar=False)

def test_create_subplots_axes_shape_with_cbar(self):
self._test_subplot_axes_shape_helper(cbar=True)

def test_create_subplots_creates_tight_figure(self):
from matplotlib.layout_engine import TightLayoutEngine

fig, _, _, _, _ = create_subplots(4)
layout_engine = fig.get_layout_engine()
self.assertIsInstance(layout_engine, TightLayoutEngine)

def test_create_subplots_colour_bar_is_alongside_all_rows(self):
_, axes, _, _, cbar_axis = create_subplots(9, add_cbar_axis=True)
plots_height = 0
cbar_height = cbar_axis.bbox.height
for ax in axes[:, 0]:
plots_height += ax.bbox.height

# plots_height will be a bit less because of padding around the plots
self.assertLessEqual(plots_height, cbar_height)

def test_create_subplots_plots_in_colour_bar_column_are_the_same_size(self):
_, axes, _, _, _ = create_subplots(4, add_cbar_axis=True)
ax1_bbox = axes[0, 0].bbox
ax2_bbox = axes[0, 1].bbox

self.assertAlmostEqual(ax1_bbox.height, ax2_bbox.height, delta=0.01)
self.assertAlmostEqual(ax1_bbox.width, ax2_bbox.width, delta=0.01)

# ------------- Failure tests -------------
def test_that_manage_workspace_names_raises_on_mix_of_workspaces_and_names(self):
ws = ["some_workspace", self._test_ws]
Expand All @@ -308,6 +338,21 @@ def _compare_errorbar_labels_and_title(self):
# Compare title
self.assertEqual(ax.get_title(), err_ax.get_title())

def _test_subplot_axes_shape_helper(self, cbar):
n_plots_to_correct_shape = {1: (1, 1), 2: (1, 2), 9: (3, 3), 10: (3, 4), 15: (4, 4)}

for n_plots in n_plots_to_correct_shape:
(
_,
axes,
_,
_,
_,
) = create_subplots(n_plots, add_cbar_axis=cbar)
shape = axes.shape
expected_shape = n_plots_to_correct_shape[n_plots]
self.assertEqual(shape, expected_shape)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions docs/source/release/v6.7.0/Workbench/Bugfixes/35014.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Fixed a bug where plot axes labels could be cut off by the window boundaries after resizing.
- Fixed a bug where the colour bar on surface plots could collide with the plot axes.
2 changes: 1 addition & 1 deletion qt/applications/workbench/workbench/plotting/toolbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def is_colormap(self, fig):
@classmethod
def _is_colorbar(cls, ax):
"""Determine whether an axes object is a colorbar"""
return not hasattr(ax, "get_subplotspec")
return not hasattr(ax, "get_subplotspec") or hasattr(ax, "_colorbar")

def set_up_color_selector_toolbar_button(self, fig):
# check if the action is already in the toolbar
Expand Down
9 changes: 7 additions & 2 deletions qt/python/mantidqt/mantidqt/plotting/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from distutils.version import LooseVersion

import matplotlib
from mpl_toolkits.axes_grid1.axes_divider import make_axes_area_auto_adjustable
import numpy as np

# local imports
Expand Down Expand Up @@ -251,7 +252,7 @@ def pcolormesh(workspaces, fig=None, color_norm=None, normalize_by_bin_width=Non
# create a subplot of the appropriate number of dimensions
# extend in number of columns if the number of plottables is not a square number
workspaces_len = len(workspaces)
fig, axes, nrows, ncols = create_subplots(workspaces_len, fig=fig)
fig, axes, nrows, ncols, cbar_axis = create_subplots(workspaces_len, fig=fig, add_cbar_axis=True)

plots = []
row_idx, col_idx = 0, 0
Expand Down Expand Up @@ -285,7 +286,7 @@ def pcolormesh(workspaces, fig=None, color_norm=None, normalize_by_bin_width=Non
fig.subplots_adjust(wspace=SUBPLOT_WSPACE, hspace=SUBPLOT_HSPACE)

axes = axes.ravel()
colorbar = fig.colorbar(pcm, ax=axes.tolist(), pad=0.06)
colorbar = fig.colorbar(pcm, cax=cbar_axis)
add_colorbar_label(colorbar, axes)

if fig.canvas.manager is not None:
Expand Down Expand Up @@ -362,6 +363,9 @@ def plot_surface(workspaces, fig=None):

surface = ax.plot_surface(ws, cmap=ConfigService.getString("plots.images.Colormap"))
ax.set_title(ws.name())
# Stops colour bar colliding with the plot. Also prevents the plot being pushed off the windows when resizing.
# "Top" direction is excluded since the title provides a buffer
make_axes_area_auto_adjustable(ax, pad=0, adjust_dirs=["left", "right", "bottom"])
fig.colorbar(surface, ax=[ax])
fig.show()

Expand All @@ -379,6 +383,7 @@ def plot_wireframe(workspaces, fig=None):
else:
fig, ax = plt.subplots(subplot_kw={"projection": "mantid3d"})

fig.set_layout_engine(layout="tight")
ax.plot_wireframe(ws)
ax.set_title(ws.name())
fig.show()
Expand Down
10 changes: 10 additions & 0 deletions qt/python/mantidqt/mantidqt/plotting/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
plot_md_ws_from_names,
pcolormesh_from_names,
plot_surface,
plot_wireframe,
)

IMAGE_PLOT_OPTIONS = {
Expand Down Expand Up @@ -329,6 +330,15 @@ def test_overplotting_onto_waterfall_plot_with_filled_areas_adds_another_filled_

self.assertEqual(len(fills), 3)

def test_plot_wireframe_creates_a_tight_figure(self):
from matplotlib.layout_engine import TightLayoutEngine

ws = self._test_ws
fig = plot_wireframe([ws])

layout_engine = fig.get_layout_engine()
self.assertIsInstance(layout_engine, TightLayoutEngine)

# ------------- Failure tests -------------

def test_plot_from_names_with_non_plottable_workspaces_returns_None(self):
Expand Down
4 changes: 2 additions & 2 deletions qt/python/mantidqt/mantidqt/project/plotsloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def make_fig(self, plot_dict, create_plot=True):
for cargs_dict in sublist:
if "norm" in cargs_dict and type(cargs_dict["norm"]) is dict:
cargs_dict["norm"] = self.restore_normalise_obj_from_dict(cargs_dict["norm"])
fig, axes_matrix, _, _ = create_subplots(len(creation_args))
fig, axes_matrix, _, _, _ = create_subplots(len(creation_args))
axes_list = axes_matrix.flatten().tolist()
for ax, cargs_list in zip(axes_list, creation_args):
creation_args_copy = copy.deepcopy(cargs_list)
Expand Down Expand Up @@ -331,7 +331,7 @@ def update_properties(self, ax, properties):
self.update_axis(ax.yaxis, properties["yAxisProperties"])

if "spineWidths" in properties:
for (spine, width) in properties["spineWidths"].items():
for spine, width in properties["spineWidths"].items():
ax.spines[spine].set_linewidth(width)

def update_axis(self, axis_, properties):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _setup_table_widget(self):
self.layout.addWidget(self.table)

def _setup_figure_widget(self):
fig, _, _, _ = create_subplots(1)
fig, _, _, _, _ = create_subplots(1)
self.figure = fig
self.figure.canvas = FigureCanvas(self.figure)
toolbar = MantidNavigationToolbar(self.figure.canvas, self)
Expand Down