Skip to content

Commit

Permalink
REF: make plotting less stateful (pandas-dev#55837)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Nov 6, 2023
1 parent 5d82d8b commit 56a4d57
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
3 changes: 2 additions & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from collections.abc import Collection

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D

from pandas._typing import MatplotlibColor
Expand Down Expand Up @@ -177,7 +178,7 @@ def maybe_color_bp(self, bp) -> None:
if not self.kwds.get("capprops"):
setp(bp["caps"], color=caps, alpha=1)

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
if self.subplots:
self._return_obj = pd.Series(dtype=object)

Expand Down
50 changes: 26 additions & 24 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.axis import Axis
from matplotlib.figure import Figure

from pandas._typing import (
IndexLabel,
Expand Down Expand Up @@ -241,7 +242,8 @@ def __init__(
self.stacked = kwds.pop("stacked", False)

self.ax = ax
self.fig = fig
# TODO: deprecate fig keyword as it is ignored, not passed in tests
# as of 2023-11-05
self.axes = np.array([], dtype=object) # "real" version get set in `generate`

# parse errorbar input if given
Expand Down Expand Up @@ -449,11 +451,11 @@ def draw(self) -> None:
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
self._setup_subplots()
self._make_plot()
fig = self._setup_subplots()
self._make_plot(fig)
self._add_table()
self._make_legend()
self._adorn_subplots()
self._adorn_subplots(fig)

for ax in self.axes:
self._post_plot_logic_common(ax, self.data)
Expand Down Expand Up @@ -495,7 +497,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
new_ax.set_yscale("symlog")
return new_ax

def _setup_subplots(self):
def _setup_subplots(self) -> Figure:
if self.subplots:
naxes = (
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
Expand Down Expand Up @@ -538,8 +540,8 @@ def _setup_subplots(self):
elif self.logy == "sym" or self.loglog == "sym":
[a.set_yscale("symlog") for a in axes]

self.fig = fig
self.axes = axes
return fig

@property
def result(self):
Expand Down Expand Up @@ -637,7 +639,7 @@ def _compute_plot_data(self):

self.data = numeric_data.apply(self._convert_to_ndarray)

def _make_plot(self):
def _make_plot(self, fig: Figure):
raise AbstractMethodError(self)

def _add_table(self) -> None:
Expand Down Expand Up @@ -672,11 +674,11 @@ def _post_plot_logic_common(self, ax, data):
def _post_plot_logic(self, ax, data) -> None:
"""Post process for each axes. Overridden in child classes"""

def _adorn_subplots(self):
def _adorn_subplots(self, fig: Figure):
"""Common post process unrelated to data"""
if len(self.axes) > 0:
all_axes = self._get_subplots()
nrows, ncols = self._get_axes_layout()
all_axes = self._get_subplots(fig)
nrows, ncols = self._get_axes_layout(fig)
handle_shared_axes(
axarr=all_axes,
nplots=len(all_axes),
Expand Down Expand Up @@ -723,7 +725,7 @@ def _adorn_subplots(self):
for ax, title in zip(self.axes, self.title):
ax.set_title(title)
else:
self.fig.suptitle(self.title)
fig.suptitle(self.title)
else:
if is_list_like(self.title):
msg = (
Expand Down Expand Up @@ -1114,17 +1116,17 @@ def _get_errorbars(
errors[kw] = err
return errors

def _get_subplots(self):
def _get_subplots(self, fig: Figure):
from matplotlib.axes import Subplot

return [
ax
for ax in self.fig.get_axes()
for ax in fig.get_axes()
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
]

def _get_axes_layout(self) -> tuple[int, int]:
axes = self._get_subplots()
def _get_axes_layout(self, fig: Figure) -> tuple[int, int]:
axes = self._get_subplots(fig)
x_set = set()
y_set = set()
for ax in axes:
Expand Down Expand Up @@ -1172,7 +1174,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

def _plot_colorbar(self, ax: Axes, **kwds):
def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds):
# Addresses issues #10611 and #10678:
# When plotting scatterplots and hexbinplots in IPython
# inline backend the colorbar axis height tends not to
Expand All @@ -1189,7 +1191,7 @@ def _plot_colorbar(self, ax: Axes, **kwds):
# use the last one which contains the latest information
# about the ax
img = ax.collections[-1]
return self.fig.colorbar(img, ax=ax, **kwds)
return fig.colorbar(img, ax=ax, **kwds)


class ScatterPlot(PlanePlot):
Expand All @@ -1209,7 +1211,7 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
c = self.data.columns[c]
self.c = c

def _make_plot(self):
def _make_plot(self, fig: Figure):
x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]

Expand Down Expand Up @@ -1274,7 +1276,7 @@ def _make_plot(self):
)
if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, label=cbar_label)
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
if color_by_categorical:
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
cbar.ax.set_yticklabels(self.data[c].cat.categories)
Expand Down Expand Up @@ -1306,7 +1308,7 @@ def __init__(self, data, x, y, C=None, **kwargs) -> None:
C = self.data.columns[C]
self.C = C

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
x, y, data, C = self.x, self.y, self.data, self.C
ax = self.axes[0]
# pandas uses colormap, matplotlib uses cmap.
Expand All @@ -1321,7 +1323,7 @@ def _make_plot(self) -> None:

ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds)
if cb:
self._plot_colorbar(ax)
self._plot_colorbar(ax, fig=fig)

def _make_legend(self) -> None:
pass
Expand Down Expand Up @@ -1358,7 +1360,7 @@ def _is_ts_plot(self) -> bool:
def _use_dynamic_x(self):
return use_dynamic_x(self._get_ax(0), self.data)

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
if self._is_ts_plot():
data = maybe_convert_index(self._get_ax(0), self.data)

Expand Down Expand Up @@ -1680,7 +1682,7 @@ def _plot( # type: ignore[override]
def _start_base(self):
return self.bottom

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors()
ncolors = len(colors)

Expand Down Expand Up @@ -1842,7 +1844,7 @@ def _args_adjust(self) -> None:
def _validate_color_args(self) -> None:
pass

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
self.kwds.setdefault("colors", colors)

Expand Down
3 changes: 2 additions & 1 deletion pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from pandas._typing import PlottingOrientation

Expand Down Expand Up @@ -113,7 +114,7 @@ def _plot( # type: ignore[override]
cls._update_stacker(ax, stacking_id, n)
return patches

def _make_plot(self) -> None:
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors()
stacking_id = self._get_stacking_id()

Expand Down

0 comments on commit 56a4d57

Please sign in to comment.