diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index d9c85a19ce6..96f0ba9a4a6 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -111,7 +111,7 @@ Internal Changes
~~~~~~~~~~~~~~~~
- Added integration tests against `pint `_.
- (:pull:`3238`, :pull:`3447`) by `Justus Magin `_.
+ (:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin `_.
.. note::
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index 8eed1f0dbe3..fd9e9b039ac 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -1045,6 +1045,36 @@ def test_comparisons(self, func, variation, unit, dtype):
assert expected == result
+ @pytest.mark.xfail(reason="blocked by `where`")
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ def test_broadcast_like(self, unit, dtype):
+ array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
+ array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
+
+ x1 = np.arange(2) * unit_registry.m
+ x2 = np.arange(2) * unit
+ y1 = np.array([0]) * unit_registry.m
+ y2 = np.arange(3) * unit
+
+ arr1 = xr.DataArray(data=array1, coords={"x": x1, "y": y1}, dims=("x", "y"))
+ arr2 = xr.DataArray(data=array2, coords={"x": x2, "y": y2}, dims=("x", "y"))
+
+ expected = attach_units(
+ strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1)
+ )
+ result = arr1.broadcast_like(arr2)
+
+ assert_equal_with_units(expected, result)
+
@pytest.mark.parametrize(
"unit",
(
@@ -1303,6 +1333,49 @@ def test_squeeze(self, shape, dtype):
np.squeeze(array, axis=index), data_array.squeeze(dim=name)
)
+ @pytest.mark.xfail(
+ reason="indexes strip units and head / tail / thin only support integers"
+ )
+ @pytest.mark.parametrize(
+ "unit,error",
+ (
+ pytest.param(1, DimensionalityError, id="no_unit"),
+ pytest.param(
+ unit_registry.dimensionless, DimensionalityError, id="dimensionless"
+ ),
+ pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, None, id="compatible_unit"),
+ pytest.param(unit_registry.m, None, id="identical_unit"),
+ ),
+ )
+ @pytest.mark.parametrize(
+ "func",
+ (method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)),
+ ids=repr,
+ )
+ def test_head_tail_thin(self, func, unit, error, dtype):
+ array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
+
+ coords = {
+ "x": np.arange(10) * unit_registry.m,
+ "y": np.arange(5) * unit_registry.m,
+ }
+
+ arr = xr.DataArray(data=array, coords=coords, dims=("x", "y"))
+
+ kwargs = {name: value * unit for name, value in func.kwargs.items()}
+
+ if error is not None:
+ with pytest.raises(error):
+ func(arr, **kwargs)
+
+ return
+
+ expected = attach_units(func(strip_units(arr)), extract_units(arr))
+ result = func(arr, **kwargs)
+
+ assert_equal_with_units(expected, result)
+
@pytest.mark.parametrize(
"unit,error",
(
@@ -2472,6 +2545,40 @@ def test_comparisons(self, func, variation, unit, dtype):
assert expected == result
+ @pytest.mark.xfail(reason="blocked by `where`")
+ @pytest.mark.parametrize(
+ "unit",
+ (
+ pytest.param(1, id="no_unit"),
+ pytest.param(unit_registry.dimensionless, id="dimensionless"),
+ pytest.param(unit_registry.s, id="incompatible_unit"),
+ pytest.param(unit_registry.cm, id="compatible_unit"),
+ pytest.param(unit_registry.m, id="identical_unit"),
+ ),
+ )
+ def test_broadcast_like(self, unit, dtype):
+ array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
+ array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
+
+ x1 = np.arange(2) * unit_registry.m
+ x2 = np.arange(2) * unit
+ y1 = np.array([0]) * unit_registry.m
+ y2 = np.arange(3) * unit
+
+ ds1 = xr.Dataset(
+ data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1}
+ )
+ ds2 = xr.Dataset(
+ data_vars={"a": (("x", "y"), array2)}, coords={"x": x2, "y": y2}
+ )
+
+ expected = attach_units(
+ strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1)
+ )
+ result = ds1.broadcast_like(ds2)
+
+ assert_equal_with_units(expected, result)
+
@pytest.mark.parametrize(
"unit",
(