diff --git a/python/pyspark/pandas/plot/matplotlib.py b/python/pyspark/pandas/plot/matplotlib.py index 91387805c421f..030623605e513 100644 --- a/python/pyspark/pandas/plot/matplotlib.py +++ b/python/pyspark/pandas/plot/matplotlib.py @@ -392,6 +392,12 @@ def _make_plot(self): kwds = self.kwds.copy() label = pprint_thing(label if len(label) > 1 else label[0]) + # `if hasattr(...)` makes plotting compatible with pandas < 1.3, see pandas-dev/pandas#40078. + label = ( + self._mark_right_label(label, index=i) + if hasattr(self, "_mark_right_label") + else label + ) kwds["label"] = label style, kwds = self._apply_style_colors(colors, kwds, i, label) @@ -400,7 +406,10 @@ def _make_plot(self): kwds = self._make_plot_keywords(kwds, y) artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds) - self._add_legend_handle(artists[0], label, index=i) + # `if hasattr(...)` makes plotting compatible with pandas < 1.3, see pandas-dev/pandas#40078. + self._append_legend_handles_labels(artists[0], label) if hasattr( + self, "_append_legend_handles_labels" + ) else self._add_legend_handle(artists[0], label, index=i) @classmethod def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0, stacking_id=None, **kwds): @@ -483,6 +492,12 @@ def _make_plot(self): kwds = self.kwds.copy() label = pprint_thing(label if len(label) > 1 else label[0]) + # `if hasattr(...)` makes plotting compatible with pandas < 1.3, see pandas-dev/pandas#40078. + label = ( + self._mark_right_label(label, index=i) + if hasattr(self, "_mark_right_label") + else label + ) kwds["label"] = label style, kwds = self._apply_style_colors(colors, kwds, i, label) @@ -491,7 +506,10 @@ def _make_plot(self): kwds = self._make_plot_keywords(kwds, y) artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds) - self._add_legend_handle(artists[0], label, index=i) + # `if hasattr(...)` makes plotting compatible with pandas < 1.3, see pandas-dev/pandas#40078. + self._append_legend_handles_labels(artists[0], label) if hasattr( + self, "_append_legend_handles_labels" + ) else self._add_legend_handle(artists[0], label, index=i) def _get_ind(self, y): return KdePlotBase.get_ind(y, self.ind)