diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml index a2245e89b41..48b9c057260 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py36-min-nep18.yml @@ -11,7 +11,7 @@ dependencies: - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 - numpy=1.17 - pandas=0.25 - - pint=0.11 + - pint - pip - pytest - pytest-cov diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bcff60ce4df..4b5bb1e491f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -72,7 +72,7 @@ New Features - Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`) By `Kai Mühlbauer `_. -- More support for unit aware arrays with pint (:pull:`3643`) +- More support for unit aware arrays with pint (:pull:`3643`, :pull:`3975`) By `Justus Magin `_. - Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even without ``append_dim``, as long as dimension sizes do not change. diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 6f4f9f768d9..b477e8cccb2 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,9 +7,8 @@ import pytest import xarray as xr -from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE -from xarray.testing import assert_allclose, assert_identical +from xarray.testing import assert_allclose, assert_equal, assert_identical from .test_variable import _PAD_XR_NP_ARGS, VariableSubclassobjects @@ -27,11 +26,6 @@ pytest.mark.skipif( not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled" ), - # TODO: remove this once pint has a released version with __array_function__ - pytest.mark.skipif( - not hasattr(unit_registry.Quantity, "__array_function__"), - reason="pint does not implement __array_function__ yet", - ), # pytest.mark.filterwarnings("ignore:::pint[.*]"), ] @@ -51,10 +45,23 @@ def dimensionality(obj): def compatible_mappings(first, second): return { key: is_compatible(unit1, unit2) - for key, (unit1, unit2) in merge_mappings(first, second) + for key, (unit1, unit2) in zip_mappings(first, second) } +def merge_mappings(base, *mappings): + result = base.copy() + for m in mappings: + result.update(m) + + return result + + +def zip_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -257,50 +264,11 @@ def assert_units_equal(a, b): assert extract_units(a) == extract_units(b) -def assert_equal_with_units(a, b): - # works like xr.testing.assert_equal, but also explicitly checks units - # so, it is more like assert_identical - __tracebackhide__ = True - - if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset): - a_units = extract_units(a) - b_units = extract_units(b) - - a_without_units = strip_units(a) - b_without_units = strip_units(b) - - assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr( - a, b, "equals" - ) - assert a_units == b_units - else: - a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data - b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data - - assert type(a) == type(b) or ( - isinstance(a, Quantity) and isinstance(b, Quantity) - ) - - # workaround until pint implements allclose in __array_function__ - if isinstance(a, Quantity) or isinstance(b, Quantity): - assert ( - hasattr(a, "magnitude") and hasattr(b, "magnitude") - ) and np.allclose(a.magnitude, b.magnitude, equal_nan=True) - assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units - else: - assert np.allclose(a, b, equal_nan=True) - - @pytest.fixture(params=[float, int]) def dtype(request): return request.param -def merge_mappings(*mappings): - for key in set(mappings[0]).intersection(*mappings[1:]): - yield key, tuple(m[key] for m in mappings) - - def merge_args(default_args, new_args): from itertools import zip_longest @@ -427,7 +395,7 @@ def test_apply_ufunc_dataset(dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "unit,error", @@ -518,7 +486,7 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "unit,error", @@ -939,7 +907,7 @@ def test_concat_dataset(variant, unit, error, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "unit,error", @@ -1050,7 +1018,7 @@ def test_merge_dataarray(variant, unit, error, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "unit,error", @@ -1430,7 +1398,7 @@ def example_1d_objects(self): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) def test_real_and_imag(self): super().test_real_and_imag() @@ -1474,7 +1442,7 @@ def test_aggregation(self, func, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) def test_aggregate_complex(self): variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) @@ -1486,7 +1454,7 @@ def test_aggregate_complex(self): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "func", @@ -1788,7 +1756,7 @@ def test_isel(self, indices, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "unit,error", @@ -1928,7 +1896,7 @@ def test_squeeze(self, dtype): pytest.param( method("quantile", q=[0.25, 0.75]), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -2268,7 +2236,7 @@ def test_repr(self, func, variant, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose", + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose", ) @pytest.mark.parametrize( "func", @@ -3331,7 +3299,7 @@ def test_head_tail_thin(self, func, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( @@ -3408,7 +3376,7 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( @@ -3577,7 +3545,7 @@ def test_stacking_reordering(self, func, dtype): pytest.param( method("quantile", q=[0.25, 0.75]), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -3614,7 +3582,7 @@ def test_computation(self, func, dtype): # TODO: remove once pint==0.12 has been released @pytest.mark.xfail( - LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose" + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" ) @pytest.mark.parametrize( "func", @@ -3630,7 +3598,9 @@ def test_computation(self, func, dtype): ), pytest.param( method("rolling_exp", y=3), - marks=pytest.mark.xfail(reason="units not supported by numbagg"), + marks=pytest.mark.xfail( + reason="numbagg functions are not supported by pint" + ), ), ), ids=repr, @@ -3676,7 +3646,7 @@ def test_resample(self, dtype): pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -3711,15 +3681,16 @@ def test_grouped_operations(self, func, dtype): xr.testing.assert_identical(expected, actual) +@pytest.mark.filterwarnings("error::pint.UnitStrippedWarning") class TestDataset: @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param(1, xr.MergeError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + unit_registry.dimensionless, xr.MergeError, id="dimensionless" ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"), pytest.param(unit_registry.mm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="same_unit"), ), @@ -3728,11 +3699,10 @@ class TestDataset: "shared", ( "nothing", - pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), pytest.param( - "coords", - marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), + "coords", ), ) def test_init(self, shared, unit, error, dtype): @@ -3740,60 +3710,53 @@ def test_init(self, shared, unit, error, dtype): scaled_unit = unit_registry.mm a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa - b = np.linspace(-1, 0, 12).astype(dtype) * unit_registry.Pa - - raw_x = np.arange(a.shape[0]) - x = raw_x * original_unit - x2 = x.to(scaled_unit) - - raw_y = np.arange(b.shape[0]) - y = raw_y * unit - y_units = unit if isinstance(y, unit_registry.Quantity) else None - if isinstance(y, unit_registry.Quantity): - if y.check(scaled_unit): - y2 = y.to(scaled_unit) - else: - y2 = y * 1000 - y2_units = y2.units - else: - y2 = y * 1000 - y2_units = None + b = np.linspace(-1, 0, 10).astype(dtype) * unit_registry.degK + + values_a = np.arange(a.shape[0]) + dim_a = values_a * original_unit + coord_a = dim_a.to(scaled_unit) + + values_b = np.arange(b.shape[0]) + dim_b = values_b * unit + coord_b = ( + dim_b.to(scaled_unit) + if unit_registry.is_compatible_with(dim_b, scaled_unit) + and unit != scaled_unit + else dim_b * 1000 + ) variants = { - "nothing": ({"x": x, "x2": ("x", x2)}, {"y": y, "y2": ("y", y2)}), - "dims": ( - {"x": x, "x2": ("x", strip_units(x2))}, - {"x": y, "y2": ("x", strip_units(y2))}, + "nothing": ({}, {}), + "dims": ({"x": dim_a}, {"x": dim_b}), + "coords": ( + {"x": values_a, "y": ("x", coord_a)}, + {"x": values_b, "y": ("x", coord_b)}, ), - "coords": ({"x": raw_x, "y": ("x", x2)}, {"x": raw_y, "y": ("x", y2)}), } coords_a, coords_b = variants.get(shared) dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x") - arr1 = xr.DataArray(data=a, coords=coords_a, dims=dims_a) - arr2 = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + a = xr.DataArray(data=a, coords=coords_a, dims=dims_a) + b = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + if error is not None and shared != "nothing": with pytest.raises(error): - xr.Dataset(data_vars={"a": arr1, "b": arr2}) + xr.Dataset(data_vars={"a": a, "b": b}) return - actual = xr.Dataset(data_vars={"a": arr1, "b": arr2}) + actual = xr.Dataset(data_vars={"a": a, "b": b}) - expected_units = { - "a": a.units, - "b": b.units, - "x": x.units, - "x2": x2.units, - "y": y_units, - "y2": y2_units, - } + units = merge_mappings( + extract_units(a.rename("a")), extract_units(b.rename("b")) + ) expected = attach_units( - xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}), - expected_units, + xr.Dataset(data_vars={"a": strip_units(a), "b": strip_units(b)}), units ) - assert_equal_with_units(actual, expected) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) @@ -3801,48 +3764,45 @@ def test_init(self, shared, unit, error, dtype): @pytest.mark.parametrize( "variant", ( + "data", pytest.param( - "with_dims", + "dims", marks=pytest.mark.xfail(reason="units in indexes are not supported"), ), - pytest.param("with_coords"), - pytest.param("without_coords"), + "coords", ), ) - @pytest.mark.filterwarnings("error:::pint[.*]") def test_repr(self, func, variant, dtype): - array1 = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.Pa - array2 = np.linspace(0, 1, 10, dtype=dtype) * unit_registry.degK + unit1, unit2 = ( + (unit_registry.Pa, unit_registry.degK) if variant == "data" else (1, 1) + ) + + array1 = np.linspace(1, 2, 10, dtype=dtype) * unit1 + array2 = np.linspace(0, 1, 10, dtype=dtype) * unit2 x = np.arange(len(array1)) * unit_registry.s y = x.to(unit_registry.ms) variants = { - "with_dims": {"x": x}, - "with_coords": {"y": ("x", y)}, - "without_coords": {}, + "dims": {"x": x}, + "coords": {"y": ("x", y)}, + "data": {}, } - data_array = xr.Dataset( + ds = xr.Dataset( data_vars={"a": ("x", array1), "b": ("x", array2)}, coords=variants.get(variant), ) # FIXME: this just checks that the repr does not raise # warnings or errors, but does not check the result - func(data_array) + func(ds) @pytest.mark.parametrize( "func", ( - pytest.param( - function("all"), - marks=pytest.mark.xfail(reason="not implemented by pint"), - ), - pytest.param( - function("any"), - marks=pytest.mark.xfail(reason="not implemented by pint"), - ), + function("all"), + function("any"), function("argmax"), function("argmin"), function("max"), @@ -3850,28 +3810,19 @@ def test_repr(self, func, variant, dtype): function("mean"), pytest.param( function("median"), - marks=pytest.mark.xfail( - reason="np.median does not work with dataset yet" - ), + marks=pytest.mark.xfail(reason="median does not work with dataset yet"), ), function("sum"), pytest.param( function("prod"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + marks=pytest.mark.xfail(reason="prod does not work with dataset yet"), ), function("std"), function("var"), function("cumsum"), - pytest.param( - function("cumprod"), - marks=pytest.mark.xfail(reason="fails within xarray"), - ), - pytest.param( - method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), - pytest.param( - method("any"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), + function("cumprod"), + method("all"), + method("any"), method("argmax"), method("argmin"), method("max"), @@ -3881,68 +3832,49 @@ def test_repr(self, func, variant, dtype): method("sum"), pytest.param( method("prod"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + marks=pytest.mark.xfail(reason="prod does not work with dataset yet"), ), method("std"), method("var"), method("cumsum"), - pytest.param( - method("cumprod"), marks=pytest.mark.xfail(reason="fails within xarray") - ), + method("cumprod"), ), ids=repr, ) def test_aggregation(self, func, dtype): - unit_a = ( - unit_registry.Pa if func.name != "cumprod" else unit_registry.dimensionless - ) - unit_b = ( - unit_registry.kg / unit_registry.m ** 3 + unit_a, unit_b = ( + (unit_registry.Pa, unit_registry.degK) if func.name != "cumprod" - else unit_registry.dimensionless - ) - a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x") - b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x") - x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x") - y = xr.DataArray( - data=np.arange(10, 20).astype(dtype) * unit_registry.s, dims="x" + else (unit_registry.dimensionless, unit_registry.dimensionless) ) - ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y}) + a = np.linspace(0, 1, 10).astype(dtype) * unit_a + b = np.linspace(-1, 0, 10).astype(dtype) * unit_b + + ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + + units_a = array_extract_units(func(a)) + units_b = array_extract_units(func(b)) + units = {"a": units_a, "b": units_b} actual = func(ds) - expected = attach_units( - func(strip_units(ds)), - { - "a": extract_units(func(a)).get(None), - "b": extract_units(func(b)).get(None), - }, - ) + expected = attach_units(func(strip_units(ds)), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize("property", ("imag", "real")) def test_numpy_properties(self, property, dtype): - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray( - data=np.linspace(0, 1, 10) * unit_registry.Pa, dims="x" - ), - "b": xr.DataArray( - data=np.linspace(-1, 0, 15) * unit_registry.Pa, dims="y" - ), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) + a = np.linspace(0, 1, 10) * unit_registry.Pa + b = np.linspace(-1, 0, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) units = extract_units(ds) actual = getattr(ds, property) expected = attach_units(getattr(strip_units(ds), property), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -3956,31 +3888,19 @@ def test_numpy_properties(self, property, dtype): ids=repr, ) def test_numpy_methods(self, func, dtype): - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray( - data=np.linspace(1, -1, 10) * unit_registry.Pa, dims="x" - ), - "b": xr.DataArray( - data=np.linspace(-1, 1, 15) * unit_registry.Pa, dims="y" - ), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) - units = { - "a": array_extract_units(func(ds.a)), - "b": array_extract_units(func(ds.b)), - "x": unit_registry.m, - "y": unit_registry.s, - } + a = np.linspace(1, -1, 10) * unit_registry.Pa + b = np.linspace(-1, 1, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) + + units_a = array_extract_units(func(a)) + units_b = array_extract_units(func(b)) + units = {"a": units_a, "b": units_b} actual = func(ds) expected = attach_units(func(strip_units(ds)), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr) @pytest.mark.parametrize( @@ -3997,21 +3917,13 @@ def test_numpy_methods(self, func, dtype): ) def test_numpy_methods_with_args(self, func, unit, error, dtype): data_unit = unit_registry.m - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=np.arange(10) * data_unit, dims="x"), - "b": xr.DataArray(data=np.arange(15) * data_unit, dims="y"), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) + a = np.linspace(0, 10, 15) * unit_registry.m + b = np.linspace(-2, 12, 20) * unit_registry.m + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) units = extract_units(ds) kwargs = { - key: (value * unit if isinstance(value, (int, float)) else value) - for key, value in func.kwargs.items() + key: array_attach_units(value, unit) for key, value in func.kwargs.items() } if error is not None: @@ -4028,7 +3940,8 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype): actual = func(ds, **kwargs) expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", (method("isnull"), method("notnull"), method("count")), ids=repr @@ -4058,22 +3971,13 @@ def test_missing_value_detection(self, func, dtype): * unit_registry.Pa ) - x = np.arange(array1.shape[0]) * unit_registry.m - y = np.arange(array1.shape[1]) * unit_registry.m - z = np.arange(array2.shape[0]) * unit_registry.m - - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("z", "x")), - }, - coords={"x": x, "y": y, "z": z}, - ) + ds = xr.Dataset({"a": (("x", "y"), array1), "b": (("z", "x"), array2)}) expected = func(strip_units(ds)) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="ffill and bfill lose the unit") @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) @@ -4087,23 +3991,14 @@ def test_missing_value_filling(self, func, dtype): * unit_registry.Pa ) - x = np.arange(len(array1)) - - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("y", array2)}) + units = extract_units(ds) - expected = attach_units( - func(strip_units(ds), dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(func(strip_units(ds), dim="x"), units) actual = func(ds, dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", @@ -4113,14 +4008,7 @@ def test_missing_value_filling(self, func, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param( - unit_registry.cm, - None, - id="compatible_unit", - marks=pytest.mark.xfail( - reason="where converts the array, not the fill value" - ), - ), + pytest.param(unit_registry.cm, None, id="compatible_unit",), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4141,30 +4029,26 @@ def test_fillna(self, fill_value, unit, error, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.m ) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - } - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + value = fill_value * unit + units = extract_units(ds) if error is not None: with pytest.raises(error): - ds.fillna(value=fill_value * unit) + ds.fillna(value=value) return - actual = ds.fillna(value=fill_value * unit) + actual = ds.fillna(value=value) expected = attach_units( strip_units(ds).fillna( - value=strip_units( - convert_units(fill_value * unit, {None: unit_registry.m}) - ) + value=strip_units(convert_units(value, {None: unit_registry.m})) ), - {"a": unit_registry.m, "b": unit_registry.m}, + units, ) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) def test_dropna(self, dtype): array1 = ( @@ -4175,22 +4059,14 @@ def test_dropna(self, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.Pa ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).dropna(dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(strip_units(ds).dropna(dim="x"), units) actual = ds.dropna(dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4211,34 +4087,28 @@ def test_isin(self, unit, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.m ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) values = raw_values * unit - if ( - isinstance(values, unit_registry.Quantity) - and values.check(unit_registry.m) - and unit != unit_registry.m - ): - raw_values = values.to(unit_registry.m).magnitude + converted_values = ( + convert_units(values, {None: unit_registry.m}) + if is_compatible(unit, unit_registry.m) + else values + ) - expected = strip_units(ds).isin(raw_values) - if not isinstance(values, unit_registry.Quantity) or not values.check( - unit_registry.m - ): + expected = strip_units(ds).isin(strip_units(converted_values)) + # TODO: use `unit_registry.is_compatible_with(unit, unit_registry.m)` instead. + # Needs `pint>=0.12.1`, though, so we probably should wait until that is released. + if not is_compatible(unit, unit_registry.m): expected.a[:] = False expected.b[:] = False + actual = ds.isin(values) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "variant", ("masking", "replacing_scalar", "replacing_array", "dropping") @@ -4260,13 +4130,8 @@ def test_where(self, variant, unit, error, dtype): array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": np.arange(len(array1))}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) condition = ds < 0.5 * original_unit other = np.linspace(-2, -1, 10).astype(dtype) * unit @@ -4288,15 +4153,13 @@ def test_where(self, variant, unit, error, dtype): for key, value in kwargs.items() } - expected = attach_units( - strip_units(ds).where(**kwargs_without_units), - {"a": original_unit, "b": original_unit}, - ) + expected = attach_units(strip_units(ds).where(**kwargs_without_units), units,) actual = ds.where(**kwargs) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="interpolate strips units") + @pytest.mark.xfail(reason="interpolate_na uses numpy.vectorize") def test_interpolate_na(self, dtype): array1 = ( np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) @@ -4306,24 +4169,15 @@ def test_interpolate_na(self, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.Pa ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).interpolate_na(dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units,) actual = ds.interpolate_na(dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="wrong argument order for `where`") @pytest.mark.parametrize( "unit,error", ( @@ -4336,31 +4190,40 @@ def test_interpolate_na(self, dtype): pytest.param(unit_registry.m, None, id="same_unit"), ), ) - def test_combine_first(self, unit, error, dtype): + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_combine_first(self, variant, unit, error, dtype): + variants = { + "data": (unit_registry.m, unit, 1, 1), + "dims": (1, 1, unit_registry.m, unit), + } + data_unit, other_data_unit, dims_unit, other_dims_unit = variants.get(variant) + array1 = ( - np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) - * unit_registry.m + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * data_unit ) array2 = ( - np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) - * unit_registry.m + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * data_unit ) - x = np.arange(len(array1)) + x = np.arange(len(array1)) * dims_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, + data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}, ) - other_array1 = np.ones_like(array1) * unit - other_array2 = -1 * np.ones_like(array2) * unit + units = extract_units(ds) + + other_array1 = np.ones_like(array1) * other_data_unit + other_array2 = np.full_like(array2, fill_value=-1) * other_data_unit + other_x = (np.arange(array1.shape[0]) + 5) * other_dims_unit other = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=other_array1, dims="x"), - "b": xr.DataArray(data=other_array2, dims="x"), - }, - coords={"x": np.arange(array1.shape[0])}, + data_vars={"a": ("x", other_array1), "b": ("x", other_array2)}, + coords={"x": other_x}, ) if error is not None: @@ -4370,16 +4233,13 @@ def test_combine_first(self, unit, error, dtype): return expected = attach_units( - strip_units(ds).combine_first( - strip_units( - convert_units(other, {"a": unit_registry.m, "b": unit_registry.m}) - ) - ), - {"a": unit_registry.m, "b": unit_registry.m}, + strip_units(ds).combine_first(strip_units(convert_units(other, units))), + units, ) actual = ds.combine_first(other) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4392,7 +4252,7 @@ def test_combine_first(self, unit, error, dtype): ), ) @pytest.mark.parametrize( - "variation", + "variant", ( "data", pytest.param( @@ -4401,50 +4261,67 @@ def test_combine_first(self, unit, error, dtype): "coords", ), ) - @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) - def test_comparisons(self, func, variation, unit, dtype): - def is_compatible(a, b): - a = a if a is not None else 1 - b = b if b is not None else 1 - quantity = np.arange(5) * a - - return a == b or quantity.check(b) - + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip("behaviour of identical is unclear"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, variant, unit, dtype): array1 = np.linspace(0, 5, 10).astype(dtype) array2 = np.linspace(-5, 0, 10).astype(dtype) coord = np.arange(len(array1)).astype(dtype) - original_unit = unit_registry.m - quantity1 = array1 * original_unit - quantity2 = array2 * original_unit - x = coord * original_unit - y = coord * original_unit + variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } + data_unit, dim_unit, coord_unit = variants.get(variant) - units = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} - data_unit, dim_unit, coord_unit = units.get(variation) + a = array1 * data_unit + b = array2 * data_unit + x = coord * dim_unit + y = coord * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=quantity1, dims="x"), - "b": xr.DataArray(data=quantity2, dims="x"), - }, - coords={"x": x, "y": ("x", y)}, + data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)}, ) + units = extract_units(ds) + + other_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + other_data_unit, other_dim_unit, other_coord_unit = other_variants.get(variant) other_units = { - "a": data_unit if quantity1.check(data_unit) else None, - "b": data_unit if quantity2.check(data_unit) else None, - "x": dim_unit if x.check(dim_unit) else None, - "y": coord_unit if y.check(coord_unit) else None, + "a": other_data_unit, + "b": other_data_unit, + "x": other_dim_unit, + "y": other_coord_unit, } - other = attach_units(strip_units(convert_units(ds, other_units)), other_units) - units = extract_units(ds) + to_convert = { + key: unit if is_compatible(unit, reference) else None + for key, (unit, reference) in zip_mappings(units, other_units) + } + # convert units where possible, then attach all units to the converted dataset + other = attach_units(strip_units(convert_units(ds, to_convert)), other_units) other_units = extract_units(other) + # make sure all units are compatible and only then try to + # convert and compare values equal_ds = all( - is_compatible(units[name], other_units[name]) for name in units.keys() + is_compatible(unit, other_unit) + for _, (unit, other_unit) in zip_mappings(units, other_units) ) and (strip_units(ds).equals(strip_units(convert_units(other, units)))) equal_units = units == other_units expected = equal_ds and (func.name != "identical" or equal_units) @@ -4453,6 +4330,9 @@ def is_compatible(a, b): assert expected == actual + # TODO: eventually use another decorator / wrapper function that + # applies a filter to the parametrize combinations: + # we only need a single test for data @pytest.mark.parametrize( "unit", ( @@ -4463,14 +4343,29 @@ def is_compatible(a, b): pytest.param(unit_registry.m, id="identical_unit"), ), ) - def test_broadcast_like(self, unit, dtype): - array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa - array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_broadcast_like(self, variant, unit, dtype): + variants = { + "data": ((unit_registry.m, unit), (1, 1)), + "dims": ((1, 1), (unit_registry.m, unit)), + } + (data_unit1, data_unit2), (dim_unit1, dim_unit2) = variants.get(variant) - x1 = np.arange(2) * unit_registry.m - x2 = np.arange(2) * unit - y1 = np.array([0]) * unit_registry.m - y2 = np.arange(3) * unit + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1 + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2 + + x1 = np.arange(2) * dim_unit1 + x2 = np.arange(2) * dim_unit2 + y1 = np.array([0]) * dim_unit1 + y2 = np.arange(3) * dim_unit2 ds1 = xr.Dataset( data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1} @@ -4484,7 +4379,8 @@ def test_broadcast_like(self, unit, dtype): ) actual = ds1.broadcast_like(ds2) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4497,32 +4393,25 @@ def test_broadcast_like(self, unit, dtype): ), ) def test_broadcast_equals(self, unit, dtype): + # TODO: does this use indexes? left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m left_array2 = np.zeros(shape=(3, 6), dtype=dtype) * unit_registry.m right_array1 = np.ones(shape=(2,)) * unit - right_array2 = np.ones(shape=(3,)) * unit + right_array2 = np.zeros(shape=(3,)) * unit left = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=left_array1, dims=("x", "y")), - "b": xr.DataArray(data=left_array2, dims=("y", "z")), - } - ) - right = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=right_array1, dims="x"), - "b": xr.DataArray(data=right_array2, dims="y"), - } + {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)}, ) + right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)}) - units = { - **extract_units(left), - **({} if left_array1.check(unit) else {"a": None, "b": None}), - } - expected = strip_units(left).broadcast_equals( - strip_units(convert_units(right, units)) - ) & left_array1.check(unit) + units = merge_mappings( + extract_units(left), + {} if is_compatible(left_array1, unit) else {"a": None, "b": None}, + ) + expected = is_compatible(left_array1, unit) and strip_units( + left + ).broadcast_equals(strip_units(convert_units(right, units))) actual = left.broadcast_equals(right) assert expected == actual @@ -4532,68 +4421,74 @@ def test_broadcast_equals(self, unit, dtype): (method("unstack"), method("reset_index", "v"), method("reorder_levels")), ids=repr, ) - def test_stacking_stacked(self, func, dtype): - array1 = ( - np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m - ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_stacking_stacked(self, variant, func, dtype): + variants = { + "data": (unit_registry.m, 1), + "dims": (1, unit_registry.m), + } + data_unit, dim_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit array2 = ( np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype) - * unit_registry.m + * data_unit ) - x = np.arange(array1.shape[0]) - y = np.arange(array1.shape[1]) - z = np.arange(array2.shape[2]) + x = np.arange(array1.shape[0]) * dim_unit + y = np.arange(array1.shape[1]) * dim_unit + z = np.arange(array2.shape[2]) * dim_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, coords={"x": x, "y": y, "z": z}, ) + units = extract_units(ds) stacked = ds.stack(v=("x", "y")) - expected = attach_units( - func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m} - ) + expected = attach_units(func(strip_units(stacked)), units) actual = func(stacked) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="does not work with quantities yet") + @pytest.mark.xfail( + reason="stacked dimension's labels have to be hashable, but is a numpy.array" + ) def test_to_stacked_array(self, dtype): - labels = np.arange(5).astype(dtype) * unit_registry.s - arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} + labels = range(5) * unit_registry.s + arrays = { + name: np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + for name in labels + } - ds = xr.Dataset( - data_vars={ - name: xr.DataArray(data=array, dims="x") - for name, array in arrays.items() - } - ) + ds = xr.Dataset({name: ("x", array) for name, array in arrays.items()}) + units = {None: unit_registry.m, "y": unit_registry.s} func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) actual = func(ds).rename(None) - expected = attach_units( - func(strip_units(ds)).rename(None), - {None: unit_registry.m, "y": unit_registry.s}, - ) + expected = attach_units(func(strip_units(ds)).rename(None), units,) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("transpose", "y", "x", "z1", "z2"), - method("stack", a=("x", "y")), + method("stack", u=("x", "y")), method("set_index", x="x2"), - pytest.param( - method("shift", x=2), - marks=pytest.mark.xfail(reason="tries to concatenate nan arrays"), - ), + method("shift", x=2), method("roll", x=2, roll_coords=False), method("sortby", "x2"), ), @@ -4618,20 +4513,19 @@ def test_stacking_reordering(self, func, dtype): ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y", "z1")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z2")), + "a": (("x", "y", "z1"), array1), + "b": (("x", "y", "z2"), array2), }, coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)}, ) + units = extract_units(ds) - expected = attach_units( - func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK} - ) + expected = attach_units(func(strip_units(ds)), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="indexes strip units") @pytest.mark.parametrize( "indices", ( @@ -4643,22 +4537,14 @@ def test_isel(self, indices, dtype): array1 = np.arange(10).astype(dtype) * unit_registry.s array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa - x = np.arange(len(array1)) * unit_registry.m - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).isel(x=indices), - {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m}, - ) + expected = attach_units(strip_units(ds).isel(x=indices), units) actual = ds.isel(x=indices) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4675,7 +4561,7 @@ def test_isel(self, indices, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4694,20 +4580,24 @@ def test_sel(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.sel(x=values) return expected = attach_units( - strip_units(ds).sel(x=strip_units(convert_units(values, {None: x.units}))), - {"a": array1.units, "b": array2.units, "x": x.units}, + strip_units(ds).sel( + x=strip_units(convert_units(values, {None: unit_registry.m})) + ), + extract_units(ds), ) actual = ds.sel(x=values) - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4724,7 +4614,7 @@ def test_sel(self, raw_values, unit, error, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4743,9 +4633,9 @@ def test_drop_sel(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.drop_sel(x=values) @@ -4753,12 +4643,14 @@ def test_drop_sel(self, raw_values, unit, error, dtype): expected = attach_units( strip_units(ds).drop_sel( - x=strip_units(convert_units(values, {None: x.units})) + x=strip_units(convert_units(values, {None: unit_registry.m})) ), extract_units(ds), ) actual = ds.drop_sel(x=values) - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4775,7 +4667,7 @@ def test_drop_sel(self, raw_values, unit, error, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4794,9 +4686,9 @@ def test_loc(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.loc[{"x": values}] @@ -4804,12 +4696,14 @@ def test_loc(self, raw_values, unit, error, dtype): expected = attach_units( strip_units(ds).loc[ - {"x": strip_units(convert_units(values, {None: x.units}))} + {"x": strip_units(convert_units(values, {None: unit_registry.m}))} ], - {"a": array1.units, "b": array2.units, "x": x.units}, + extract_units(ds), ) actual = ds.loc[{"x": values}] - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -4820,14 +4714,34 @@ def test_loc(self, raw_values, unit, error, dtype): ), ids=repr, ) - def test_head_tail_thin(self, func, dtype): - array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK - array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_head_tail_thin(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit_a, unit_b), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_a + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_b coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, + "x": np.arange(10) * dim_unit, + "y": np.arange(5) * dim_unit, + "z": np.arange(8) * dim_unit, + "u": ("x", np.linspace(0, 1, 10) * coord_unit), + "v": ("y", np.linspace(1, 2, 5) * coord_unit), + "w": ("z", np.linspace(-1, 0, 8) * coord_unit), } ds = xr.Dataset( @@ -4841,8 +4755,10 @@ def test_head_tail_thin(self, func, dtype): expected = attach_units(func(strip_units(ds)), extract_units(ds)) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) + @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all")) @pytest.mark.parametrize( "shape", ( @@ -4853,13 +4769,9 @@ def test_head_tail_thin(self, func, dtype): pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"), ), ) - def test_squeeze(self, shape, dtype): + def test_squeeze(self, shape, dim, dtype): names = "xyzt" - coords = { - name: np.arange(length).astype(dtype) - * (unit_registry.m if name != "t" else unit_registry.s) - for name, length in zip(names, shape) - } + dim_lengths = dict(zip(names, shape)) array1 = ( np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK ) @@ -4869,74 +4781,59 @@ def test_squeeze(self, shape, dtype): ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=tuple(names[: len(shape)])), - "b": xr.DataArray(data=array2, dims=tuple(names[: len(shape)])), + "a": (tuple(names[: len(shape)]), array1), + "b": (tuple(names[: len(shape)]), array2), }, - coords=coords, ) units = extract_units(ds) - expected = attach_units(strip_units(ds).squeeze(), units) + kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {} - actual = ds.squeeze() - assert_equal_with_units(actual, expected) + expected = attach_units(strip_units(ds).squeeze(**kwargs), units) - # try squeezing the dimensions separately - names = tuple(dim for dim, coord in coords.items() if len(coord) == 1) - for name in names: - expected = attach_units(strip_units(ds).squeeze(dim=name), units) - actual = ds.squeeze(dim=name) - assert_equal_with_units(actual, expected) + actual = ds.squeeze(**kwargs) - @pytest.mark.xfail(reason="ignores units") + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( - "unit,error", + "func", ( - pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + method("interp"), marks=pytest.mark.xfail(reason="uses scipy") ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + method("reindex"), ), + ids=repr, ) - def test_interp(self, unit, error): - array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK - array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.s, + def test_interp_reindex(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), } + data_unit, coord_unit = variants.get(variant) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) - - new_coords = (np.arange(10) + 0.5) * unit + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit - if error is not None: - with pytest.raises(error): - ds.interp(x=new_coords) + y = np.arange(10) * coord_unit - return + x = np.arange(10) + new_x = np.arange(8) + 0.5 - units = extract_units(ds) - expected = attach_units( - strip_units(ds).interp(x=strip_units(convert_units(new_coords, units))), - units, + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} ) - actual = ds.interp(x=new_coords) + units = extract_units(ds) - assert_equal_with_units(actual, expected) + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) - @pytest.mark.xfail(reason="ignores units") + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( @@ -4949,106 +4846,67 @@ def test_interp(self, unit, error): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_interp_like(self, unit, error, dtype): - array1 = ( - np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, - } + @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr) + def test_interp_reindex_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit - other = xr.Dataset( - data_vars={ - "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), - "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), - }, - coords={ - "x": (np.arange(20) + 0.3) * unit, - "y": (np.arange(10) - 0.2) * unit, - "z": (np.arange(15) + 0.4) * unit, - }, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) if error is not None: with pytest.raises(error): - ds.interp_like(other) + func(ds, x=new_x) return - units = extract_units(ds) - expected = attach_units( - strip_units(ds).interp_like(strip_units(convert_units(other, units))), units - ) - actual = ds.interp_like(other) + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="indexes don't support units") + @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( - "unit,error", + "func", ( - pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy") ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + method("reindex_like"), ), + ids=repr, ) - def test_reindex(self, unit, error, dtype): - array1 = ( - np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(1, 2, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.s, + def test_interp_reindex_like(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), } + data_unit, coord_unit = variants.get(variant) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) - - new_coords = (np.arange(10) + 0.5) * unit + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit - if error is not None: - with pytest.raises(error): - ds.reindex(x=new_coords) + y = np.arange(10) * coord_unit - return + x = np.arange(10) + new_x = np.arange(8) + 0.5 - expected = attach_units( - strip_units(ds).reindex( - x=strip_units(convert_units(new_coords, {None: coords["x"].units})) - ), - extract_units(ds), + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} ) - actual = ds.reindex(x=new_coords) + units = extract_units(ds) + + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) - assert_equal_with_units(actual, expected) + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -5063,54 +4921,32 @@ def test_reindex(self, unit, error, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_reindex_like(self, unit, error, dtype): - array1 = ( - np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) + @pytest.mark.parametrize( + "func", (method("interp_like"), method("reindex_like")), ids=repr + ) + def test_interp_reindex_like_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, - } + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) - other = xr.Dataset( - data_vars={ - "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), - "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), - }, - coords={ - "x": (np.arange(20) + 0.3) * unit, - "y": (np.arange(10) - 0.2) * unit, - "z": (np.arange(15) + 0.4) * unit, - }, - ) + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) if error is not None: with pytest.raises(error): - ds.reindex_like(other) + func(ds, other) return - units = extract_units(ds) - expected = attach_units( - strip_units(ds).reindex_like(strip_units(convert_units(other, units))), - units, - ) - actual = ds.reindex_like(other) + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -5120,30 +4956,46 @@ def test_reindex_like(self, unit, error, dtype): method("integrate", coord="x"), pytest.param( method("quantile", q=[0.25, 0.75]), - marks=pytest.mark.xfail(reason="nanquantile not implemented"), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="nanquantile not implemented yet", + ), ), method("reduce", func=np.sum, dim="x"), method("map", np.fabs), ), ids=repr, ) - def test_computation(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit ds = xr.Dataset( data_vars={ "a": xr.DataArray(data=array1, dims=("x", "y")), "b": xr.DataArray(data=array2, dims=("x", "z")), }, - coords={"x": x, "y": y, "z": z}, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, ) units = extract_units(ds) @@ -5151,69 +5003,105 @@ def test_computation(self, func, dtype): expected = attach_units(func(strip_units(ds)), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("groupby", "x"), - method("groupby_bins", "x", bins=4), + pytest.param( + method("groupby_bins", "x", bins=2), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="needs assert_allclose but that does not work with pint", + ), + ), method("coarsen", x=2), pytest.param( method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units") ), pytest.param( method("rolling_exp", x=3), - marks=pytest.mark.xfail(reason="uses numbagg which strips units"), + marks=pytest.mark.xfail( + reason="numbagg functions are not supported by pint" + ), ), ), ids=repr, ) - def test_computation_objects(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation_objects(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, - coords={"x": x, "y": y, "z": z}, + data_vars={"a": (("x", "y"), array1), "b": (("x", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, ) units = extract_units(ds) args = [] if func.name != "groupby" else ["y"] - reduce_func = method("mean", *args) - expected = attach_units(reduce_func(func(strip_units(ds))), units) - actual = reduce_func(func(ds)) + expected = attach_units(func(strip_units(ds)).mean(*args), units) + actual = func(ds).mean(*args) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + # TODO: remove once pint 0.12 has been released + if LooseVersion(pint.__version__) <= "0.12": + assert_equal(expected, actual) + else: + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_resample(self, variant, dtype): + # TODO: move this to test_computation_objects + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit2 - def test_resample(self, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y") - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + y = np.arange(5) * dim_unit + z = np.arange(8) * dim_unit + + u = np.linspace(-1, 0, 5) * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("time", "y")), - "b": xr.DataArray(data=array2, dims=("time", "z")), - }, - coords={"time": t, "y": y, "z": z}, + data_vars={"a": (("time", "y"), array1), "b": (("time", "z"), array2)}, + coords={"time": t, "y": y, "z": z, "u": ("y", u)}, ) units = extract_units(ds) @@ -5222,43 +5110,59 @@ def test_resample(self, dtype): expected = attach_units(func(strip_units(ds)).mean(), units) actual = func(ds).mean() - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("assign", c=lambda ds: 10 * ds.b), - method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), + method("assign_coords", v=("x", np.arange(5) * unit_registry.s)), method("first"), method("last"), pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), - marks=pytest.mark.xfail(reason="nanquantile not implemented"), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="nanquantile not implemented", + ), ), ), ids=repr, ) - def test_grouped_operations(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_grouped_operations(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + u = np.linspace(-1, 0, 4) * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, - coords={"x": x, "y": y, "z": z}, + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "u": ("y", u)}, ) - units = extract_units(ds) - units.update({"c": unit_registry.Pa, "v": unit_registry.s}) + + assigned_units = {"c": unit2, "v": unit_registry.s} + units = merge_mappings(extract_units(ds), assigned_units) stripped_kwargs = { name: strip_units(value) for name, value in func.kwargs.items() @@ -5268,20 +5172,26 @@ def test_grouped_operations(self, func, dtype): ) actual = func(ds.groupby("y")) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("pipe", lambda ds: ds * 10), method("assign", d=lambda ds: ds.b * 10), - method("assign_coords", y2=("y", np.arange(5) * unit_registry.mm)), + method("assign_coords", y2=("y", np.arange(4) * unit_registry.mm)), method("assign_attrs", attr1="value"), method("rename", x2="x_mm"), method("rename_vars", c="temperature"), method("rename_dims", x="offset_x"), - method("swap_dims", {"x": "x2"}), - method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1), + method("swap_dims", {"x": "u"}), + pytest.param( + method( + "expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1 + ), + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), method("drop_vars", "x"), method("drop_dims", "z"), method("set_coords", names="c"), @@ -5290,40 +5200,55 @@ def test_grouped_operations(self, func, dtype): ), ids=repr, ) - def test_content_manipulation(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) - * unit_registry.m ** 3 - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - array3 = np.linspace(0, 10, 10).astype(dtype) * unit_registry.degK + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_content_manipulation(self, func, variant, dtype): + variants = { + "data": ( + (unit_registry.m ** 3, unit_registry.Pa, unit_registry.degK), + 1, + 1, + ), + "dims": ((1, 1, 1), unit_registry.m, 1), + "coords": ((1, 1, 1), 1, unit_registry.m), + } + (unit1, unit2, unit3), dim_unit, coord_unit = variants.get(variant) - x = np.arange(10) * unit_registry.m - x2 = x.to(unit_registry.mm) - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + array3 = np.linspace(0, 10, 5).astype(dtype) * unit3 + + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + x2 = np.linspace(-1, 0, 5) * coord_unit ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - "c": xr.DataArray(data=array3, dims="x"), + "a": (("x", "y"), array1), + "b": (("x", "y", "z"), array2), + "c": ("x", array3), }, coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, ) - units = { - **extract_units(ds), - **{ - "y2": unit_registry.mm, - "x_mm": unit_registry.mm, - "offset_x": unit_registry.m, - "d": unit_registry.Pa, - "temperature": unit_registry.degK, - }, + + new_units = { + "y2": unit_registry.mm, + "x_mm": coord_unit, + "offset_x": unit_registry.m, + "d": unit2, + "temperature": unit3, } + units = merge_mappings(extract_units(ds), new_units) stripped_kwargs = { key: strip_units(value) for key, value in func.kwargs.items() @@ -5331,7 +5256,8 @@ def test_content_manipulation(self, func, dtype): expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", @@ -5356,25 +5282,29 @@ def test_content_manipulation(self, func, dtype): ), ) def test_merge(self, variant, unit, error, dtype): - original_data_unit = unit_registry.m - original_dim_unit = unit_registry.m - original_coord_unit = unit_registry.m + left_variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } - variants = { - "data": (unit, original_dim_unit, original_coord_unit), - "dims": (original_data_unit, unit, original_coord_unit), - "coords": (original_data_unit, original_dim_unit, unit), + left_data_unit, left_dim_unit, left_coord_unit = left_variants.get(variant) + + right_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), } - data_unit, dim_unit, coord_unit = variants.get(variant) + right_data_unit, right_dim_unit, right_coord_unit = right_variants.get(variant) - left_array = np.arange(10).astype(dtype) * original_data_unit - right_array = np.arange(-5, 5).astype(dtype) * data_unit + left_array = np.arange(10).astype(dtype) * left_data_unit + right_array = np.arange(-5, 5).astype(dtype) * right_data_unit - left_dim = np.arange(10, 20) * original_dim_unit - right_dim = np.arange(5, 15) * dim_unit + left_dim = np.arange(10, 20) * left_dim_unit + right_dim = np.arange(5, 15) * right_dim_unit - left_coord = np.arange(-10, 0) * original_coord_unit - right_coord = np.arange(-15, -5) * coord_unit + left_coord = np.arange(-10, 0) * left_coord_unit + right_coord = np.arange(-15, -5) * right_coord_unit left = xr.Dataset( data_vars={"a": ("x", left_array)}, @@ -5397,4 +5327,5 @@ def test_merge(self, variant, unit, error, dtype): expected = attach_units(strip_units(left).merge(strip_units(converted)), units) actual = left.merge(right) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual)