diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml
index 4e4f8550e16..3f10a158f91 100644
--- a/ci/requirements/py36-min-all-deps.yml
+++ b/ci/requirements/py36-min-all-deps.yml
@@ -31,6 +31,7 @@ dependencies:
- numba=0.44
- numpy=1.14
- pandas=0.24
+ # - pint # See py36-min-nep18.yml
- pip
- pseudonetcdf=3.0
- pydap=3.2
diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml
index 5b291cf554c..fc9523ce249 100644
--- a/ci/requirements/py36-min-nep18.yml
+++ b/ci/requirements/py36-min-nep18.yml
@@ -2,7 +2,7 @@ name: xarray-tests
channels:
- conda-forge
dependencies:
- # Optional dependencies that require NEP18, such as sparse,
+ # Optional dependencies that require NEP18, such as sparse and pint,
# require drastically newer packages than everything else
- python=3.6
- coveralls
@@ -10,6 +10,7 @@ dependencies:
- distributed=2.4
- numpy=1.17
- pandas=0.24
+ - pint=0.9 # Actually not enough as it doesn't implement __array_function__yet!
- pytest
- pytest-cov
- pytest-env
diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml
index cc91e8a12da..820160b19cc 100644
--- a/ci/requirements/py36.yml
+++ b/ci/requirements/py36.yml
@@ -27,6 +27,7 @@ dependencies:
- numba
- numpy
- pandas
+ - pint
- pip
- pseudonetcdf
- pydap
diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/py37-windows.yml
index bf485b59a49..1d150d9f2af 100644
--- a/ci/requirements/py37-windows.yml
+++ b/ci/requirements/py37-windows.yml
@@ -27,6 +27,7 @@ dependencies:
- numba
- numpy
- pandas
+ - pint
- pip
- pseudonetcdf
- pydap
diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml
index 5c9a1cec5b5..4a7aaf7d32b 100644
--- a/ci/requirements/py37.yml
+++ b/ci/requirements/py37.yml
@@ -27,6 +27,7 @@ dependencies:
- numba
- numpy
- pandas
+ - pint
- pip
- pseudonetcdf
- pydap
diff --git a/doc/installing.rst b/doc/installing.rst
index b1bf072dbe1..0c5e8916ca3 100644
--- a/doc/installing.rst
+++ b/doc/installing.rst
@@ -66,6 +66,15 @@ For plotting
Alternative data containers
~~~~~~~~~~~~~~~~~~~~~~~~~~~
- `sparse `_: for sparse arrays
+- `pint `_: for units of measure
+
+ .. note::
+
+ At the moment of writing, xarray requires a `highly experimental version of pint
+ `_ (install with
+ ``pip install git+https://github.com/andrewgsavage/pint.git@refs/pull/6/head)``.
+ Even with it, interaction with non-numpy array libraries, e.g. dask or sparse, is broken.
+
- Any numpy-like objects that support
`NEP-18 `_.
Note that while such libraries theoretically should work, they are untested.
@@ -85,7 +94,7 @@ dependencies:
(`NEP-29 `_)
- **pandas:** 12 months
- **scipy:** 12 months
-- **sparse** and other libraries that rely on
+- **sparse, pint** and other libraries that rely on
`NEP-18 `_
for integration: very latest available versions only, until the technology will have
matured. This extends to dask when used in conjunction with any of these libraries.
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 2202c91408b..6c09b44940b 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -18,6 +18,19 @@ What's New
v0.14.1 (unreleased)
--------------------
+New Features
+~~~~~~~~~~~~
+- Added integration tests against `pint `_.
+ (:pull:`3238`) by `Justus Magin `_.
+
+ .. note::
+
+ At the moment of writing, these tests *as well as the ability to use pint in general*
+ require `a highly experimental version of pint
+ `_ (install with
+ ``pip install git+https://github.com/andrewgsavage/pint.git@refs/pull/6/head)``.
+ Even with it, interaction with non-numpy array libraries, e.g. dask or sparse, is broken.
+
Documentation
~~~~~~~~~~~~~
diff --git a/setup.cfg b/setup.cfg
index 6293d331477..eee8b2477b2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -73,6 +73,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
+[mypy-pint.*]
+ignore_missing_imports = True
[mypy-PseudoNetCDF.*]
ignore_missing_imports = True
[mypy-pydap.*]
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
new file mode 100644
index 00000000000..15bb40ce4b2
--- /dev/null
+++ b/xarray/tests/test_units.py
@@ -0,0 +1,1636 @@
+import operator
+
+import numpy as np
+import pandas as pd
+import pytest
+
+import xarray as xr
+from xarray.core import formatting
+from xarray.core.npcompat import IS_NEP18_ACTIVE
+
+pint = pytest.importorskip("pint")
+DimensionalityError = pint.errors.DimensionalityError
+
+
+unit_registry = pint.UnitRegistry()
+Quantity = unit_registry.Quantity
+
+pytestmark = [
+ pytest.mark.skipif(
+ not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled"
+ ),
+ # TODO: remove this once pint has a released version with __array_function__
+ pytest.mark.skipif(
+ not hasattr(unit_registry.Quantity, "__array_function__"),
+ reason="pint does not implement __array_function__ yet",
+ ),
+ # pytest.mark.filterwarnings("ignore:::pint[.*]"),
+]
+
+
+def array_extract_units(obj):
+ raw = obj.data if hasattr(obj, "data") else obj
+ try:
+ return raw.units
+ except AttributeError:
+ return None
+
+
+def array_strip_units(array):
+ try:
+ return array.magnitude
+ except AttributeError:
+ return array
+
+
+def array_attach_units(data, unit, convert_from=None):
+ try:
+ unit, convert_from = unit
+ except TypeError:
+ pass
+
+ if isinstance(data, Quantity):
+ if not convert_from:
+ raise ValueError(
+ "cannot attach unit {unit} to quantity ({data.units})".format(
+ unit=unit, data=data
+ )
+ )
+ elif isinstance(convert_from, unit_registry.Unit):
+ data = data.magnitude
+ elif convert_from is True: # intentionally accept exactly true
+ if data.check(unit):
+ convert_from = data.units
+ data = data.magnitude
+ else:
+ raise ValueError(
+ "cannot convert quantity ({data.units}) to {unit}".format(
+ unit=unit, data=data
+ )
+ )
+ else:
+ raise ValueError(
+ "cannot convert from invalid unit {convert_from}".format(
+ convert_from=convert_from
+ )
+ )
+
+ # to make sure we also encounter the case of "equal if converted"
+ if convert_from is not None:
+ quantity = (data * convert_from).to(
+ unit
+ if isinstance(unit, unit_registry.Unit)
+ else unit_registry.dimensionless
+ )
+ else:
+ try:
+ quantity = data * unit
+ except np.core._exceptions.UFuncTypeError:
+ if unit != 1:
+ raise
+
+ quantity = data
+
+ return quantity
+
+
+def extract_units(obj):
+ if isinstance(obj, xr.Dataset):
+ vars_units = {
+ name: array_extract_units(value) for name, value in obj.data_vars.items()
+ }
+ coords_units = {
+ name: array_extract_units(value) for name, value in obj.coords.items()
+ }
+
+ units = {**vars_units, **coords_units}
+ elif isinstance(obj, xr.DataArray):
+ vars_units = {obj.name: array_extract_units(obj)}
+ coords_units = {
+ name: array_extract_units(value) for name, value in obj.coords.items()
+ }
+
+ units = {**vars_units, **coords_units}
+ elif isinstance(obj, Quantity):
+ vars_units = {"": array_extract_units(obj)}
+
+ units = {**vars_units}
+ else:
+ units = {}
+
+ return units
+
+
+def strip_units(obj):
+ if isinstance(obj, xr.Dataset):
+ data_vars = {name: strip_units(value) for name, value in obj.data_vars.items()}
+ coords = {name: strip_units(value) for name, value in obj.coords.items()}
+
+ new_obj = xr.Dataset(data_vars=data_vars, coords=coords)
+ elif isinstance(obj, xr.DataArray):
+ data = array_strip_units(obj.data)
+ coords = {
+ name: (
+ (value.dims, array_strip_units(value.data))
+ if isinstance(value.data, Quantity)
+ else value # to preserve multiindexes
+ )
+ for name, value in obj.coords.items()
+ }
+
+ new_obj = xr.DataArray(name=obj.name, data=data, coords=coords, dims=obj.dims)
+ elif hasattr(obj, "magnitude"):
+ new_obj = obj.magnitude
+ else:
+ new_obj = obj
+
+ return new_obj
+
+
+def attach_units(obj, units):
+ if not isinstance(obj, (xr.DataArray, xr.Dataset)):
+ return array_attach_units(obj, units.get("data", 1))
+
+ if isinstance(obj, xr.Dataset):
+ data_vars = {
+ name: attach_units(value, units) for name, value in obj.data_vars.items()
+ }
+
+ coords = {
+ name: attach_units(value, units) for name, value in obj.coords.items()
+ }
+
+ new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
+ else:
+ # try the array name, "data" and None, then fall back to dimensionless
+ data_units = (
+ units.get(obj.name, None)
+ or units.get("data", None)
+ or units.get(None, None)
+ or 1
+ )
+
+ data = array_attach_units(obj.data, data_units)
+
+ coords = {
+ name: (
+ (value.dims, array_attach_units(value.data, units.get(name) or 1))
+ if name in units
+ # to preserve multiindexes
+ else value
+ )
+ for name, value in obj.coords.items()
+ }
+ dims = obj.dims
+ attrs = obj.attrs
+
+ new_obj = xr.DataArray(
+ name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims
+ )
+
+ return new_obj
+
+
+def assert_equal_with_units(a, b):
+ # works like xr.testing.assert_equal, but also explicitly checks units
+ # so, it is more like assert_identical
+ __tracebackhide__ = True
+
+ if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset):
+ a_units = extract_units(a)
+ b_units = extract_units(b)
+
+ a_without_units = strip_units(a)
+ b_without_units = strip_units(b)
+
+ assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr(
+ a, b, "equals"
+ )
+ assert a_units == b_units
+ else:
+ a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data
+ b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data
+
+ assert type(a) == type(b) or (
+ isinstance(a, Quantity) and isinstance(b, Quantity)
+ )
+
+ # workaround until pint implements allclose in __array_function__
+ if isinstance(a, Quantity) or isinstance(b, Quantity):
+ assert (
+ hasattr(a, "magnitude") and hasattr(b, "magnitude")
+ ) and np.allclose(a.magnitude, b.magnitude, equal_nan=True)
+ assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units
+ else:
+ assert np.allclose(a, b, equal_nan=True)
+
+
+@pytest.fixture(params=[float, int])
+def dtype(request):
+ return request.param
+
+
+class method:
+ def __init__(self, name, *args, **kwargs):
+ self.name = name
+ self.args = args
+ self.kwargs = kwargs
+
+ def __call__(self, obj, *args, **kwargs):
+ from collections.abc import Callable
+ from functools import partial
+
+ all_args = list(self.args) + list(args)
+ all_kwargs = {**self.kwargs, **kwargs}
+
+ func = getattr(obj, self.name, None)
+ if func is None or not isinstance(func, Callable):
+ # fall back to module level numpy functions if not a xarray object
+ if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
+ numpy_func = getattr(np, self.name)
+ func = partial(numpy_func, obj)
+ # remove typical xr args like "dim"
+ exclude_kwargs = ("dim", "dims")
+ all_kwargs = {
+ key: value
+ for key, value in all_kwargs.items()
+ if key not in exclude_kwargs
+ }
+ else:
+ raise AttributeError(
+ "{obj} has no method named '{self.name}'".format(obj=obj, self=self)
+ )
+
+ return func(*all_args, **all_kwargs)
+
+ def __repr__(self):
+ return "method_{self.name}".format(self=self)
+
+
+class function:
+ def __init__(self, name):
+ self.name = name
+ self.func = getattr(np, name)
+
+ def __call__(self, *args, **kwargs):
+ return self.func(*args, **kwargs)
+
+ def __repr__(self):
+ return "function_{self.name}".format(self=self)
+
+
+@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
+def test_replication(func, dtype):
+ array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
+ data_array = xr.DataArray(data=array, dims="x")
+
+ numpy_func = getattr(np, func.__name__)
+ expected = xr.DataArray(data=numpy_func(array), dims="x")
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+
+@pytest.mark.xfail(
+ reason="np.full_like on Variable strips the unit and pint does not allow mixed args"
+)
+@pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.ms, None, id="compatible_unit"),
+ pytest.param(unit_registry.s, None, id="identical_unit"),
+ ),
+)
+def test_replication_full_like(unit, error, dtype):
+ array = np.linspace(0, 5, 10) * unit_registry.s
+ data_array = xr.DataArray(data=array, dims="x")
+
+ fill_value = -1 * unit
+ if error is not None:
+ with pytest.raises(error):
+ xr.full_like(data_array, fill_value=fill_value)
+ else:
+ result = xr.full_like(data_array, fill_value=fill_value)
+ expected = np.full_like(array, fill_value=fill_value)
+
+ assert_equal_with_units(expected, result)
+
+
+class TestDataArray:
+ @pytest.mark.filterwarnings("error:::pint[.*]")
+ @pytest.mark.parametrize(
+ "variant",
+ (
+ pytest.param(
+ "with_dims",
+ marks=pytest.mark.xfail(reason="units in indexes are not supported"),
+ ),
+ pytest.param("with_coords"),
+ pytest.param("without_coords"),
+ ),
+ )
+ def test_init(self, variant, dtype):
+ array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m
+
+ x = np.arange(len(array)) * unit_registry.s
+ y = x.to(unit_registry.ms)
+
+ variants = {
+ "with_dims": {"x": x},
+ "with_coords": {"y": ("x", y)},
+ "without_coords": {},
+ }
+
+ kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)}
+ data_array = xr.DataArray(**kwargs)
+
+ assert isinstance(data_array.data, Quantity)
+ assert all(
+ {
+ name: isinstance(coord.data, Quantity)
+ for name, coord in data_array.coords.items()
+ }.values()
+ )
+
+ @pytest.mark.filterwarnings("error:::pint[.*]")
+ @pytest.mark.parametrize(
+ "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr"))
+ )
+ @pytest.mark.parametrize(
+ "variant",
+ (
+ pytest.param(
+ "with_dims",
+ marks=pytest.mark.xfail(reason="units in indexes are not supported"),
+ ),
+ pytest.param("with_coords"),
+ pytest.param("without_coords"),
+ ),
+ )
+ def test_repr(self, func, variant, dtype):
+ array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m
+ x = np.arange(len(array)) * unit_registry.s
+ y = x.to(unit_registry.ms)
+
+ variants = {
+ "with_dims": {"x": x},
+ "with_coords": {"y": ("x", y)},
+ "without_coords": {},
+ }
+
+ kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)}
+ data_array = xr.DataArray(**kwargs)
+
+ # FIXME: this just checks that the repr does not raise
+ # warnings or errors, but does not check the result
+ func(data_array)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ function("all"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ function("any"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ function("argmax"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantity with ndarrays in nanops not implemented"
+ ),
+ ),
+ pytest.param(
+ function("argmin"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantity with ndarrays in nanops not implemented"
+ ),
+ ),
+ function("max"),
+ function("mean"),
+ pytest.param(
+ function("median"),
+ marks=pytest.mark.xfail(
+ reason="np.median on DataArray strips the units"
+ ),
+ ),
+ function("min"),
+ pytest.param(
+ function("prod"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ function("sum"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantity with ndarrays in nanops not implemented"
+ ),
+ ),
+ function("std"),
+ function("var"),
+ function("cumsum"),
+ pytest.param(
+ function("cumprod"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ method("all"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ method("any"),
+ marks=pytest.mark.xfail(reason="not implemented by pint yet"),
+ ),
+ pytest.param(
+ method("argmax"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantities with ndarrays in nanops not implemented"
+ ),
+ ),
+ pytest.param(
+ method("argmin"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantities with ndarrays in nanops not implemented"
+ ),
+ ),
+ method("max"),
+ method("mean"),
+ method("median"),
+ method("min"),
+ pytest.param(
+ method("prod"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantity with ndarrays in nanops not implemented"
+ ),
+ ),
+ pytest.param(
+ method("sum"),
+ marks=pytest.mark.xfail(
+ reason="comparison of quantity with ndarrays in nanops not implemented"
+ ),
+ ),
+ method("std"),
+ method("var"),
+ method("cumsum"),
+ pytest.param(
+ method("cumprod"),
+ marks=pytest.mark.xfail(reason="pint does not implement cumprod yet"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_aggregation(self, func, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.m
+ data_array = xr.DataArray(data=array)
+
+ expected = xr.DataArray(data=func(array))
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(operator.neg, id="negate"),
+ pytest.param(abs, id="absolute"),
+ pytest.param(
+ np.round,
+ id="round",
+ marks=pytest.mark.xfail(reason="pint does not implement round"),
+ ),
+ ),
+ )
+ def test_unary_operations(self, func, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.m
+ data_array = xr.DataArray(data=array)
+
+ expected = xr.DataArray(data=func(array))
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(lambda x: 2 * x, id="multiply"),
+ pytest.param(lambda x: x + x, id="add"),
+ pytest.param(lambda x: x[0] + x, id="add scalar"),
+ pytest.param(
+ lambda x: x.T @ x,
+ id="matrix multiply",
+ marks=pytest.mark.xfail(
+ reason="pint does not support matrix multiplication yet"
+ ),
+ ),
+ ),
+ )
+ def test_binary_operations(self, func, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.m
+ data_array = xr.DataArray(data=array)
+
+ expected = xr.DataArray(data=func(array))
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "comparison",
+ (
+ pytest.param(operator.lt, id="less_than"),
+ pytest.param(operator.ge, id="greater_equal"),
+ pytest.param(operator.eq, id="equal"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, ValueError, id="without_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incorrect_unit"),
+ pytest.param(unit_registry.m, None, id="correct_unit"),
+ ),
+ )
+ def test_comparison_operations(self, comparison, unit, error, dtype):
+ array = (
+ np.array([10.1, 5.2, 6.5, 8.0, 21.3, 7.1, 1.3]).astype(dtype)
+ * unit_registry.m
+ )
+ data_array = xr.DataArray(data=array)
+
+ value = 8
+ to_compare_with = value * unit
+
+ # incompatible units are all not equal
+ if error is not None and comparison is not operator.eq:
+ with pytest.raises(error):
+ comparison(array, to_compare_with)
+
+ with pytest.raises(error):
+ comparison(data_array, to_compare_with)
+ else:
+ result = comparison(data_array, to_compare_with)
+ # pint compares incompatible arrays to False, so we need to extend
+ # the multiplication works for both scalar and array results
+ expected = xr.DataArray(
+ data=comparison(array, to_compare_with)
+ * np.ones_like(array, dtype=bool)
+ )
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "units,error",
+ (
+ pytest.param(unit_registry.dimensionless, None, id="dimensionless"),
+ pytest.param(unit_registry.m, DimensionalityError, id="incorrect unit"),
+ pytest.param(unit_registry.degree, None, id="correct unit"),
+ ),
+ )
+ def test_univariate_ufunc(self, units, error, dtype):
+ array = np.arange(10).astype(dtype) * units
+ data_array = xr.DataArray(data=array)
+
+ if error is not None:
+ with pytest.raises(error):
+ np.sin(data_array)
+ else:
+ expected = xr.DataArray(data=np.sin(array))
+ result = np.sin(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="pint's implementation of `np.maximum` strips units")
+ def test_bivariate_ufunc(self, dtype):
+ unit = unit_registry.m
+ array = np.arange(10).astype(dtype) * unit
+ data_array = xr.DataArray(data=array)
+
+ expected = xr.DataArray(np.maximum(array, 0 * unit))
+
+ assert_equal_with_units(expected, np.maximum(data_array, 0 * unit))
+ assert_equal_with_units(expected, np.maximum(0 * unit, data_array))
+
+ @pytest.mark.parametrize("property", ("T", "imag", "real"))
+ def test_numpy_properties(self, property, dtype):
+ array = (
+ np.arange(5 * 10).astype(dtype)
+ + 1j * np.linspace(-1, 0, 5 * 10).astype(dtype)
+ ).reshape(5, 10) * unit_registry.s
+ data_array = xr.DataArray(data=array, dims=("x", "y"))
+
+ expected = xr.DataArray(
+ data=getattr(array, property),
+ dims=("x", "y")[:: 1 if property != "T" else -1],
+ )
+ result = getattr(data_array, property)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("conj"),
+ method("argsort"),
+ method("conjugate"),
+ method("round"),
+ pytest.param(
+ method("rank", dim="x"),
+ marks=pytest.mark.xfail(reason="pint does not implement rank yet"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_numpy_methods(self, func, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.m
+ data_array = xr.DataArray(data=array, dims="x")
+
+ expected = xr.DataArray(func(array), dims="x")
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func", (method("clip", min=3, max=8), method("searchsorted", v=5)), ids=repr
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_numpy_methods_with_args(self, func, unit, error, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.m
+ data_array = xr.DataArray(data=array)
+
+ scalar_types = (int, float)
+ kwargs = {
+ key: (value * unit if isinstance(value, scalar_types) else value)
+ for key, value in func.kwargs.items()
+ }
+
+ if error is not None:
+ with pytest.raises(error):
+ func(data_array, **kwargs)
+ else:
+ expected = func(array, **kwargs)
+ if func.name not in ["searchsorted"]:
+ expected = xr.DataArray(data=expected)
+ result = func(data_array, **kwargs)
+
+ if func.name in ["searchsorted"]:
+ assert np.allclose(expected, result)
+ else:
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func", (method("isnull"), method("notnull"), method("count")), ids=repr
+ )
+ def test_missing_value_detection(self, func, dtype):
+ array = (
+ np.array(
+ [
+ [1.4, 2.3, np.nan, 7.2],
+ [np.nan, 9.7, np.nan, np.nan],
+ [2.1, np.nan, np.nan, 4.6],
+ [9.9, np.nan, 7.2, 9.1],
+ ]
+ )
+ * unit_registry.degK
+ )
+ x = np.arange(array.shape[0]) * unit_registry.m
+ y = np.arange(array.shape[1]) * unit_registry.m
+
+ data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
+
+ expected = func(strip_units(data_array))
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="ffill and bfill lose units in data")
+ @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr)
+ def test_missing_value_filling(self, func, dtype):
+ array = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.degK
+ )
+ x = np.arange(len(array))
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"])
+
+ result_without_units = func(strip_units(data_array), dim="x")
+ result = xr.DataArray(
+ data=result_without_units.data * unit_registry.degK,
+ coords={"x": x},
+ dims=["x"],
+ )
+
+ expected = attach_units(
+ func(strip_units(data_array), dim="x"), {"data": unit_registry.degK}
+ )
+ result = func(data_array, dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="fillna drops the unit")
+ @pytest.mark.parametrize(
+ "fill_value",
+ (
+ pytest.param(
+ -1,
+ id="python scalar",
+ marks=pytest.mark.xfail(
+ reason="python scalar cannot be converted using astype()"
+ ),
+ ),
+ pytest.param(np.array(-1), id="numpy scalar"),
+ pytest.param(np.array([-1]), id="numpy array"),
+ ),
+ )
+ def test_fillna(self, fill_value, dtype):
+ unit = unit_registry.m
+ array = np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * unit
+ data_array = xr.DataArray(data=array)
+
+ expected = attach_units(
+ strip_units(data_array).fillna(value=fill_value), {"data": unit}
+ )
+ result = data_array.fillna(value=fill_value * unit)
+
+ assert_equal_with_units(expected, result)
+
+ def test_dropna(self, dtype):
+ array = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.m
+ )
+ x = np.arange(len(array))
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"])
+
+ expected = attach_units(
+ strip_units(data_array).dropna(dim="x"), {"data": unit_registry.m}
+ )
+ result = data_array.dropna(dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="pint does not implement `numpy.isin`")
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="same_unit"),
+ ),
+ )
+ def test_isin(self, unit, dtype):
+ array = (
+ np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
+ * unit_registry.m
+ )
+ data_array = xr.DataArray(data=array, dims="x")
+
+ raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype)
+ values = raw_values * unit
+
+ result_without_units = strip_units(data_array).isin(raw_values)
+ if unit != unit_registry.m:
+ result_without_units[:] = False
+ result_with_units = data_array.isin(values)
+
+ assert_equal_with_units(result_without_units, result_with_units)
+
+ @pytest.mark.parametrize(
+ "variant",
+ (
+ pytest.param(
+ "masking",
+ marks=pytest.mark.xfail(reason="nan not compatible with quantity"),
+ ),
+ pytest.param(
+ "replacing_scalar",
+ marks=pytest.mark.xfail(reason="scalar not convertible using astype"),
+ ),
+ pytest.param(
+ "replacing_array",
+ marks=pytest.mark.xfail(
+ reason="replacing using an array drops the units"
+ ),
+ ),
+ pytest.param(
+ "dropping",
+ marks=pytest.mark.xfail(reason="nan not compatible with quantity"),
+ ),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="same_unit"),
+ ),
+ )
+ def test_where(self, variant, unit, error, dtype):
+ def _strip_units(mapping):
+ return {key: array_strip_units(value) for key, value in mapping.items()}
+
+ original_unit = unit_registry.m
+ array = np.linspace(0, 1, 10).astype(dtype) * original_unit
+
+ data_array = xr.DataArray(data=array)
+
+ condition = data_array < 0.5 * original_unit
+ other = np.linspace(-2, -1, 10).astype(dtype) * unit
+ variant_kwargs = {
+ "masking": {"cond": condition},
+ "replacing_scalar": {"cond": condition, "other": -1 * unit},
+ "replacing_array": {"cond": condition, "other": other},
+ "dropping": {"cond": condition, "drop": True},
+ }
+ kwargs = variant_kwargs.get(variant)
+ kwargs_without_units = _strip_units(kwargs)
+
+ if variant not in ("masking", "dropping") and error is not None:
+ with pytest.raises(error):
+ data_array.where(**kwargs)
+ else:
+ expected = attach_units(
+ strip_units(array).where(**kwargs_without_units),
+ {"data": original_unit},
+ )
+ result = data_array.where(**kwargs)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="interpolate strips units")
+ def test_interpolate_na(self, dtype):
+ array = (
+ np.array([-1.03, 0.1, 1.4, np.nan, 2.3, np.nan, np.nan, 9.1])
+ * unit_registry.m
+ )
+ x = np.arange(len(array))
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims="x").astype(dtype)
+
+ expected = attach_units(
+ strip_units(data_array).interpolate_na(dim="x"), {"data": unit_registry.m}
+ )
+ result = data_array.interpolate_na(dim="x")
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="uses DataArray.where, which currently fails")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_combine_first(self, unit, error, dtype):
+ array = np.zeros(shape=(2, 2), dtype=dtype) * unit_registry.m
+ other_array = np.ones_like(array) * unit
+
+ data_array = xr.DataArray(
+ data=array, coords={"x": ["a", "b"], "y": [-1, 0]}, dims=["x", "y"]
+ )
+ other = xr.DataArray(
+ data=other_array, coords={"x": ["b", "c"], "y": [0, 1]}, dims=["x", "y"]
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.combine_first(other)
+ else:
+ expected = attach_units(
+ strip_units(data_array).combine_first(strip_units(other)),
+ {"data": unit_registry.m},
+ )
+ result = data_array.combine_first(other)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(
+ unit_registry.cm,
+ id="compatible_unit",
+ marks=pytest.mark.xfail(reason="identical does not check units yet"),
+ ),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "variation",
+ (
+ "data",
+ pytest.param(
+ "dims", marks=pytest.mark.xfail(reason="units in indexes not supported")
+ ),
+ "coords",
+ ),
+ )
+ @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
+ def test_comparisons(self, func, variation, unit, dtype):
+ data = np.linspace(0, 5, 10).astype(dtype)
+ coord = np.arange(len(data)).astype(dtype)
+
+ base_unit = unit_registry.m
+ quantity = data * base_unit
+ x = coord * base_unit
+ y = coord * base_unit
+
+ units = {
+ "data": (unit, base_unit, base_unit),
+ "dims": (base_unit, unit, base_unit),
+ "coords": (base_unit, base_unit, unit),
+ }
+ data_unit, dim_unit, coord_unit = units.get(variation)
+
+ data_array = xr.DataArray(
+ data=quantity, coords={"x": x, "y": ("x", y)}, dims="x"
+ )
+
+ other = attach_units(
+ strip_units(data_array),
+ {
+ None: (data_unit, base_unit if quantity.check(data_unit) else None),
+ "x": (dim_unit, base_unit if x.check(dim_unit) else None),
+ "y": (coord_unit, base_unit if y.check(coord_unit) else None),
+ },
+ )
+
+ # TODO: test dim coord once indexes leave units intact
+ # also, express this in terms of calls on the raw data array
+ # and then check the units
+ equal_arrays = (
+ np.all(quantity == other.data)
+ and (np.all(x == other.x.data) or True) # dims can't be checked yet
+ and np.all(y == other.y.data)
+ )
+ equal_units = (
+ data_unit == unit_registry.m
+ and coord_unit == unit_registry.m
+ and dim_unit == unit_registry.m
+ )
+ expected = equal_arrays and (func.name != "identical" or equal_units)
+ result = func(data_array, other)
+
+ assert expected == result
+
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ def test_broadcast_equals(self, unit, dtype):
+ left_array = np.ones(shape=(2, 2), dtype=dtype) * unit_registry.m
+ right_array = array_attach_units(
+ np.ones(shape=(2,), dtype=dtype),
+ unit,
+ convert_from=unit_registry.m if left_array.check(unit) else None,
+ )
+
+ left = xr.DataArray(data=left_array, dims=("x", "y"))
+ right = xr.DataArray(data=right_array, dims="x")
+
+ expected = np.all(left_array == right_array[:, None])
+ result = left.broadcast_equals(right)
+
+ assert expected == result
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("pipe", lambda da: da * 10),
+ method("assign_coords", y2=("y", np.arange(10) * unit_registry.mm)),
+ method("assign_attrs", attr1="value"),
+ method("rename", x2="x_mm"),
+ method("swap_dims", {"x": "x2"}),
+ method(
+ "expand_dims",
+ dim={"z": np.linspace(10, 20, 12) * unit_registry.s},
+ axis=1,
+ ),
+ method("drop", labels="x"),
+ method("reset_coords", names="x2"),
+ method("copy"),
+ pytest.param(
+ method("astype", np.float32),
+ marks=pytest.mark.xfail(reason="units get stripped"),
+ ),
+ pytest.param(
+ method("item", 1), marks=pytest.mark.xfail(reason="units get stripped")
+ ),
+ ),
+ ids=repr,
+ )
+ def test_content_manipulation(self, func, dtype):
+ quantity = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
+ * unit_registry.pascal
+ )
+ x = np.arange(quantity.shape[0]) * unit_registry.m
+ y = np.arange(quantity.shape[1]) * unit_registry.m
+ x2 = x.to(unit_registry.mm)
+
+ data_array = xr.DataArray(
+ name="data",
+ data=quantity,
+ coords={"x": x, "x2": ("x", x2), "y": y},
+ dims=("x", "y"),
+ )
+
+ stripped_kwargs = {
+ key: array_strip_units(value) for key, value in func.kwargs.items()
+ }
+ expected = attach_units(
+ func(strip_units(data_array), **stripped_kwargs),
+ {
+ "data": quantity.units,
+ "x": x.units,
+ "x_mm": x2.units,
+ "x2": x2.units,
+ "y": y.units,
+ },
+ )
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ method("drop", labels=np.array([1, 5]), dim="x"),
+ marks=pytest.mark.xfail(
+ reason="selecting using incompatible units does not raise"
+ ),
+ ),
+ pytest.param(method("copy", data=np.arange(20))),
+ ),
+ ids=repr,
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, KeyError, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_content_manipulation_with_units(self, func, unit, error, dtype):
+ quantity = np.linspace(0, 10, 20, dtype=dtype) * unit_registry.pascal
+ x = np.arange(len(quantity)) * unit_registry.m
+
+ data_array = xr.DataArray(name="data", data=quantity, coords={"x": x}, dims="x")
+
+ kwargs = {
+ key: (value * unit if isinstance(value, np.ndarray) else value)
+ for key, value in func.kwargs.items()
+ }
+ stripped_kwargs = func.kwargs
+
+ expected = attach_units(
+ func(strip_units(data_array), **stripped_kwargs),
+ {"data": quantity.units if func.name == "drop" else unit, "x": x.units},
+ )
+ if error is not None and func.name == "drop":
+ with pytest.raises(error):
+ func(data_array, **kwargs)
+ else:
+ result = func(data_array, **kwargs)
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "indices",
+ (
+ pytest.param(4, id="single index"),
+ pytest.param([5, 2, 9, 1], id="multiple indices"),
+ ),
+ )
+ def test_isel(self, indices, dtype):
+ array = np.arange(10).astype(dtype) * unit_registry.s
+ x = np.arange(len(array)) * unit_registry.m
+
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"])
+
+ expected = attach_units(
+ strip_units(data_array).isel(x=indices),
+ {"data": unit_registry.s, "x": unit_registry.m},
+ )
+ result = data_array.isel(x=indices)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(
+ reason="xarray does not support duck arrays in dimension coordinates"
+ )
+ @pytest.mark.parametrize(
+ "values",
+ (
+ pytest.param(12, id="single value"),
+ pytest.param([10, 5, 13], id="list of multiple values"),
+ pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "units,error",
+ (
+ pytest.param(1, KeyError, id="no units"),
+ pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
+ pytest.param(unit_registry.degree, KeyError, id="incorrect unit"),
+ pytest.param(unit_registry.s, None, id="correct unit"),
+ ),
+ )
+ def test_sel(self, values, units, error, dtype):
+ array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m
+ x = np.arange(len(array)) * unit_registry.s
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"])
+
+ values_with_units = values * units
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.sel(x=values_with_units)
+ else:
+ result_array = array[values]
+ result_data_array = data_array.sel(x=values_with_units)
+ assert_equal_with_units(result_array, result_data_array)
+
+ @pytest.mark.xfail(
+ reason="xarray does not support duck arrays in dimension coordinates"
+ )
+ @pytest.mark.parametrize(
+ "values",
+ (
+ pytest.param(12, id="single value"),
+ pytest.param([10, 5, 13], id="list of multiple values"),
+ pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "units,error",
+ (
+ pytest.param(1, KeyError, id="no units"),
+ pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
+ pytest.param(unit_registry.degree, KeyError, id="incorrect unit"),
+ pytest.param(unit_registry.s, None, id="correct unit"),
+ ),
+ )
+ def test_loc(self, values, units, error, dtype):
+ array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m
+ x = np.arange(len(array)) * unit_registry.s
+ data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"])
+
+ values_with_units = values * units
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.loc[values_with_units]
+ else:
+ result_array = array[values]
+ result_data_array = data_array.loc[values_with_units]
+ assert_equal_with_units(result_array, result_data_array)
+
+ @pytest.mark.xfail(reason="tries to coerce using asarray")
+ @pytest.mark.parametrize(
+ "shape",
+ (
+ pytest.param((10, 20), id="nothing squeezable"),
+ pytest.param((10, 20, 1), id="last dimension squeezable"),
+ pytest.param((10, 1, 20), id="middle dimension squeezable"),
+ pytest.param((1, 10, 20), id="first dimension squeezable"),
+ pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"),
+ ),
+ )
+ def test_squeeze(self, shape, dtype):
+ names = "xyzt"
+ coords = {
+ name: np.arange(length).astype(dtype)
+ * (unit_registry.m if name != "t" else unit_registry.s)
+ for name, length in zip(names, shape)
+ }
+ array = np.arange(10 * 20).astype(dtype).reshape(shape) * unit_registry.J
+ data_array = xr.DataArray(
+ data=array, coords=coords, dims=tuple(names[: len(shape)])
+ )
+
+ result_array = array.squeeze()
+ result_data_array = data_array.squeeze()
+ assert_equal_with_units(result_array, result_data_array)
+
+ # try squeezing the dimensions separately
+ names = tuple(dim for dim, coord in coords.items() if len(coord) == 1)
+ for index, name in enumerate(names):
+ assert_equal_with_units(
+ np.squeeze(array, axis=index), data_array.squeeze(dim=name)
+ )
+
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, None, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, None, id="dimensionless"),
+ pytest.param(unit_registry.s, None, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_interp(self, unit, error):
+ array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ new_coords = (np.arange(10) + 0.5) * unit
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ }
+
+ data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.interp(x=new_coords)
+ else:
+ new_coords_ = (
+ new_coords.magnitude if hasattr(new_coords, "magnitude") else new_coords
+ )
+ result_array = strip_units(data_array).interp(
+ x=new_coords_ * unit_registry.degK
+ )
+ result_data_array = data_array.interp(x=new_coords)
+
+ assert_equal_with_units(result_array, result_data_array)
+
+ @pytest.mark.xfail(reason="tries to coerce using asarray")
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, None, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, None, id="dimensionless"),
+ pytest.param(unit_registry.s, None, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_interp_like(self, unit, error):
+ array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ coords = {
+ "x": (np.arange(10) + 0.3) * unit_registry.m,
+ "y": (np.arange(5) + 0.3) * unit_registry.m,
+ }
+
+ data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
+ new_data_array = xr.DataArray(
+ data=np.empty((20, 10)),
+ coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit},
+ dims=("x", "y"),
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.interp_like(new_data_array)
+ else:
+ result_array = (
+ xr.DataArray(
+ data=array.magnitude,
+ coords={name: value.magnitude for name, value in coords.items()},
+ dims=("x", "y"),
+ ).interp_like(strip_units(new_data_array))
+ * unit_registry.degK
+ )
+ result_data_array = data_array.interp_like(new_data_array)
+
+ assert_equal_with_units(result_array, result_data_array)
+
+ @pytest.mark.xfail(
+ reason="pint does not implement np.result_type in __array_function__ yet"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, None, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, None, id="dimensionless"),
+ pytest.param(unit_registry.s, None, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_reindex(self, unit, error):
+ array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ new_coords = (np.arange(10) + 0.5) * unit
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ }
+
+ data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.interp(x=new_coords)
+ else:
+ result_array = strip_units(data_array).reindex(
+ x=(
+ new_coords.magnitude
+ if hasattr(new_coords, "magnitude")
+ else new_coords
+ )
+ * unit_registry.degK
+ )
+ result_data_array = data_array.reindex(x=new_coords)
+
+ assert_equal_with_units(result_array, result_data_array)
+
+ @pytest.mark.xfail(
+ reason="pint does not implement np.result_type in __array_function__ yet"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, None, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, None, id="dimensionless"),
+ pytest.param(unit_registry.s, None, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ def test_reindex_like(self, unit, error):
+ array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+ coords = {
+ "x": (np.arange(10) + 0.3) * unit_registry.m,
+ "y": (np.arange(5) + 0.3) * unit_registry.m,
+ }
+
+ data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
+ new_data_array = xr.DataArray(
+ data=np.empty((20, 10)),
+ coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit},
+ dims=("x", "y"),
+ )
+
+ if error is not None:
+ with pytest.raises(error):
+ data_array.reindex_like(new_data_array)
+ else:
+ expected = attach_units(
+ strip_units(data_array).reindex_like(strip_units(new_data_array)),
+ {
+ "data": unit_registry.degK,
+ "x": unit_registry.m,
+ "y": unit_registry.m,
+ },
+ )
+ result = data_array.reindex_like(new_data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (method("unstack"), method("reset_index", "z"), method("reorder_levels")),
+ ids=repr,
+ )
+ def test_stacking_stacked(self, func, dtype):
+ array = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
+ )
+ x = np.arange(array.shape[0])
+ y = np.arange(array.shape[1])
+
+ data_array = xr.DataArray(
+ name="data", data=array, coords={"x": x, "y": y}, dims=("x", "y")
+ )
+ stacked = data_array.stack(z=("x", "y"))
+
+ expected = attach_units(func(strip_units(stacked)), {"data": unit_registry.m})
+ result = func(stacked)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="indexes strip the label units")
+ def test_to_unstacked_dataset(self, dtype):
+ array = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
+ * unit_registry.pascal
+ )
+ x = np.arange(array.shape[0]) * unit_registry.m
+ y = np.arange(array.shape[1]) * unit_registry.s
+
+ data_array = xr.DataArray(
+ data=array, coords={"x": x, "y": y}, dims=("x", "y")
+ ).stack(z=("x", "y"))
+
+ func = method("to_unstacked_dataset", dim="z")
+
+ expected = attach_units(
+ func(strip_units(data_array)),
+ {"y": y.units, **dict(zip(x.magnitude, [array.units] * len(y)))},
+ ).rename({elem.magnitude: elem for elem in x})
+ result = func(data_array)
+
+ print(data_array, expected, result, sep="\n")
+
+ assert_equal_with_units(expected, result)
+
+ assert False
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("transpose", "y", "x", "z"),
+ method("stack", a=("x", "y")),
+ method("set_index", x="x2"),
+ pytest.param(
+ method("shift", x=2), marks=pytest.mark.xfail(reason="strips units")
+ ),
+ pytest.param(
+ method("roll", x=2, roll_coords=False),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ method("sortby", "x2"),
+ ),
+ ids=repr,
+ )
+ def test_stacking_reordering(self, func, dtype):
+ array = (
+ np.linspace(0, 10, 2 * 5 * 10).reshape(2, 5, 10).astype(dtype)
+ * unit_registry.m
+ )
+ x = np.arange(array.shape[0])
+ y = np.arange(array.shape[1])
+ z = np.arange(array.shape[2])
+ x2 = np.linspace(0, 1, array.shape[0])[::-1]
+
+ data_array = xr.DataArray(
+ name="data",
+ data=array,
+ coords={"x": x, "y": y, "z": z, "x2": ("x", x2)},
+ dims=("x", "y", "z"),
+ )
+
+ expected = attach_units(
+ func(strip_units(data_array)), {"data": unit_registry.m}
+ )
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ method("diff", dim="x"),
+ method("differentiate", coord="x"),
+ method("integrate", dim="x"),
+ pytest.param(
+ method("quantile", q=[0.25, 0.75]),
+ marks=pytest.mark.xfail(
+ reason="pint does not implement nanpercentile yet"
+ ),
+ ),
+ pytest.param(
+ method("reduce", func=np.sum, dim="x"),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ pytest.param(
+ lambda x: x.dot(x),
+ id="method_dot",
+ marks=pytest.mark.xfail(reason="pint does not implement einsum"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_computation(self, func, dtype):
+ array = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
+ )
+
+ x = np.arange(array.shape[0]) * unit_registry.m
+ y = np.arange(array.shape[1]) * unit_registry.s
+
+ data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
+ units = extract_units(data_array)
+
+ expected = attach_units(func(strip_units(data_array)), units)
+ result = func(data_array)
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ method("groupby", "y"), marks=pytest.mark.xfail(reason="strips units")
+ ),
+ pytest.param(
+ method("groupby_bins", "y", bins=4),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ method("coarsen", y=2),
+ pytest.param(
+ method("rolling", y=3), marks=pytest.mark.xfail(reason="strips units")
+ ),
+ pytest.param(
+ method("rolling_exp", y=3),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_computation_objects(self, func, dtype):
+ array = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
+ )
+
+ x = np.arange(array.shape[0]) * unit_registry.m
+ y = np.arange(array.shape[1]) * 3 * unit_registry.s
+
+ data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
+ units = extract_units(data_array)
+
+ expected = attach_units(func(strip_units(data_array)).mean(), units)
+ result = func(data_array).mean()
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.xfail(reason="strips units")
+ def test_resample(self, dtype):
+ array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
+
+ time = pd.date_range("10-09-2010", periods=len(array), freq="1y")
+ data_array = xr.DataArray(data=array, coords={"time": time}, dims="time")
+ units = extract_units(data_array)
+
+ func = method("resample", time="6m")
+
+ expected = attach_units(func(strip_units(data_array)).mean(), units)
+ result = func(data_array).mean()
+
+ assert_equal_with_units(expected, result)
+
+ @pytest.mark.parametrize(
+ "func",
+ (
+ pytest.param(
+ method("assign_coords", {"z": (["x"], np.arange(5) * unit_registry.s)}),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ pytest.param(method("first")),
+ pytest.param(method("last")),
+ pytest.param(
+ method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
+ marks=pytest.mark.xfail(reason="strips units"),
+ ),
+ ),
+ ids=repr,
+ )
+ def test_grouped_operations(self, func, dtype):
+ array = (
+ np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
+ )
+
+ x = np.arange(array.shape[0]) * unit_registry.m
+ y = np.arange(array.shape[1]) * 3 * unit_registry.s
+
+ data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
+ units = extract_units(data_array)
+
+ expected = attach_units(func(strip_units(data_array).groupby("y")), units)
+ result = func(data_array.groupby("y"))
+
+ assert_equal_with_units(expected, result)