diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9d3e64badb8..dc487f85c02 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -44,6 +44,8 @@ Bug fixes - Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4. By `Anderson Banihirwe `_. +- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`) + By `Deepak Cherian `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index a288f195e32..a4ad30a9f40 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, @@ -666,17 +667,6 @@ def newplotfunc( darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb ) - # better to pass the ndarrays directly to plotting functions - xval = darray[xlab].values - yval = darray[ylab].values - - # check if we need to broadcast one dimension - if xval.ndim < yval.ndim: - xval = np.broadcast_to(xval, yval.shape) - - if yval.ndim < xval.ndim: - yval = np.broadcast_to(yval, xval.shape) - # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names if imshow_rgb: @@ -690,8 +680,17 @@ def newplotfunc( elif darray[xlab].dims[-1] == darray.dims[0]: darray = darray.transpose(transpose_coords=True) - # Pass the data as a masked ndarray too - zval = darray.to_masked_array(copy=False) + # better to pass the ndarrays directly to plotting functions + # Pass the data as a masked ndarray + if darray[xlab].ndim == 1 and darray[ylab].ndim == 1: + xval = darray[xlab].values + yval = darray[ylab].values + zval = darray.to_masked_array(copy=False) + else: + xval, yval, zval = map( + lambda x: x.values, broadcast(darray[xlab], darray[ylab], darray) + ) + zval = np.ma.masked_array(zval, mask=pd.isnull(zval), copy=False) # Replace pd.Intervals if contained in xval or yval. xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 3ac45a9720f..5c594ea193f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2145,3 +2145,19 @@ def test_yticks_kwarg(self, da): da.plot(yticks=np.arange(5)) expected = np.arange(5) assert np.all(plt.gca().get_yticks() == expected) + + +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) +def test_plot_transposed_nondim_coord(plotfunc): + x = np.linspace(0, 10, 101) + h = np.linspace(3, 7, 101) + s = np.linspace(0, 1, 51) + z = s[:, np.newaxis] * h[np.newaxis, :] + da = xr.DataArray( + np.sin(x) * np.cos(z), + dims=["s", "x"], + coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)}, + ) + getattr(da.plot, plotfunc)(x="x", y="zt") + getattr(da.plot, plotfunc)(x="zt", y="x")