Skip to content

Commit

Permalink
Merge pull request #77 from n-claes/bugfix/mpl_geometry
Browse files Browse the repository at this point in the history
Pylbo: improved subplot addition
  • Loading branch information
n-claes authored Sep 8, 2021
2 parents 9beff37 + b0ad1d3 commit 775e11b
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 62 deletions.
2 changes: 1 addition & 1 deletion post_processing/pylbo/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from packaging import version

# The current pylbo version number.
__version__ = "1.1.2"
__version__ = "1.1.3"


class VersionHandler:
Expand Down
25 changes: 25 additions & 0 deletions post_processing/pylbo/utilities/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
from pylbo._version import _mpl_version


def get_axis_geometry(ax):
"""
Retrieves the geometry of a given matplotlib axis.
Parameters
----------
ax : ~matplotlib.axes.Axes
The axis to retrieve the geometry from.
Returns
-------
tuple
The geometry of the given matplotlib axis.
"""
if _mpl_version >= "3.4":
axis_geometry = ax.get_subplotspec().get_geometry()[0:3]
else:
# this is 1-based indexing by default, use 0-based here for consistency
# with subplotspec in matplotlib 3.4+
axis_geometry = transform_to_numpy(ax.get_geometry())
axis_geometry[-1] -= 1
axis_geometry = tuple(axis_geometry)
return axis_geometry


def add_pickradius_to_item(item, pickradius):
"""
Makes a matplotlib artist pickable and adds a pickradius.
Expand Down
87 changes: 40 additions & 47 deletions post_processing/pylbo/visualisation/figure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from functools import wraps
from pylbo.utilities.logger import pylboLogger
from pylbo._version import _mpl_version
from pylbo.utilities.toolbox import get_axis_geometry, transform_to_numpy


def refresh_plot(f):
Expand Down Expand Up @@ -150,8 +150,6 @@ class FigureWindow:
Scaling to apply to the x-axis.
y_scaling : int, float, complex, np.ndarray
Scaling to apply to the y-axis.
enabled : bool
If `False` (default), the current figure will not be plotted at plt.show().
"""

figure_stack = FigureContainer()
Expand Down Expand Up @@ -246,12 +244,13 @@ def _enable_interface(self, handle):
{"cid": callback_id, "kind": callback_kind, "method": callback_method}
)

def _add_subplot_axes(self, ax, loc="right", share=None):
def _add_subplot_axes(self, ax, loc="right", share=None, apply_tight_layout=True):
"""
Adds a new subplot to the given axes object, depending on the loc argument.
On return, `tight_layout()` is called to update the figure's gridspec.
This is meant to split a single axis, and should not be used for multiple
subplots.
Adds a new subplot to a given matplotlib subplot, essentially "splitting" the
axis into two. Position and placement depend on the loc argument.
When called on a more complex subplot layout the overall gridspec remains
untouched, only the `ax` object has its gridspec modified.
On return, `tight_layout()` is called by default to prevent overlapping labels.
Parameters
----------
Expand All @@ -264,62 +263,56 @@ def _add_subplot_axes(self, ax, loc="right", share=None):
to the left and the new axis is added on the right.
share : str
Can be "x", "y" or "all". This locks axes zooming between both subplots.
apply_tight_layout : bool
If `True` applies `tight_layout()` to the figure on return.
Raises
------
ValueError
- If the subplot is not a single axis
- If the loc argument is invalid.
If the loc argument is invalid.
Returns
-------
new_ax : ~matplotlib.axes.Axes
~matplotlib.axes.Axes
The axes instance that was added.
"""
if _mpl_version >= "3.4":
axes_geometry = ax.get_subplotspec().get_geometry()[0:2]
else:
axes_geometry = ax.get_geometry()[0:2]

