diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ff67ea20073..355e19ee2a8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ New Features By `Thomas Hirtz `_. - allow passing a function to ``combine_attrs`` (:pull:`4896`). By `Justus Magin `_. +- Allow plotting categorical data (:pull:`5464`). + By `Jimmy Westling `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f530427562a..10ebcc07664 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -925,18 +925,26 @@ def imshow(x, y, z, ax, **kwargs): "imshow requires 1D coordinates, try using pcolormesh or contour(f)" ) - # Centering the pixels- Assumes uniform spacing - try: - xstep = (x[1] - x[0]) / 2.0 - except IndexError: - # Arbitrary default value, similar to matplotlib behaviour - xstep = 0.1 - try: - ystep = (y[1] - y[0]) / 2.0 - except IndexError: - ystep = 0.1 - left, right = x[0] - xstep, x[-1] + xstep - bottom, top = y[-1] + ystep, y[0] - ystep + def _center_pixels(x): + """Center the pixels on the coordinates.""" + if np.issubdtype(x.dtype, str): + # When using strings as inputs imshow converts it to + # integers. Choose extent values which puts the indices in + # in the center of the pixels: + return 0 - 0.5, len(x) - 0.5 + + try: + # Center the pixels assuming uniform spacing: + xstep = 0.5 * (x[1] - x[0]) + except IndexError: + # Arbitrary default value, similar to matplotlib behaviour: + xstep = 0.1 + + return x[0] - xstep, x[-1] + xstep + + # Center the pixels: + left, right = _center_pixels(x) + top, bottom = _center_pixels(y) defaults = {"origin": "upper", "interpolation": "nearest"} @@ -967,6 +975,13 @@ def imshow(x, y, z, ax, **kwargs): primitive = ax.imshow(z, **defaults) + # If x or y are strings the ticklabels have been replaced with + # integer indices. Replace them back to strings: + for axis, v in [("x", x), ("y", y)]: + if np.issubdtype(v.dtype, str): + getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) + getattr(ax, f"set_{axis}ticklabels")(v) + return primitive @@ -1011,9 +1026,13 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if infer_intervals and ( - (np.shape(x)[0] == np.shape(z)[1]) - or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + if ( + infer_intervals + and not np.issubdtype(x.dtype, str) + and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + ) ): if len(x.shape) == 1: x = _infer_interval_breaks(x, check_monotonic=True) @@ -1022,7 +1041,11 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): x = _infer_interval_breaks(x, axis=1) x = _infer_interval_breaks(x, axis=0) - if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]): + if ( + infer_intervals + and not np.issubdtype(y.dtype, str) + and (np.shape(y)[0] == np.shape(z)[0]) + ): if len(y.shape) == 1: y = _infer_interval_breaks(y, check_monotonic=True) else: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 416d56aa620..db85a5908c0 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -604,7 +604,14 @@ def _ensure_plottable(*args): Raise exception if there is anything in args that can't be plotted on an axis by matplotlib. """ - numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_] + numpy_types = [ + np.floating, + np.integer, + np.timedelta64, + np.datetime64, + np.bool_, + np.str_, + ] other_types = [datetime] try: import cftime diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index bbd7c31fa16..e833654138a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -684,10 +684,9 @@ def test_format_string(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.darray.plot.line) - def test_nonnumeric_index_raises_typeerror(self): + def test_nonnumeric_index(self): a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter") - with pytest.raises(TypeError, match=r"[Pp]lot"): - a.plot.line() + a.plot.line() def test_primitive_returned(self): p = self.darray.plot.line() @@ -1162,9 +1161,13 @@ def test_3d_raises_valueerror(self): with pytest.raises(ValueError, match=r"DataArray must be 2d"): self.plotfunc(a) - def test_nonnumeric_index_raises_typeerror(self): + def test_nonnumeric_index(self): a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]]) - with pytest.raises(TypeError, match=r"[Pp]lot"): + if self.plotfunc.__name__ == "surface": + # ax.plot_surface errors with nonnumerics: + with pytest.raises(Exception): + self.plotfunc(a) + else: self.plotfunc(a) def test_multiindex_raises_typeerror(self):