Skip to content

Commit

Permalink
Plots get labels from pint arrays (#5561)
Browse files Browse the repository at this point in the history
* test labels come from pint units

* values demotes pint arrays before returning

* plot labels look for pint units first

* pre-commit

* added to_numpy() and as_numpy() methods

* remove special-casing of cupy arrays in .values in favour of using .to_numpy()

* .values -> .to_numpy()

* lint

* Fix mypy (I think?)

* added Dataset.as_numpy()

* improved docstrings

* add what's new

* add to API docs

* linting

* fix failures by only importing pint when needed

* refactor pycompat into class

* pycompat import changes applied to plotting code

* what's new

* compute instead of load

* added tests

* fixed sparse test

* tests and fixes for ds.as_numpy()

* fix sparse tests

* fix linting

* tests for Variable

* test IndexVariable too

* use numpy.asarray to avoid a copy

* also convert coords

* Force tests again after #5600

Co-authored-by: Maximilian Roos <m@maxroos.com>
  • Loading branch information
TomNicholas and max-sixty authored Jul 21, 2021
1 parent c5ee050 commit 92cb751
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 13 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ New Features
By `Elle Smith <https://github.com/ellesmith88>`_.
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,7 +2784,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
result : MaskedArray
Masked where invalid values (nan or inf) occur.
"""
values = self.values # only compute lazy arrays once
values = self.to_numpy() # only compute lazy arrays once
isnull = pd.isnull(values)
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

Expand Down
1 change: 1 addition & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
# TODO an entrypoint so array libraries can choose coercion method?
data = self.data

# TODO first attempt to call .to_numpy() once some libraries implement it
if isinstance(data, dask_array_type):
data = data.compute()
Expand Down
10 changes: 5 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def line(

# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
xplt.values, yplt.values, kwargs
xplt.to_numpy(), yplt.to_numpy(), kwargs
)
xlabel = label_from_attrs(xplt, extra=x_suffix)
ylabel = label_from_attrs(yplt, extra=y_suffix)
Expand All @@ -449,7 +449,7 @@ def line(
ax.set_title(darray._title_for_slice())

if darray.ndim == 2 and add_legend:
ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label)
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)

# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
Expand Down Expand Up @@ -551,7 +551,7 @@ def hist(
"""
ax = get_axis(figsize, size, aspect, ax)

no_nan = np.ravel(darray.values)
no_nan = np.ravel(darray.to_numpy())
no_nan = no_nan[pd.notnull(no_nan)]

primitive = ax.hist(no_nan, **kwargs)
Expand Down Expand Up @@ -1153,8 +1153,8 @@ def newplotfunc(
dims = (yval.dims[0], xval.dims[0])

# better to pass the ndarrays directly to plotting functions
xval = xval.values
yval = yval.values
xval = xval.to_numpy()
yval = yval.to_numpy()

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
Expand Down
21 changes: 15 additions & 6 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd

from ..core.options import OPTIONS
from ..core.pycompat import DuckArrayModule
from ..core.utils import is_scalar

try:
Expand Down Expand Up @@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""):
else:
name = ""

if da.attrs.get("units"):
units = " [{}]".format(da.attrs["units"])
elif da.attrs.get("unit"):
units = " [{}]".format(da.attrs["unit"])
def _get_units_from_attrs(da):
if da.attrs.get("units"):
units = " [{}]".format(da.attrs["units"])
elif da.attrs.get("unit"):
units = " [{}]".format(da.attrs["unit"])
else:
units = ""
return units

pint_array_type = DuckArrayModule("pint").type
if isinstance(da.data, pint_array_type):
units = " [{}]".format(str(da.data.units))
else:
units = ""
units = _get_units_from_attrs(da)

return "\n".join(textwrap.wrap(name + extra + units, 30))

Expand Down Expand Up @@ -896,7 +905,7 @@ def _get_nice_quiver_magnitude(u, v):
import matplotlib as mpl

ticker = mpl.ticker.MaxNLocator(3)
mean = np.mean(np.hypot(u.values, v.values))
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude

Expand Down
40 changes: 39 additions & 1 deletion xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
import pandas as pd
import pytest

try:
import matplotlib.pyplot as plt
except ImportError:
pass

import xarray as xr
from xarray.core import dtypes, duck_array_ops

from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
from . import (
assert_allclose,
assert_duckarray_allclose,
assert_equal,
assert_identical,
requires_matplotlib,
)
from .test_plot import PlotTestCase
from .test_variable import _PAD_XR_NP_ARGS

pint = pytest.importorskip("pint")
Expand Down Expand Up @@ -5564,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype):

assert_units_equal(expected, actual)
assert_equal(expected, actual)


@requires_matplotlib
class TestPlots(PlotTestCase):
def test_units_in_line_plot_labels(self):
arr = np.linspace(1, 10, 3) * unit_registry.Pa
# TODO make coord a Quantity once unit-aware indexes supported
x_coord = xr.DataArray(
np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
)
da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")

da.plot.line()

ax = plt.gca()
assert ax.get_ylabel() == "pressure [pascal]"
assert ax.get_xlabel() == "x [meters]"

def test_units_in_2d_plot_labels(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)

assert cax.get_ylabel() == "pressure [pascal]"

0 comments on commit 92cb751

Please sign in to comment.