if axes_geometry != (1, 1):
raise ValueError(
f"Something went wrong when adding a subplot. Expected "
f"axes with geometry (1, 1) but got {axes_geometry}."
)
if loc == "top":
geom_ax = (2, 1, 2)
geom_new_ax = (2, 1, 1)
elif loc == "right":
geom_ax = (1, 2, 1)
geom_new_ax = (1, 2, 2)
elif loc == "bottom":
geom_ax = (2, 1, 1)
geom_new_ax = (2, 1, 2)
elif loc == "left":
geom_ax = (1, 2, 2)
geom_new_ax = (1, 2, 1)
else:
raise ValueError(
f"invalid loc = {loc}, expected 'top', 'right', 'bottom' or 'left'"
)
sharex = None
sharey = None
if share in ("x", "all"):
sharex = ax
if share in ("y", "all"):
sharey = ax
# handle deprecation of `change_geometry` in mpl 3.4
if _mpl_version >= "3.4":
spec = gridspec.GridSpec(*geom_ax[0:2], figure=self.fig)
# geom_ax[-1] - 1 determines position of plot in window, index 0 based
ax.set_subplotspec(gridspec.SubplotSpec(spec, geom_ax[-1] - 1))
new_ax = self.fig.add_subplot(spec[geom_new_ax[-1] - 1])

if loc == "right":
subplot_geometry = (1, 2)
old_new_position = (0, 1)
elif loc == "left":
subplot_geometry = (1, 2)
old_new_position = (1, 0)
elif loc == "top":
subplot_geometry = (2, 1)
old_new_position = (1, 0)
elif loc == "bottom":
subplot_geometry = (2, 1)
old_new_position = (0, 1)
else:
ax.change_geometry(*geom_ax)
new_ax = self.fig.add_subplot(*geom_new_ax, sharex=sharex, sharey=sharey)
# this will update the figure's gridspec
raise ValueError(
f"invalid loc={loc}, expected ['top', 'right', 'bottom', 'left']"
)

_geometry = transform_to_numpy(get_axis_geometry(ax))
subplot_index = _geometry[-1]
gspec_outer = gridspec.GridSpec(*_geometry[0:2], figure=self.fig)
gspec_inner = gridspec.GridSpecFromSubplotSpec(
*subplot_geometry, subplot_spec=gspec_outer[subplot_index]
)
ax.set_subplotspec(gspec_inner[old_new_position[0]])
new_axis = self.fig.add_subplot(
gspec_inner[old_new_position[1]], sharex=sharex, sharey=sharey
)
if apply_tight_layout:
self.fig.tight_layout()
return new_ax
return new_axis

def draw(self):
"""
Expand Down
3 changes: 1 addition & 2 deletions post_processing/pylbo/visualisation/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def __init__(self, dataset, figsize, **kwargs):
self.dataset = dataset
self.kwargs = kwargs

self.ax.change_geometry(1, 2, 1)
self.ax2 = self.fig.add_subplot(122)
self.ax2 = self._add_subplot_axes(self.ax, loc="right")
self.draw()

def draw(self):
Expand Down
6 changes: 2 additions & 4 deletions post_processing/pylbo/visualisation/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,8 @@ def _use_custom_axes(self):
"""Splits the original 1x2 plot into a 2x2 plot."""
if self._axes_set:
return
self.panel1.ax.change_geometry(2, 2, 1)
self.panel1._ef_ax = self.panel1.fig.add_subplot(2, 2, 3)
self.panel2.ax.change_geometry(2, 2, 2)
self.panel2._ef_ax = self.panel2.fig.add_subplot(2, 2, 4)
self.panel1._ef_ax = self.panel1._add_subplot_axes(self.panel1.ax, loc="bottom")
self.panel2._ef_ax = self.panel2._add_subplot_axes(self.panel2.ax, loc="bottom")
self._axes_set = True

def add_eigenfunctions(self):
Expand Down
2 changes: 1 addition & 1 deletion post_processing/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

package_name = "pylbo"
required_packages = ["numpy", "matplotlib", "f90nml", "tqdm", "psutil"]
required_packages = ["numpy", "matplotlib", "f90nml", "tqdm", "psutil", "packaging"]

version_filepath = (Path(__file__).parent / "pylbo/_version.py").resolve()
VERSION = None
Expand Down
Loading

0 comments on commit 775e11b

Please sign in to comment.