From 1490c1645c414acf8328816af59cb4719a718f4d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:01:43 +0200 Subject: [PATCH 001/100] Add support for cross --- xarray/__init__.py | 3 +- xarray/core/computation.py | 121 +++++++++++++++++++++++++++++++ xarray/tests/test_computation.py | 37 ++++++++++ 3 files changed, 160 insertions(+), 1 deletion(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 3886edc60e6..19d917d70a5 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,7 +18,7 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, where +from .core.computation import apply_ufunc, corr, cov, cross, dot, polyval, where from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -56,6 +56,7 @@ "dot", "cov", "corr", + "cross", "full_like", "infer_freq", "load_dataarray", diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6010d502c23..d77f2905f87 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1527,6 +1527,127 @@ def dot(*arrays, dims=None, **kwargs): return result.transpose(*[d for d in all_dims if d in result.dims]) +def cross(a, b, spatial_dim=None): + """ + Return the cross product of two (arrays of) vectors. + + Parameters + ---------- + a : array_like + Components of the first vector(s). + b : array_like + Components of the second vector(s). + spatial_dim : something + something + + Examples + -------- + Vector cross-product. + + >>> x = xr.DataArray(np.array([1, 2, 3])) + >>> y = xr.DataArray(np.array([4, 5, 6])) + >>> xr.cross(x, y) + array([-3, 6, -3]) + + One vector with dimension 2. + + >>> a = xr.DataArray(np.array([1, 2]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "z"])))) + >>> b = xr.DataArray(np.array([4, 5, 6]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "y", "z"])))) + >>> xr.cross(a, b) + array([12, -6, -3]) + + + + Multiple vector cross-products. Note that the direction of the + cross product vector is defined by the right-hand rule. + + >>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6]]), dims=("a", "b")) + >>> y = xr.DataArray(np.array([[4, 5, 6], [1, 2, 3]]), dims=("a", "b")) + >>> xr.cross(x, y) + array([[-3, 6, -3], + [ 3, -6, 3]]) + + Change the vector definition of x and y using axisa and axisb. + + >>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + >>> y = xr.DataArray(np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]])) + >>> np.cross(x, y) + array([[ -6, 12, -6], + [ 0, 0, 0], + [ 6, -12, 6]]) + >>> np.cross(x, y, axisa=0, axisb=0) + array([[-24, 48, -24], + [-30, 60, -30], + [-36, 72, -36]]) + + See Also + -------- + numpy.cross : Corresponding numpy function + """ + from .dataarray import DataArray + from .variable import Variable + + arrays = [a, b] + for arr in arrays: + if not isinstance(arr, (DataArray)): + raise TypeError( + f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." + ) + + if spatial_dim is None: + # TODO: Find spatial dim default by looking for unique + # (3 or 2)-valued dim? + spatial_dim = arr.dims[-1] + elif spatial_dim not in arr.dims: + raise ValueError(f"Dimension {spatial_dim} not in {arr}.") + + s = arr.sizes[spatial_dim] + if s < 1 or s > 3: + raise ValueError( + "incompatible dimensions for cross product\n" + "(dimension with coords must be 1, 2 or 3)" + ) + + if a.sizes[spatial_dim] == b.sizes[spatial_dim]: + # Arrays have the same size, no need to do anything: + pass + else: + # Arrays have different sizes. Append zeros where the smaller + # array is missing a value, zeros will not affect np.cross: + ind = 1 if a.sizes[spatial_dim] > b.sizes[spatial_dim] else 0 + if a.coords: + # If the array has coords we know which indexes to fill + # with zeros: + arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0) + elif arrays[ind].sizes[spatial_dim] > 1: + # If it doesn't have coords we can can only that infer that + # it is composite values if the size is 2. + from .concat import concat + + arrays[ind] = concat([a, DataArray([0])], dim=spatial_dim) + else: + # Size is 1, then we do not know if it is a constant or + # composite value: + raise ValueError( + "incompatible dimensions for cross product\n" + "(dimension without coords must be 2 or 3)" + ) + + # Figure out the output dtype: + # output_dtype = np.cross( + # np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype) + # ).dtype + + return apply_ufunc( + np.cross, + *arrays, + # input_core_dims=[[spatial_dim], [spatial_dim]], + # output_core_dims=[[spatial_dim]], + dask="parallelized", + # output_dtypes=[output_dtype], + ) + + def where(cond, x, y): """Return elements from `x` or `y` depending on `cond`. diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index cbfa61a4482..7671083b517 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1898,3 +1898,40 @@ def test_polyval(use_dask, use_datetime): da_pv = xr.polyval(da.x, coeffs) xr.testing.assert_allclose(da, da_pv.T) + + +@pytest.mark.parametrize( + "a, b, ae, be", + [ + [ + xr.DataArray(np.array([1, 2, 3])), + xr.DataArray(np.array([4, 5, 6])), + np.array([1, 2, 3]), + np.array([4, 5, 6]), + ], + [ + xr.DataArray(np.array([1, 2])), + xr.DataArray(np.array([4, 5, 6])), + np.array([1, 2]), + np.array([4, 5, 6]), + ], + [ + xr.DataArray( + np.array([1, 2]), + dims=["ax"], + coords=dict(x=(["ax"], np.array(["x", "z"]))), + ), + xr.DataArray( + np.array([4, 5, 6]), + dims=["ax"], + coords=dict(x=(["ax"], np.array(["x", "y", "z"]))), + ), + np.array([1, 0, 2]), + np.array([4, 5, 6]), + ], + ], +) +def test_cross(a, b, ae, be): + expected = np.cross(ae, be) + actual = xr.cross(a, b) + xr.testing.assert_allclose(expected, actual) \ No newline at end of file From 03db734fbbb948cd548598687b37f06b8a8021f1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:02:21 +0200 Subject: [PATCH 002/100] Update test_computation.py --- xarray/tests/test_computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 7671083b517..d42217d8a3c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1934,4 +1934,4 @@ def test_polyval(use_dask, use_datetime): def test_cross(a, b, ae, be): expected = np.cross(ae, be) actual = xr.cross(a, b) - xr.testing.assert_allclose(expected, actual) \ No newline at end of file + xr.testing.assert_allclose(expected, actual) From c824e3654f69e8d701de995c3f06676acf92dcbd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:05:24 +0200 Subject: [PATCH 003/100] Update computation.py --- xarray/core/computation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index d77f2905f87..a3d6a4216e9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1585,7 +1585,6 @@ def cross(a, b, spatial_dim=None): numpy.cross : Corresponding numpy function """ from .dataarray import DataArray - from .variable import Variable arrays = [a, b] for arr in arrays: From 7ce39c7f398bf4c42df6d33751aa35eaba2760e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:10:28 +0200 Subject: [PATCH 004/100] Update computation.py --- xarray/core/computation.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a3d6a4216e9..2cbea719ce1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1551,8 +1551,14 @@ def cross(a, b, spatial_dim=None): One vector with dimension 2. - >>> a = xr.DataArray(np.array([1, 2]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "z"])))) - >>> b = xr.DataArray(np.array([4, 5, 6]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "y", "z"])))) + >>> a = xr.DataArray( + ... np.array([1, 2]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "z"]))) + ... ) + >>> b = xr.DataArray( + ... np.array([4, 5, 6]), + ... dims=["x"], + ... coords=dict(x=(["x"], np.array(["x", "y", "z"]))), + ... ) >>> xr.cross(a, b) array([12, -6, -3]) From 916e661f78328fe2d19b3ec5858ec4473f923caf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:32:27 +0200 Subject: [PATCH 005/100] Update test_computation.py --- xarray/tests/test_computation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index d42217d8a3c..579624e1a50 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1918,13 +1918,13 @@ def test_polyval(use_dask, use_datetime): [ xr.DataArray( np.array([1, 2]), - dims=["ax"], - coords=dict(x=(["ax"], np.array(["x", "z"]))), + dims=["axis"], + coords=dict(axis=(["axis"], np.array(["x", "z"]))), ), xr.DataArray( np.array([4, 5, 6]), - dims=["ax"], - coords=dict(x=(["ax"], np.array(["x", "y", "z"]))), + dims=["axis"], + coords=dict(axis=(["axis"], np.array(["x", "y", "z"]))), ), np.array([1, 0, 2]), np.array([4, 5, 6]), From 654ad60577a6be32299219492be6a9d8acae56f0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 15:35:56 +0200 Subject: [PATCH 006/100] Update test_computation.py --- xarray/tests/test_computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 579624e1a50..2526706a501 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1934,4 +1934,4 @@ def test_polyval(use_dask, use_datetime): def test_cross(a, b, ae, be): expected = np.cross(ae, be) actual = xr.cross(a, b) - xr.testing.assert_allclose(expected, actual) + xr.testing.assert_duckarray_allclose(expected, actual) From a6ac578983f6c130cc4f8d2179a926cf1bd01eaa Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 16:08:29 +0200 Subject: [PATCH 007/100] Update test_computation.py --- xarray/tests/test_computation.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 2526706a501..f3ebd0ece28 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1903,6 +1903,7 @@ def test_polyval(use_dask, use_datetime): @pytest.mark.parametrize( "a, b, ae, be", [ + # Basic np.cross tests: [ xr.DataArray(np.array([1, 2, 3])), xr.DataArray(np.array([4, 5, 6])), @@ -1915,6 +1916,22 @@ def test_polyval(use_dask, use_datetime): np.array([1, 2]), np.array([4, 5, 6]), ], + # Test 1 sized arrays with coords: + [ + xr.DataArray( + np.array([1]), + dims=["axis"], + coords=dict(axis=(["axis"], np.array(["z"]))), + ), + xr.DataArray( + np.array([4, 5, 6]), + dims=["axis"], + coords=dict(axis=(["axis"], np.array(["x", "y", "z"]))), + ), + np.array([0, 0, 1]), + np.array([4, 5, 6]), + ], + # Test filling inbetween with coords: [ xr.DataArray( np.array([1, 2]), From e0c1facaba46c20dd540ed44d1ed477fb85b5fb8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 18:41:29 +0200 Subject: [PATCH 008/100] add more tests --- xarray/core/computation.py | 21 ++++++----- xarray/tests/test_computation.py | 63 +++++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2cbea719ce1..ad33adb5b7c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1552,18 +1552,18 @@ def cross(a, b, spatial_dim=None): One vector with dimension 2. >>> a = xr.DataArray( - ... np.array([1, 2]), dims=["x"], coords=dict(x=(["x"], np.array(["x", "z"]))) + ... np.array([1, 2]), + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))) ... ) >>> b = xr.DataArray( ... np.array([4, 5, 6]), ... dims=["x"], - ... coords=dict(x=(["x"], np.array(["x", "y", "z"]))), + ... coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), ... ) >>> xr.cross(a, b) array([12, -6, -3]) - - Multiple vector cross-products. Note that the direction of the cross product vector is defined by the right-hand rule. @@ -1620,6 +1620,7 @@ def cross(a, b, spatial_dim=None): # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: ind = 1 if a.sizes[spatial_dim] > b.sizes[spatial_dim] else 0 + if a.coords: # If the array has coords we know which indexes to fill # with zeros: @@ -1639,17 +1640,17 @@ def cross(a, b, spatial_dim=None): ) # Figure out the output dtype: - # output_dtype = np.cross( - # np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype) - # ).dtype + output_dtype = np.cross( + np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype) + ).dtype return apply_ufunc( np.cross, *arrays, - # input_core_dims=[[spatial_dim], [spatial_dim]], - # output_core_dims=[[spatial_dim]], + input_core_dims=[[spatial_dim], [spatial_dim]], + output_core_dims=[[spatial_dim]], dask="parallelized", - # output_dtypes=[output_dtype], + output_dtypes=[output_dtype], ) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index f3ebd0ece28..b3701a4893c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1901,54 +1901,83 @@ def test_polyval(use_dask, use_datetime): @pytest.mark.parametrize( - "a, b, ae, be", + "a, b, ae, be, spatial_dim, axis", [ - # Basic np.cross tests: [ xr.DataArray(np.array([1, 2, 3])), xr.DataArray(np.array([4, 5, 6])), np.array([1, 2, 3]), np.array([4, 5, 6]), + None, + -1, ], [ xr.DataArray(np.array([1, 2])), xr.DataArray(np.array([4, 5, 6])), np.array([1, 2]), np.array([4, 5, 6]), + None, + -1, ], - # Test 1 sized arrays with coords: - [ + [ # Test spatial dim in the middle: + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], np.array(["x", "y", "z"])), + var=(["var"], np.array([1, 1.5, 2, 2.5])), + ), + ), + xr.DataArray( + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + dims=["time", "cartesian", "var"], + coords=dict( + time=(["time"], np.arange(0, 5)), + cartesian=(["cartesian"], np.array(["x", "y", "z"])), + var=(["var"], np.array([1, 1.5, 2, 2.5])), + ), + ), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), + np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1, + "cartesian", + 1, + ], + [ # Test 1 sized arrays with coords: xr.DataArray( np.array([1]), - dims=["axis"], - coords=dict(axis=(["axis"], np.array(["z"]))), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], np.array(["z"]))), ), xr.DataArray( np.array([4, 5, 6]), - dims=["axis"], - coords=dict(axis=(["axis"], np.array(["x", "y", "z"]))), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), ), np.array([0, 0, 1]), np.array([4, 5, 6]), + None, + -1, ], - # Test filling inbetween with coords: - [ + [ # Test filling inbetween with coords: xr.DataArray( np.array([1, 2]), - dims=["axis"], - coords=dict(axis=(["axis"], np.array(["x", "z"]))), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))), ), xr.DataArray( np.array([4, 5, 6]), - dims=["axis"], - coords=dict(axis=(["axis"], np.array(["x", "y", "z"]))), + dims=["cartesian"], + coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), ), np.array([1, 0, 2]), np.array([4, 5, 6]), + None, + -1, ], ], ) -def test_cross(a, b, ae, be): - expected = np.cross(ae, be) - actual = xr.cross(a, b) +def test_cross(a, b, ae, be, spatial_dim, axis): + expected = np.cross(ae, be, axis=axis) + actual = xr.cross(a, b, spatial_dim=spatial_dim) xr.testing.assert_duckarray_allclose(expected, actual) From 7aebae7ebfdd4363e4a63068f449298b2ba528d2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 18:46:51 +0200 Subject: [PATCH 009/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ad33adb5b7c..c3500418d00 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1533,7 +1533,7 @@ def cross(a, b, spatial_dim=None): Parameters ---------- - a : array_like + a, b : DataArray Components of the first vector(s). b : array_like Components of the second vector(s). From be7b2c275bc82d6665d83b3b2aa18cd79976558a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 19:05:11 +0200 Subject: [PATCH 010/100] spatial_dim to dim --- xarray/core/computation.py | 30 ++++++++++++++---------------- xarray/tests/test_computation.py | 8 ++++---- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c3500418d00..c1ff56eb40c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1527,17 +1527,15 @@ def dot(*arrays, dims=None, **kwargs): return result.transpose(*[d for d in all_dims if d in result.dims]) -def cross(a, b, spatial_dim=None): +def cross(a, b, dim=None): """ Return the cross product of two (arrays of) vectors. Parameters ---------- a, b : DataArray - Components of the first vector(s). - b : array_like - Components of the second vector(s). - spatial_dim : something + something + dim : hashable or tuple of hashable something Examples @@ -1599,38 +1597,38 @@ def cross(a, b, spatial_dim=None): f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." ) - if spatial_dim is None: + if dim is None: # TODO: Find spatial dim default by looking for unique # (3 or 2)-valued dim? - spatial_dim = arr.dims[-1] - elif spatial_dim not in arr.dims: - raise ValueError(f"Dimension {spatial_dim} not in {arr}.") + dim = arr.dims[-1] + elif dim not in arr.dims: + raise ValueError(f"Dimension {dim} not in {arr}.") - s = arr.sizes[spatial_dim] + s = arr.sizes[dim] if s < 1 or s > 3: raise ValueError( "incompatible dimensions for cross product\n" "(dimension with coords must be 1, 2 or 3)" ) - if a.sizes[spatial_dim] == b.sizes[spatial_dim]: + if a.sizes[dim] == b.sizes[dim]: # Arrays have the same size, no need to do anything: pass else: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - ind = 1 if a.sizes[spatial_dim] > b.sizes[spatial_dim] else 0 + ind = 1 if a.sizes[dim] > b.sizes[dim] else 0 if a.coords: # If the array has coords we know which indexes to fill # with zeros: arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0) - elif arrays[ind].sizes[spatial_dim] > 1: + elif arrays[ind].sizes[dim] > 1: # If it doesn't have coords we can can only that infer that # it is composite values if the size is 2. from .concat import concat - arrays[ind] = concat([a, DataArray([0])], dim=spatial_dim) + arrays[ind] = concat([a, DataArray([0])], dim=dim) else: # Size is 1, then we do not know if it is a constant or # composite value: @@ -1647,8 +1645,8 @@ def cross(a, b, spatial_dim=None): return apply_ufunc( np.cross, *arrays, - input_core_dims=[[spatial_dim], [spatial_dim]], - output_core_dims=[[spatial_dim]], + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], dask="parallelized", output_dtypes=[output_dtype], ) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index b3701a4893c..74750cf54f1 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1901,7 +1901,7 @@ def test_polyval(use_dask, use_datetime): @pytest.mark.parametrize( - "a, b, ae, be, spatial_dim, axis", + "a, b, ae, be, dim, axis", [ [ xr.DataArray(np.array([1, 2, 3])), @@ -1919,7 +1919,7 @@ def test_polyval(use_dask, use_datetime): None, -1, ], - [ # Test spatial dim in the middle: + [ # Test dim in the middle: xr.DataArray( np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), dims=["time", "cartesian", "var"], @@ -1977,7 +1977,7 @@ def test_polyval(use_dask, use_datetime): ], ], ) -def test_cross(a, b, ae, be, spatial_dim, axis): +def test_cross(a, b, ae, be, dim, axis): expected = np.cross(ae, be, axis=axis) - actual = xr.cross(a, b, spatial_dim=spatial_dim) + actual = xr.cross(a, b, dim=dim) xr.testing.assert_duckarray_allclose(expected, actual) From 4448006ab9035138efd38feeb29e6af373eec182 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 19:09:34 +0200 Subject: [PATCH 011/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c1ff56eb40c..14760844e8f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1552,7 +1552,7 @@ def cross(a, b, dim=None): >>> a = xr.DataArray( ... np.array([1, 2]), ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))) + ... coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))), ... ) >>> b = xr.DataArray( ... np.array([4, 5, 6]), From af8b09cdff318da8a69677e6f932f26956effcd2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 19:22:59 +0200 Subject: [PATCH 012/100] use pad instead of concat --- xarray/core/computation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 14760844e8f..ad3d5904e05 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1626,9 +1626,7 @@ def cross(a, b, dim=None): elif arrays[ind].sizes[dim] > 1: # If it doesn't have coords we can can only that infer that # it is composite values if the size is 2. - from .concat import concat - - arrays[ind] = concat([a, DataArray([0])], dim=dim) + arrays[ind] = arrays[ind].pad({dim: (0, 1)}, constant_values=0) else: # Size is 1, then we do not know if it is a constant or # composite value: From a135e0518198429a6d459e4b47a50c4c1a6567b2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 19:47:09 +0200 Subject: [PATCH 013/100] copy paste np.cross intro --- xarray/core/computation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ad3d5904e05..906658e7b8f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1531,6 +1531,14 @@ def cross(a, b, dim=None): """ Return the cross product of two (arrays of) vectors. + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular + to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors + are defined by the last axis of `a` and `b` by default, and these axes + can have dimensions 2 or 3. Where the dimension of either `a` or `b` is + 2, the third component of the input vector is assumed to be zero and the + cross product calculated accordingly. In cases where both input vectors + have dimension 2, the z-component of the cross product is returned. + Parameters ---------- a, b : DataArray From 6f17b9bbb84685de7e4c06bd78278c53918e3271 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 21:32:40 +0200 Subject: [PATCH 014/100] Get last dim for each array, which is more inline with np.cross --- xarray/core/computation.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 906658e7b8f..f190141d783 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1598,6 +1598,7 @@ def cross(a, b, dim=None): """ from .dataarray import DataArray + dims = [] arrays = [a, b] for arr in arrays: if not isinstance(arr, (DataArray)): @@ -1608,35 +1609,37 @@ def cross(a, b, dim=None): if dim is None: # TODO: Find spatial dim default by looking for unique # (3 or 2)-valued dim? - dim = arr.dims[-1] - elif dim not in arr.dims: + dims.append(arr.dims[-1]) + elif dim in arr.dims: + dims.append(dim) + else: raise ValueError(f"Dimension {dim} not in {arr}.") - s = arr.sizes[dim] + s = arr.sizes[dims[-1]] if s < 1 or s > 3: raise ValueError( "incompatible dimensions for cross product\n" "(dimension with coords must be 1, 2 or 3)" ) - if a.sizes[dim] == b.sizes[dim]: + if a.sizes[dims[0]] == b.sizes[dims[1]]: # Arrays have the same size, no need to do anything: pass else: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - ind = 1 if a.sizes[dim] > b.sizes[dim] else 0 + ind = 1 if a.sizes[dims[0]] > b.sizes[dims[1]] else 0 - if a.coords: + if arrays[ind].coords: # If the array has coords we know which indexes to fill # with zeros: arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0) - elif arrays[ind].sizes[dim] > 1: - # If it doesn't have coords we can can only that infer that - # it is composite values if the size is 2. - arrays[ind] = arrays[ind].pad({dim: (0, 1)}, constant_values=0) + elif arrays[ind].sizes[dims[ind]] > 1: + # If the array doesn't have coords we can can only infer + # that it is composite values if the size is 2: + arrays[ind] = arrays[ind].pad({dims[ind]: (0, 1)}, constant_values=0) else: - # Size is 1, then we do not know if it is a constant or + # Size is 1, then we do not know if the array is a constant or # composite value: raise ValueError( "incompatible dimensions for cross product\n" @@ -1648,15 +1651,17 @@ def cross(a, b, dim=None): np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype) ).dtype - return apply_ufunc( + c = apply_ufunc( np.cross, *arrays, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], + input_core_dims=[[dims[0]], [dims[1]]], + output_core_dims=[[dims[0]]], dask="parallelized", output_dtypes=[output_dtype], ) + return c + def where(cond, x, y): """Return elements from `x` or `y` depending on `cond`. From 1fadb5fb3a9177ed846eba6d57b8d5110a069045 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 23 May 2021 21:50:21 +0200 Subject: [PATCH 015/100] examples in docs --- xarray/core/computation.py | 76 ++++++++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f190141d783..a8ed3075c5a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1550,47 +1550,75 @@ def cross(a, b, dim=None): -------- Vector cross-product. - >>> x = xr.DataArray(np.array([1, 2, 3])) - >>> y = xr.DataArray(np.array([4, 5, 6])) - >>> xr.cross(x, y) + >>> a = xr.DataArray([1, 2, 3]) + >>> b = xr.DataArray([4, 5, 6]) + >>> xr.cross(a, b) + array([-3, 6, -3]) + Dimensions without coordinates: dim_0 One vector with dimension 2. >>> a = xr.DataArray( - ... np.array([1, 2]), + ... [1, 2], ... dims=["cartesian"], - ... coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))), + ... coords=dict(cartesian=(["cartesian"], ["x", "y"])), ... ) >>> b = xr.DataArray( - ... np.array([4, 5, 6]), - ... dims=["x"], - ... coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) >>> xr.cross(a, b) + array([12, -6, -3]) + Coordinates: + * cartesian (cartesian) object 'x' 'y' 'z' + + One vector with dimension 2 but coords in other positions. + + >>> a = xr.DataArray( + ... [1, 2], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "z"])), + ... ) + >>> b = xr.DataArray( + ... [4, 5, 6], + ... dims=["cartesian"], + ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), + ... ) + >>> xr.cross(a, b) + + array([-10, 2, 5]) + Coordinates: + * cartesian (cartesian) object 'x' 'y' 'z' Multiple vector cross-products. Note that the direction of the cross product vector is defined by the right-hand rule. - >>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6]]), dims=("a", "b")) - >>> y = xr.DataArray(np.array([[4, 5, 6], [1, 2, 3]]), dims=("a", "b")) - >>> xr.cross(x, y) + >>> a = xr.DataArray( + ... [[1, 2, 3], [4, 5, 6]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> b = xr.DataArray( + ... [[4, 5, 6], [1, 2, 3]], + ... dims=("time", "cartesian"), + ... coords=dict( + ... time=(["time"], [0, 1]), + ... cartesian=(["cartesian"], ["x", "y", "z"]), + ... ), + ... ) + >>> xr.cross(a, b) + array([[-3, 6, -3], [ 3, -6, 3]]) - - Change the vector definition of x and y using axisa and axisb. - - >>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) - >>> y = xr.DataArray(np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]])) - >>> np.cross(x, y) - array([[ -6, 12, -6], - [ 0, 0, 0], - [ 6, -12, 6]]) - >>> np.cross(x, y, axisa=0, axisb=0) - array([[-24, 48, -24], - [-30, 60, -30], - [-36, 72, -36]]) + Coordinates: + * time (time) int32 0 1 + * cartesian (cartesian) Date: Sun, 23 May 2021 21:57:49 +0200 Subject: [PATCH 016/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a8ed3075c5a..4e23d96f4f1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1617,7 +1617,7 @@ def cross(a, b, dim=None): array([[-3, 6, -3], [ 3, -6, 3]]) Coordinates: - * time (time) int32 0 1 + * time (time) int64 0 1 * cartesian (cartesian) Date: Sun, 23 May 2021 23:03:44 +0200 Subject: [PATCH 017/100] more doc examples --- xarray/core/computation.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 4e23d96f4f1..81bc91fef9d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1548,7 +1548,7 @@ def cross(a, b, dim=None): Examples -------- - Vector cross-product. + Vector cross-product with 3 dimensions. >>> a = xr.DataArray([1, 2, 3]) >>> b = xr.DataArray([4, 5, 6]) @@ -1557,6 +1557,25 @@ def cross(a, b, dim=None): array([-3, 6, -3]) Dimensions without coordinates: dim_0 + Vector cross-product with 2 dimensions, returns in the orthogonal + direction: + + >>> a = xr.DataArray([1, 2]) + >>> b = xr.DataArray([4, 5]) + >>> xr.cross(a, b) + + array(-3) + + Vector cross-product with 3 dimensions but zeros at the last axis + yields the same results as with 2 dimensions: + + >>> a = xr.DataArray([1, 2, 0]) + >>> b = xr.DataArray([4, 5, 0]) + >>> xr.cross(a, b) + + array([ 0, 0, -3]) + Dimensions without coordinates: dim_0 + One vector with dimension 2. >>> a = xr.DataArray( @@ -1674,18 +1693,18 @@ def cross(a, b, dim=None): "(dimension without coords must be 2 or 3)" ) - # Figure out the output dtype: - output_dtype = np.cross( - np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype) - ).dtype - c = apply_ufunc( np.cross, *arrays, input_core_dims=[[dims[0]], [dims[1]]], - output_core_dims=[[dims[0]]], + output_core_dims=[[dims[0]]] if arrays[0].sizes[dims[0]] == 3 else [[]], dask="parallelized", - output_dtypes=[output_dtype], + output_dtypes=[ + np.cross( + np.empty((2, 2), dtype=arrays[0].dtype), + np.empty((2, 2), dtype=arrays[1].dtype), + ).dtype + ], ) return c From dd6056238c0a98f2e101571d67b8fb2257c7aa94 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 24 May 2021 20:58:41 +0200 Subject: [PATCH 018/100] single dim required, tranpose after apply_ufunc --- xarray/core/computation.py | 334 +++++++++++++++---------------- xarray/tests/test_computation.py | 16 +- 2 files changed, 174 insertions(+), 176 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 81bc91fef9d..4a31aa8dbcb 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1382,152 +1382,7 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): return corr -def dot(*arrays, dims=None, **kwargs): - """Generalized dot product for xarray objects. Like np.einsum, but - provides a simpler interface based on array dimensions. - - Parameters - ---------- - *arrays : DataArray or Variable - Arrays to compute. - dims : ..., str or tuple of str, optional - Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. - If not specified, then all the common dimensions are summed over. - **kwargs : dict - Additional keyword arguments passed to numpy.einsum or - dask.array.einsum - - Returns - ------- - DataArray - - Examples - -------- - >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) - >>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2), dims=["a", "b", "c"]) - >>> da_c = xr.DataArray(np.arange(2 * 3).reshape(2, 3), dims=["c", "d"]) - - >>> da_a - - array([[0, 1], - [2, 3], - [4, 5]]) - Dimensions without coordinates: a, b - - >>> da_b - - array([[[ 0, 1], - [ 2, 3]], - - [[ 4, 5], - [ 6, 7]], - - [[ 8, 9], - [10, 11]]]) - Dimensions without coordinates: a, b, c - - >>> da_c - - array([[0, 1, 2], - [3, 4, 5]]) - Dimensions without coordinates: c, d - - >>> xr.dot(da_a, da_b, dims=["a", "b"]) - - array([110, 125]) - Dimensions without coordinates: c - - >>> xr.dot(da_a, da_b, dims=["a"]) - - array([[40, 46], - [70, 79]]) - Dimensions without coordinates: b, c - - >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"]) - - array([[ 9, 14, 19], - [ 93, 150, 207], - [273, 446, 619]]) - Dimensions without coordinates: a, d - - >>> xr.dot(da_a, da_b) - - array([110, 125]) - Dimensions without coordinates: c - - >>> xr.dot(da_a, da_b, dims=...) - - array(235) - """ - from .dataarray import DataArray - from .variable import Variable - - if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): - raise TypeError( - "Only xr.DataArray and xr.Variable are supported." - "Given {}.".format([type(arr) for arr in arrays]) - ) - - if len(arrays) == 0: - raise TypeError("At least one array should be given.") - - if isinstance(dims, str): - dims = (dims,) - - common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) - all_dims = [] - for arr in arrays: - all_dims += [d for d in arr.dims if d not in all_dims] - - einsum_axes = "abcdefghijklmnopqrstuvwxyz" - dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - - if dims is ...: - dims = all_dims - elif dims is None: - # find dimensions that occur more than one times - dim_counts = Counter() - for arr in arrays: - dim_counts.update(arr.dims) - dims = tuple(d for d, c in dim_counts.items() if c > 1) - - dims = tuple(dims) # make dims a tuple - - # dimensions to be parallelized - broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims) - input_core_dims = [ - [d for d in arr.dims if d not in broadcast_dims] for arr in arrays - ] - output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)] - - # construct einsum subscripts, such as '...abc,...ab->...c' - # Note: input_core_dims are always moved to the last position - subscripts_list = [ - "..." + "".join(dim_map[d] for d in ds) for ds in input_core_dims - ] - subscripts = ",".join(subscripts_list) - subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0]) - - join = OPTIONS["arithmetic_join"] - # using "inner" emulates `(a * b).sum()` for all joins (except "exact") - if join != "exact": - join = "inner" - - # subscripts should be passed to np.einsum as arg, not as kwargs. We need - # to construct a partial function for apply_ufunc to work. - func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) - result = apply_ufunc( - func, - *arrays, - input_core_dims=input_core_dims, - output_core_dims=output_core_dims, - join=join, - dask="allowed", - ) - return result.transpose(*[d for d in all_dims if d in result.dims]) - - -def cross(a, b, dim=None): +def cross(a, b, dim): """ Return the cross product of two (arrays of) vectors. @@ -1571,7 +1426,7 @@ def cross(a, b, dim=None): >>> a = xr.DataArray([1, 2, 0]) >>> b = xr.DataArray([4, 5, 0]) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "dim_0") array([ 0, 0, -3]) Dimensions without coordinates: dim_0 @@ -1588,7 +1443,7 @@ def cross(a, b, dim=None): ... dims=["cartesian"], ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "cartesian") array([12, -6, -3]) Coordinates: @@ -1606,7 +1461,7 @@ def cross(a, b, dim=None): ... dims=["cartesian"], ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "cartesian") array([-10, 2, 5]) Coordinates: @@ -1631,7 +1486,7 @@ def cross(a, b, dim=None): ... cartesian=(["cartesian"], ["x", "y", "z"]), ... ), ... ) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "cartesian") array([[-3, 6, -3], [ 3, -6, 3]]) @@ -1645,7 +1500,7 @@ def cross(a, b, dim=None): """ from .dataarray import DataArray - dims = [] + all_dims = [] arrays = [a, b] for arr in arrays: if not isinstance(arr, (DataArray)): @@ -1653,38 +1508,36 @@ def cross(a, b, dim=None): f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." ) - if dim is None: - # TODO: Find spatial dim default by looking for unique - # (3 or 2)-valued dim? - dims.append(arr.dims[-1]) - elif dim in arr.dims: - dims.append(dim) - else: + # TODO: Find spatial dim default by looking for unique + # (3 or 2)-valued dim? + if dim not in arr.dims: raise ValueError(f"Dimension {dim} not in {arr}.") - s = arr.sizes[dims[-1]] + s = arr.sizes[dim] if s < 1 or s > 3: raise ValueError( "incompatible dimensions for cross product\n" "(dimension with coords must be 1, 2 or 3)" ) - if a.sizes[dims[0]] == b.sizes[dims[1]]: + all_dims += [d for d in arr.dims if d not in all_dims] + + if a.sizes[dim] == b.sizes[dim]: # Arrays have the same size, no need to do anything: pass else: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - ind = 1 if a.sizes[dims[0]] > b.sizes[dims[1]] else 0 + ind = 1 if a.sizes[dim] > b.sizes[dim] else 0 - if arrays[ind].coords: - # If the array has coords we know which indexes to fill + if all([arr.coords for arr in arrays]): + # If the arrays have coords we know which indexes to fill # with zeros: arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0) - elif arrays[ind].sizes[dims[ind]] > 1: + elif arrays[ind].sizes[dim] > 1: # If the array doesn't have coords we can can only infer # that it is composite values if the size is 2: - arrays[ind] = arrays[ind].pad({dims[ind]: (0, 1)}, constant_values=0) + arrays[ind] = arrays[ind].pad({dim: (0, 1)}, constant_values=0) else: # Size is 1, then we do not know if the array is a constant or # composite value: @@ -1696,8 +1549,8 @@ def cross(a, b, dim=None): c = apply_ufunc( np.cross, *arrays, - input_core_dims=[[dims[0]], [dims[1]]], - output_core_dims=[[dims[0]]] if arrays[0].sizes[dims[0]] == 3 else [[]], + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else [[]]], dask="parallelized", output_dtypes=[ np.cross( @@ -1707,7 +1560,152 @@ def cross(a, b, dim=None): ], ) - return c + return c.transpose(*[d for d in all_dims if d in c.dims]) + + +def dot(*arrays, dims=None, **kwargs): + """Generalized dot product for xarray objects. Like np.einsum, but + provides a simpler interface based on array dimensions. + + Parameters + ---------- + *arrays : DataArray or Variable + Arrays to compute. + dims : ..., str or tuple of str, optional + Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. + If not specified, then all the common dimensions are summed over. + **kwargs : dict + Additional keyword arguments passed to numpy.einsum or + dask.array.einsum + + Returns + ------- + DataArray + + Examples + -------- + >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) + >>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2), dims=["a", "b", "c"]) + >>> da_c = xr.DataArray(np.arange(2 * 3).reshape(2, 3), dims=["c", "d"]) + + >>> da_a + + array([[0, 1], + [2, 3], + [4, 5]]) + Dimensions without coordinates: a, b + + >>> da_b + + array([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]], + + [[ 8, 9], + [10, 11]]]) + Dimensions without coordinates: a, b, c + + >>> da_c + + array([[0, 1, 2], + [3, 4, 5]]) + Dimensions without coordinates: c, d + + >>> xr.dot(da_a, da_b, dims=["a", "b"]) + + array([110, 125]) + Dimensions without coordinates: c + + >>> xr.dot(da_a, da_b, dims=["a"]) + + array([[40, 46], + [70, 79]]) + Dimensions without coordinates: b, c + + >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"]) + + array([[ 9, 14, 19], + [ 93, 150, 207], + [273, 446, 619]]) + Dimensions without coordinates: a, d + + >>> xr.dot(da_a, da_b) + + array([110, 125]) + Dimensions without coordinates: c + + >>> xr.dot(da_a, da_b, dims=...) + + array(235) + """ + from .dataarray import DataArray + from .variable import Variable + + if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): + raise TypeError( + "Only xr.DataArray and xr.Variable are supported." + "Given {}.".format([type(arr) for arr in arrays]) + ) + + if len(arrays) == 0: + raise TypeError("At least one array should be given.") + + if isinstance(dims, str): + dims = (dims,) + + common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) + all_dims = [] + for arr in arrays: + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = "abcdefghijklmnopqrstuvwxyz" + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + + if dims is ...: + dims = all_dims + elif dims is None: + # find dimensions that occur more than one times + dim_counts = Counter() + for arr in arrays: + dim_counts.update(arr.dims) + dims = tuple(d for d, c in dim_counts.items() if c > 1) + + dims = tuple(dims) # make dims a tuple + + # dimensions to be parallelized + broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims) + input_core_dims = [ + [d for d in arr.dims if d not in broadcast_dims] for arr in arrays + ] + output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)] + + # construct einsum subscripts, such as '...abc,...ab->...c' + # Note: input_core_dims are always moved to the last position + subscripts_list = [ + "..." + "".join(dim_map[d] for d in ds) for ds in input_core_dims + ] + subscripts = ",".join(subscripts_list) + subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0]) + + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + + # subscripts should be passed to np.einsum as arg, not as kwargs. We need + # to construct a partial function for apply_ufunc to work. + func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) + result = apply_ufunc( + func, + *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + join=join, + dask="allowed", + ) + return result.transpose(*[d for d in all_dims if d in result.dims]) def where(cond, x, y): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 74750cf54f1..c1a11c9918a 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1925,8 +1925,8 @@ def test_polyval(use_dask, use_datetime): dims=["time", "cartesian", "var"], coords=dict( time=(["time"], np.arange(0, 5)), - cartesian=(["cartesian"], np.array(["x", "y", "z"])), - var=(["var"], np.array([1, 1.5, 2, 2.5])), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), ), ), xr.DataArray( @@ -1934,8 +1934,8 @@ def test_polyval(use_dask, use_datetime): dims=["time", "cartesian", "var"], coords=dict( time=(["time"], np.arange(0, 5)), - cartesian=(["cartesian"], np.array(["x", "y", "z"])), - var=(["var"], np.array([1, 1.5, 2, 2.5])), + cartesian=(["cartesian"], ["x", "y", "z"]), + var=(["var"], [1, 1.5, 2, 2.5]), ), ), np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), @@ -1947,12 +1947,12 @@ def test_polyval(use_dask, use_datetime): xr.DataArray( np.array([1]), dims=["cartesian"], - coords=dict(cartesian=(["cartesian"], np.array(["z"]))), + coords=dict(cartesian=(["cartesian"], ["z"])), ), xr.DataArray( np.array([4, 5, 6]), dims=["cartesian"], - coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ), np.array([0, 0, 1]), np.array([4, 5, 6]), @@ -1963,12 +1963,12 @@ def test_polyval(use_dask, use_datetime): xr.DataArray( np.array([1, 2]), dims=["cartesian"], - coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))), + coords=dict(cartesian=(["cartesian"], ["x", "z"])), ), xr.DataArray( np.array([4, 5, 6]), dims=["cartesian"], - coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))), + coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ), np.array([1, 0, 2]), np.array([4, 5, 6]), From a20cb8690dd58d69140d43176a888b45e595ba3c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 24 May 2021 21:11:06 +0200 Subject: [PATCH 019/100] add dims to tests --- xarray/core/computation.py | 12 ++++++------ xarray/tests/test_computation.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 4a31aa8dbcb..c0fa9d963c0 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1407,7 +1407,7 @@ def cross(a, b, dim): >>> a = xr.DataArray([1, 2, 3]) >>> b = xr.DataArray([4, 5, 6]) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "dim_0") array([-3, 6, -3]) Dimensions without coordinates: dim_0 @@ -1417,7 +1417,7 @@ def cross(a, b, dim): >>> a = xr.DataArray([1, 2]) >>> b = xr.DataArray([4, 5]) - >>> xr.cross(a, b) + >>> xr.cross(a, b, "dim_0") array(-3) @@ -1528,16 +1528,16 @@ def cross(a, b, dim): else: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - ind = 1 if a.sizes[dim] > b.sizes[dim] else 0 + i = 1 if a.sizes[dim] > b.sizes[dim] else 0 if all([arr.coords for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: - arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0) - elif arrays[ind].sizes[dim] > 1: + arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) + elif arrays[i].sizes[dim] > 1: # If the array doesn't have coords we can can only infer # that it is composite values if the size is 2: - arrays[ind] = arrays[ind].pad({dim: (0, 1)}, constant_values=0) + arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) else: # Size is 1, then we do not know if the array is a constant or # composite value: diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c1a11c9918a..f6917f005e2 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1908,7 +1908,7 @@ def test_polyval(use_dask, use_datetime): xr.DataArray(np.array([4, 5, 6])), np.array([1, 2, 3]), np.array([4, 5, 6]), - None, + "dim_0", -1, ], [ @@ -1916,7 +1916,7 @@ def test_polyval(use_dask, use_datetime): xr.DataArray(np.array([4, 5, 6])), np.array([1, 2]), np.array([4, 5, 6]), - None, + "dim_0", -1, ], [ # Test dim in the middle: @@ -1956,7 +1956,7 @@ def test_polyval(use_dask, use_datetime): ), np.array([0, 0, 1]), np.array([4, 5, 6]), - None, + "cartesian", -1, ], [ # Test filling inbetween with coords: @@ -1972,7 +1972,7 @@ def test_polyval(use_dask, use_datetime): ), np.array([1, 0, 2]), np.array([4, 5, 6]), - None, + "cartesian", -1, ], ], From 7ce9315882ee220e687f01cd800024e566dc906c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 24 May 2021 21:21:56 +0200 Subject: [PATCH 020/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c0fa9d963c0..f718401d116 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1550,7 +1550,7 @@ def cross(a, b, dim): np.cross, *arrays, input_core_dims=[[dim], [dim]], - output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else [[]]], + output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], dask="parallelized", output_dtypes=[ np.cross( From d5a0ea8dffed27b228790edf70e7d5c48bfc2840 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 05:09:48 +0200 Subject: [PATCH 021/100] reduce code --- xarray/core/computation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f718401d116..37f2d86c585 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1522,10 +1522,7 @@ def cross(a, b, dim): all_dims += [d for d in arr.dims if d not in all_dims] - if a.sizes[dim] == b.sizes[dim]: - # Arrays have the same size, no need to do anything: - pass - else: + if a.sizes[dim] != b.sizes[dim]: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: i = 1 if a.sizes[dim] > b.sizes[dim] else 0 From ef94fa499134acb320468f2d17c744444cb5d845 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 05:28:42 +0200 Subject: [PATCH 022/100] support xr.Variable --- xarray/core/computation.py | 2 +- xarray/tests/test_computation.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 37f2d86c585..c5a0358087a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1527,7 +1527,7 @@ def cross(a, b, dim): # array is missing a value, zeros will not affect np.cross: i = 1 if a.sizes[dim] > b.sizes[dim] else 0 - if all([arr.coords for arr in arrays]): + if all([getattr(arr, "coords", False) for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index f6917f005e2..d4d6a93c3e8 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1919,6 +1919,22 @@ def test_polyval(use_dask, use_datetime): "dim_0", -1, ], + [ + xr.Variable(dims=["dim_0"], data=np.array([1, 2, 3])), + xr.Variable(dims=["dim_0"], data=np.array([4, 5, 6])), + np.array([1, 2, 3]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], + [ + xr.Variable(dims=["dim_0"], data=np.array([1, 2])), + xr.Variable(dims=["dim_0"], data=np.array([4, 5, 6])), + np.array([1, 2]), + np.array([4, 5, 6]), + "dim_0", + -1, + ], [ # Test dim in the middle: xr.DataArray( np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), From 1a851478aada944c4f060582ab82302c5070d06d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 05:38:33 +0200 Subject: [PATCH 023/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c5a0358087a..607a59f88e1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1503,7 +1503,7 @@ def cross(a, b, dim): all_dims = [] arrays = [a, b] for arr in arrays: - if not isinstance(arr, (DataArray)): + if not isinstance(arr, (DataArray, Variable)): raise TypeError( f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." ) From 2ce3dbe1273f32d1104322e90ecebffa162b1c6c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 05:42:02 +0200 Subject: [PATCH 024/100] Update computation.py --- xarray/core/computation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 607a59f88e1..bb12d158842 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1550,6 +1550,7 @@ def cross(a, b, dim): output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], dask="parallelized", output_dtypes=[ + # TODO: Is there a better way of figuring out the dtype? np.cross( np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype), From 53c84c295bc69a6418a77e94b1ed18712bf3bd1c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 05:51:18 +0200 Subject: [PATCH 025/100] reduce code --- xarray/core/computation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bb12d158842..88c779b0634 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1549,12 +1549,9 @@ def cross(a, b, dim): input_core_dims=[[dim], [dim]], output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], dask="parallelized", + # TODO: Is there a better way of figuring out the dtype? output_dtypes=[ - # TODO: Is there a better way of figuring out the dtype? - np.cross( - np.empty((2, 2), dtype=arrays[0].dtype), - np.empty((2, 2), dtype=arrays[1].dtype), - ).dtype + np.cross(*[np.empty((2, 2), dtype=arr.dtype) for arr in arrays]).dtype ], ) From dded7206808086dbda416218991d95a0f47dc501 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 06:04:45 +0200 Subject: [PATCH 026/100] docstring explanations --- xarray/core/computation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 88c779b0634..956f9d94954 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1396,10 +1396,10 @@ def cross(a, b, dim): Parameters ---------- - a, b : DataArray - something - dim : hashable or tuple of hashable - something + a, b : DataArray or Variable + Components of the first and second vector(s). + dim : hashable + Dimension to calculate the cross product over. Examples -------- From 705816646a83d9d5296be7e31865dc9ce6ec2d24 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 06:08:07 +0200 Subject: [PATCH 027/100] Use same terms --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 956f9d94954..0bfd9e6d0d1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1412,7 +1412,7 @@ def cross(a, b, dim): array([-3, 6, -3]) Dimensions without coordinates: dim_0 - Vector cross-product with 2 dimensions, returns in the orthogonal + Vector cross-product with 2 dimensions, returns in the perpendicular direction: >>> a = xr.DataArray([1, 2]) From cb57a55eda08f62d6b03806f1a95066480ae2ab3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 06:44:23 +0200 Subject: [PATCH 028/100] docstring formatting --- xarray/core/computation.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0bfd9e6d0d1..71653d71f76 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1386,20 +1386,21 @@ def cross(a, b, dim): """ Return the cross product of two (arrays of) vectors. - The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular - to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors - are defined by the last axis of `a` and `b` by default, and these axes - can have dimensions 2 or 3. Where the dimension of either `a` or `b` is - 2, the third component of the input vector is assumed to be zero and the - cross product calculated accordingly. In cases where both input vectors - have dimension 2, the z-component of the cross product is returned. + The cross product of `a` and `b` in :math:`R^3` is a vector + perpendicular to both `a` and `b`. If `a` and `b` are arrays of + vectors, and these axes can have dimensions 2 or 3. Where the + dimension of either `a` or `b` is 2, the third component of the + input vector is assumed to be zero and the cross product calculated + accordingly. In cases where both input vectors have dimension 2, + the z-component of the cross product is returned. Parameters ---------- a, b : DataArray or Variable Components of the first and second vector(s). dim : hashable - Dimension to calculate the cross product over. + The dimension along which the cross product will be computed. + Must be available in both vectors. Examples -------- From e69ca8184150f66caf529a09fd85fceafe1ab639 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 07:12:41 +0200 Subject: [PATCH 029/100] reduce code --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 71653d71f76..c806465fc2c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1552,7 +1552,7 @@ def cross(a, b, dim): dask="parallelized", # TODO: Is there a better way of figuring out the dtype? output_dtypes=[ - np.cross(*[np.empty((2, 2), dtype=arr.dtype) for arr in arrays]).dtype + np.cross(*[np.empty(3, dtype=arr.dtype) for arr in arrays]).dtype ], ) From 4b2fc72df79946942d99ddd1af76e7ec7eb2a6f5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 25 May 2021 17:35:28 +0200 Subject: [PATCH 030/100] add tests for dask --- xarray/tests/test_computation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index d4d6a93c3e8..30e57ced494 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1900,6 +1900,7 @@ def test_polyval(use_dask, use_datetime): xr.testing.assert_allclose(da, da_pv.T) +@pytest.mark.parametrize("use_dask", [False, True]) @pytest.mark.parametrize( "a, b, ae, be, dim, axis", [ @@ -1993,7 +1994,12 @@ def test_polyval(use_dask, use_datetime): ], ], ) -def test_cross(a, b, ae, be, dim, axis): +def test_cross(a, b, ae, be, dim, axis, use_dask): expected = np.cross(ae, be, axis=axis) actual = xr.cross(a, b, dim=dim) + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + actual = actual.chunk() + xr.testing.assert_duckarray_allclose(expected, actual) From afe572d7b9d28ffc9c684b30334ea18501abe24c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 26 May 2021 07:48:09 +0200 Subject: [PATCH 031/100] simplify check, align used variables --- xarray/core/computation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c806465fc2c..25b5822ce6a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1523,16 +1523,16 @@ def cross(a, b, dim): all_dims += [d for d in arr.dims if d not in all_dims] - if a.sizes[dim] != b.sizes[dim]: + if arrays[0].sizes[dim] != arrays[1].sizes[dim]: # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - i = 1 if a.sizes[dim] > b.sizes[dim] else 0 + i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 if all([getattr(arr, "coords", False) for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) - elif arrays[i].sizes[dim] > 1: + elif arrays[i].sizes[dim] == 2: # If the array doesn't have coords we can can only infer # that it is composite values if the size is 2: arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) From e137350d7bb3d65e62ea92cf6614909046098054 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 26 May 2021 18:09:39 +0200 Subject: [PATCH 032/100] trim down tests --- xarray/tests/test_computation.py | 44 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 30e57ced494..dc61e3d272c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1905,34 +1905,34 @@ def test_polyval(use_dask, use_datetime): "a, b, ae, be, dim, axis", [ [ - xr.DataArray(np.array([1, 2, 3])), - xr.DataArray(np.array([4, 5, 6])), - np.array([1, 2, 3]), - np.array([4, 5, 6]), + xr.DataArray([1, 2, 3]), + xr.DataArray([4, 5, 6]), + [1, 2, 3], + [4, 5, 6], "dim_0", -1, ], [ - xr.DataArray(np.array([1, 2])), - xr.DataArray(np.array([4, 5, 6])), - np.array([1, 2]), - np.array([4, 5, 6]), + xr.DataArray([1, 2]), + xr.DataArray([4, 5, 6]), + [1, 2], + [4, 5, 6], "dim_0", -1, ], [ - xr.Variable(dims=["dim_0"], data=np.array([1, 2, 3])), - xr.Variable(dims=["dim_0"], data=np.array([4, 5, 6])), - np.array([1, 2, 3]), - np.array([4, 5, 6]), + xr.Variable(dims=["dim_0"], data=[1, 2, 3]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + [1, 2, 3], + [4, 5, 6], "dim_0", -1, ], [ - xr.Variable(dims=["dim_0"], data=np.array([1, 2])), - xr.Variable(dims=["dim_0"], data=np.array([4, 5, 6])), - np.array([1, 2]), - np.array([4, 5, 6]), + xr.Variable(dims=["dim_0"], data=[1, 2]), + xr.Variable(dims=["dim_0"], data=[4, 5, 6]), + [1, 2], + [4, 5, 6], "dim_0", -1, ], @@ -1971,24 +1971,24 @@ def test_polyval(use_dask, use_datetime): dims=["cartesian"], coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ), - np.array([0, 0, 1]), - np.array([4, 5, 6]), + [0, 0, 1], + [4, 5, 6], "cartesian", -1, ], [ # Test filling inbetween with coords: xr.DataArray( - np.array([1, 2]), + [1, 2], dims=["cartesian"], coords=dict(cartesian=(["cartesian"], ["x", "z"])), ), xr.DataArray( - np.array([4, 5, 6]), + [4, 5, 6], dims=["cartesian"], coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ), - np.array([1, 0, 2]), - np.array([4, 5, 6]), + [1, 0, 2], + [4, 5, 6], "cartesian", -1, ], From 1a26324140075bf6ad4ff5744af8a1a893b05110 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 26 May 2021 21:26:49 +0200 Subject: [PATCH 033/100] Update computation.py --- xarray/core/computation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 25b5822ce6a..d4b1d19c5a6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1517,8 +1517,8 @@ def cross(a, b, dim): s = arr.sizes[dim] if s < 1 or s > 3: raise ValueError( - "incompatible dimensions for cross product\n" - "(dimension with coords must be 1, 2 or 3)" + "Incompatible dimensions for cross product,\n" + "dimension with coords must be 1, 2 or 3." ) all_dims += [d for d in arr.dims if d not in all_dims] @@ -1540,8 +1540,8 @@ def cross(a, b, dim): # Size is 1, then we do not know if the array is a constant or # composite value: raise ValueError( - "incompatible dimensions for cross product\n" - "(dimension without coords must be 2 or 3)" + "Incompatible dimensions for cross product,\n" + "dimension without coords must be 2 or 3." ) c = apply_ufunc( From 531a98b71bb79db5eca4f65d184fb8d31dac7c3c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 27 May 2021 07:41:54 +0200 Subject: [PATCH 034/100] simplify code --- xarray/core/computation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index d4b1d19c5a6..7bcd810d74d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1514,8 +1514,7 @@ def cross(a, b, dim): if dim not in arr.dims: raise ValueError(f"Dimension {dim} not in {arr}.") - s = arr.sizes[dim] - if s < 1 or s > 3: + if not 1 <= arr.sizes[dim] <= 3: raise ValueError( "Incompatible dimensions for cross product,\n" "dimension with coords must be 1, 2 or 3." From 214640649ae71f701d089ef6763adac78b3f081b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 21:21:36 +0200 Subject: [PATCH 035/100] Add type hints --- xarray/core/computation.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7bcd810d74d..45f19bb57a0 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: from .coordinates import Coordinates # noqa + from .dataarray import DataArray from .dataset import Dataset _NO_FILL_VALUE = utils.ReprObject("") @@ -1382,7 +1383,9 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): return corr -def cross(a, b, dim): +def cross( + a: Union["DataArray", "Variable"], b: Union["DataArray", "Variable"], dim: Hashable +) -> Union["DataArray", "Variable"]: """ Return the cross product of two (arrays of) vectors. @@ -1499,16 +1502,9 @@ def cross(a, b, dim): -------- numpy.cross : Corresponding numpy function """ - from .dataarray import DataArray - all_dims = [] arrays = [a, b] for arr in arrays: - if not isinstance(arr, (DataArray, Variable)): - raise TypeError( - f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." - ) - # TODO: Find spatial dim default by looking for unique # (3 or 2)-valued dim? if dim not in arr.dims: From 094047213d39cc1e52496361e36bbdccd61333a4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 21:45:12 +0200 Subject: [PATCH 036/100] less type hints --- xarray/core/computation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 45f19bb57a0..a7d74e4d2e6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1383,9 +1383,7 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): return corr -def cross( - a: Union["DataArray", "Variable"], b: Union["DataArray", "Variable"], dim: Hashable -) -> Union["DataArray", "Variable"]: +def cross(a, b, dim: Hashable) -> Union["DataArray", "Variable"]: """ Return the cross product of two (arrays of) vectors. @@ -1505,6 +1503,11 @@ def cross( all_dims = [] arrays = [a, b] for arr in arrays: + if not isinstance(arr, (DataArray, Variable)): + raise TypeError( + f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." + ) + # TODO: Find spatial dim default by looking for unique # (3 or 2)-valued dim? if dim not in arr.dims: From a7cc565c302c003cf3aaecc158c1694886ece597 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 21:51:37 +0200 Subject: [PATCH 037/100] Update computation.py --- xarray/core/computation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a7d74e4d2e6..6746e47849c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1500,6 +1500,8 @@ def cross(a, b, dim: Hashable) -> Union["DataArray", "Variable"]: -------- numpy.cross : Corresponding numpy function """ + from .dataarray import DataArray + all_dims = [] arrays = [a, b] for arr in arrays: From 1d1f20510c09d649b856f7bc2931864999d03623 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 22:01:53 +0200 Subject: [PATCH 038/100] undo type hints --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6746e47849c..700e052994c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1383,7 +1383,7 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): return corr -def cross(a, b, dim: Hashable) -> Union["DataArray", "Variable"]: +def cross(a, b, dim): """ Return the cross product of two (arrays of) vectors. From 9af7091562ba4c40be6e0ed70ba1058dd9711975 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 28 May 2021 22:03:23 +0200 Subject: [PATCH 039/100] Update computation.py --- xarray/core/computation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 700e052994c..7bcd810d74d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from .coordinates import Coordinates # noqa - from .dataarray import DataArray from .dataset import Dataset _NO_FILL_VALUE = utils.ReprObject("") From 14decb3cd696b5e699eec8aa07b253ff916ddaac Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 30 May 2021 21:52:40 +0200 Subject: [PATCH 040/100] Add support for datasets --- xarray/core/computation.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7bcd810d74d..4776c7e35ce 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1396,7 +1396,7 @@ def cross(a, b, dim): Parameters ---------- - a, b : DataArray or Variable + a, b : DataArray, Dataset or Variable Components of the first and second vector(s). dim : hashable The dimension along which the cross product will be computed. @@ -1500,13 +1500,24 @@ def cross(a, b, dim): numpy.cross : Corresponding numpy function """ from .dataarray import DataArray + from .dataset import Dataset all_dims = [] arrays = [a, b] - for arr in arrays: - if not isinstance(arr, (DataArray, Variable)): + for i, arr in enumerate(arrays): + if isinstance(arr, Dataset): + is_dataset = True + # TODO: How make sure this temporary dimension is matches + # the orther dataset? + arrays[i] = arr = arr.to_stacked_array( + variable_dim=dim, new_dim="variable", sample_dims=arr.dims + ).unstack("variable") + elif isinstance(arr, (DataArray, Variable)): + is_dataset = False + else: raise TypeError( - f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}." + "Only xr.DataArray, xr.Dataset and xr.Variable are supported, " + f"got {type(arr)}." ) # TODO: Find spatial dim default by looking for unique @@ -1554,8 +1565,11 @@ def cross(a, b, dim): np.cross(*[np.empty(3, dtype=arr.dtype) for arr in arrays]).dtype ], ) + c = c.transpose(*[d for d in all_dims if d in c.dims]) + if is_dataset: + c = c.stack(variable=[dim]).to_unstacked_dataset("variable") - return c.transpose(*[d for d in all_dims if d in c.dims]) + return c def dot(*arrays, dims=None, **kwargs): From 6f73c3224da2c7f84df9fe4c3dd1554ffdd81a60 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 2 Jun 2021 22:42:59 +0200 Subject: [PATCH 041/100] determine dtype with np.result_type --- xarray/core/computation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 4776c7e35ce..dbab51c0ccd 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1560,10 +1560,7 @@ def cross(a, b, dim): input_core_dims=[[dim], [dim]], output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], dask="parallelized", - # TODO: Is there a better way of figuring out the dtype? - output_dtypes=[ - np.cross(*[np.empty(3, dtype=arr.dtype) for arr in arrays]).dtype - ], + output_dtypes=[np.result_type(*arrays)], ) c = c.transpose(*[d for d in all_dims if d in c.dims]) if is_dataset: From 72330ceae22bc7c21937ac88e01598f202f85a69 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 6 Jun 2021 11:38:51 +0200 Subject: [PATCH 042/100] test datasets, daskify the inputs not the results --- xarray/tests/test_computation.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index dc61e3d272c..55675544f6f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1936,6 +1936,22 @@ def test_polyval(use_dask, use_datetime): "dim_0", -1, ], + [ + xr.Dataset({0: ("dim_0", [1]), 1: ("dim_0", [2]), 2: ("dim_0", [3])}) + .to_stacked_array( + variable_dim="cartesian", new_dim="variable", sample_dims=("dim_0",) + ) + .unstack("variable"), + xr.Dataset({0: ("dim_0", [4]), 1: ("dim_0", [5]), 2: ("dim_0", [6])}) + .to_stacked_array( + variable_dim="cartesian", new_dim="variable", sample_dims=("dim_0",) + ) + .unstack("variable"), + [1, 2, 3], + [4, 5, 6], + "cartesian", + -1, + ], [ # Test dim in the middle: xr.DataArray( np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), @@ -1996,10 +2012,16 @@ def test_polyval(use_dask, use_datetime): ) def test_cross(a, b, ae, be, dim, axis, use_dask): expected = np.cross(ae, be, axis=axis) - actual = xr.cross(a, b, dim=dim) + if use_dask: if not has_dask: pytest.skip("test for dask.") - actual = actual.chunk() + a = a.chunk() + b = b.chunk() + + actual = xr.cross(a, b, dim=dim) + + if isinstance(actual, xr.Dataset): + actual = actual.stack(variable=[dim]).to_unstacked_dataset("variable") xr.testing.assert_duckarray_allclose(expected, actual) From bce2f3ecd74115eebc5ff1f48fd4b470613793ca Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 6 Jun 2021 11:45:33 +0200 Subject: [PATCH 043/100] rechunk padded values, handle 1 sized datasets --- xarray/core/computation.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index dbab51c0ccd..d4e129d02c9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1507,8 +1507,9 @@ def cross(a, b, dim): for i, arr in enumerate(arrays): if isinstance(arr, Dataset): is_dataset = True - # TODO: How make sure this temporary dimension is matches - # the orther dataset? + # Turn the dataset to a stacked dataarray to follow the + # normal code path. Then at the end turn it back to a + # dataset. arrays[i] = arr = arr.to_stacked_array( variable_dim=dim, new_dim="variable", sample_dims=arr.dims ).unstack("variable") @@ -1546,6 +1547,8 @@ def cross(a, b, dim): # If the array doesn't have coords we can can only infer # that it is composite values if the size is 2: arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) + if is_duck_dask_array(arrays[i].data): + arrays[i] = arrays[i].chunk({dim: -1}) else: # Size is 1, then we do not know if the array is a constant or # composite value: @@ -1565,6 +1568,9 @@ def cross(a, b, dim): c = c.transpose(*[d for d in all_dims if d in c.dims]) if is_dataset: c = c.stack(variable=[dim]).to_unstacked_dataset("variable") + c = c.expand_dims( + [dim for ds in arrays for dim, size in ds.sizes.items() if size == 1] + ) return c From 1636d251084e8af0eac9946031e81ef68053d74b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 6 Jun 2021 13:41:59 +0200 Subject: [PATCH 044/100] expand only unique dims, squeeze out dims in tests --- xarray/core/computation.py | 4 ++-- xarray/tests/test_computation.py | 20 +++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index d4e129d02c9..b838b5d8341 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1506,10 +1506,10 @@ def cross(a, b, dim): arrays = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): - is_dataset = True # Turn the dataset to a stacked dataarray to follow the # normal code path. Then at the end turn it back to a # dataset. + is_dataset = True arrays[i] = arr = arr.to_stacked_array( variable_dim=dim, new_dim="variable", sample_dims=arr.dims ).unstack("variable") @@ -1569,7 +1569,7 @@ def cross(a, b, dim): if is_dataset: c = c.stack(variable=[dim]).to_unstacked_dataset("variable") c = c.expand_dims( - [dim for ds in arrays for dim, size in ds.sizes.items() if size == 1] + list({d: s for ds in arrays for d, s in ds.sizes.items() if s == 1}) ) return c diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 55675544f6f..55d18e1597b 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1937,16 +1937,8 @@ def test_polyval(use_dask, use_datetime): -1, ], [ - xr.Dataset({0: ("dim_0", [1]), 1: ("dim_0", [2]), 2: ("dim_0", [3])}) - .to_stacked_array( - variable_dim="cartesian", new_dim="variable", sample_dims=("dim_0",) - ) - .unstack("variable"), - xr.Dataset({0: ("dim_0", [4]), 1: ("dim_0", [5]), 2: ("dim_0", [6])}) - .to_stacked_array( - variable_dim="cartesian", new_dim="variable", sample_dims=("dim_0",) - ) - .unstack("variable"), + xr.Dataset({0: ("dim_0", [1]), 1: ("dim_0", [2]), 2: ("dim_0", [3])}), + xr.Dataset({0: ("dim_0", [4]), 1: ("dim_0", [5]), 2: ("dim_0", [6])}), [1, 2, 3], [4, 5, 6], "cartesian", @@ -2022,6 +2014,12 @@ def test_cross(a, b, ae, be, dim, axis, use_dask): actual = xr.cross(a, b, dim=dim) if isinstance(actual, xr.Dataset): - actual = actual.stack(variable=[dim]).to_unstacked_dataset("variable") + actual = ( + actual.to_stacked_array( + variable_dim=dim, new_dim="variable", sample_dims=actual.dims + ) + .unstack("variable") + .squeeze() + ) xr.testing.assert_duckarray_allclose(expected, actual) From b5b97a0d66cb56d80c01fdaaf642b3e904d6db25 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 6 Jun 2021 15:39:45 +0200 Subject: [PATCH 045/100] rechunk along the dim --- xarray/core/computation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b838b5d8341..f36388cc4a4 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1513,6 +1513,8 @@ def cross(a, b, dim): arrays[i] = arr = arr.to_stacked_array( variable_dim=dim, new_dim="variable", sample_dims=arr.dims ).unstack("variable") + if is_duck_dask_array(arr.data): + arrays[i] = arr = arr.chunk({dim: -1}) elif isinstance(arr, (DataArray, Variable)): is_dataset = False else: From 02364ca2346ccad0af71e04b747bfc1b5733689c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:05:19 +0200 Subject: [PATCH 046/100] Attempt typing again --- xarray/core/computation.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1530aa07f9e..80b6aaee5d0 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -20,6 +20,7 @@ Optional, Sequence, Tuple, + TypeVar, Union, ) @@ -37,6 +38,8 @@ from .coordinates import Coordinates # noqa from .dataset import Dataset + T_DSorDAorVar = TypeVar("T_DSorDAorVar", Dataset, DataArray, Variable) + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) @@ -1393,7 +1396,11 @@ def _get_valid_values(da, other): return corr -def cross(a, b, dim): +def cross( + a: T_DSorDAorVar, + b: T_DSorDAorVar, + dim: Hashable, +) -> T_DSorDAorVar: """ Return the cross product of two (arrays of) vectors. @@ -1551,15 +1558,19 @@ def cross(a, b, dim): # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 + array_large, array_small = array[i], array[1 - i] - if all([getattr(arr, "coords", False) for arr in arrays]): + if getattr(array_large, "coords", False) and getattr( + array_small, "coords", False + ): + # if all([getattr(arr, "coords", False) for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: - arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) - elif arrays[i].sizes[dim] == 2: + arrays[i] = array_small.reindex_like(array_large, fill_value=0) + elif array_small.sizes[dim] == 2: # If the array doesn't have coords we can can only infer # that it is composite values if the size is 2: - arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) + arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) if is_duck_dask_array(arrays[i].data): arrays[i] = arrays[i].chunk({dim: -1}) else: From ed44400138438547de33bab14c8182c5815f97ab Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:14:38 +0200 Subject: [PATCH 047/100] Update __init__.py --- xarray/__init__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index c4271e2432f..e838dd29785 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,7 +18,16 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, cross, dot, polyval, unify_chunks, where +from .core.computation import ( + apply_ufunc, + corr, + cov, + cross, + dot, + polyval, + unify_chunks, + where, +) from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset From 4fe9737c95c9260a6db7685907deef3181272b50 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:26:27 +0200 Subject: [PATCH 048/100] Update computation.py --- xarray/core/computation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a739c087bb4..7f0a4d4b0cc 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1398,10 +1398,10 @@ def _get_valid_values(da, other): def cross( - a: T_DSorDAorVar, - b: T_DSorDAorVar, + a: "T_DSorDAorVar", + b: "T_DSorDAorVar", dim: Hashable, -) -> T_DSorDAorVar: +) -> "T_DSorDAorVar": """ Return the cross product of two (arrays of) vectors. @@ -1559,7 +1559,7 @@ def cross( # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 - array_large, array_small = array[i], array[1 - i] + array_large, array_small = arrays[i], arrays[1 - i] if getattr(array_large, "coords", False) and getattr( array_small, "coords", False From ec05780c78a64c70b73cf9e91c8618255215f537 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:30:55 +0200 Subject: [PATCH 049/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7f0a4d4b0cc..44bcde71fca 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1400,7 +1400,7 @@ def _get_valid_values(da, other): def cross( a: "T_DSorDAorVar", b: "T_DSorDAorVar", - dim: Hashable, + dim: str, ) -> "T_DSorDAorVar": """ Return the cross product of two (arrays of) vectors. From 36c5956dfbe1baf1683b2c016c77f2c964fb102d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:37:56 +0200 Subject: [PATCH 050/100] test fixing type in to_stacked_array --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 48925b70b66..31f17e99c96 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3867,7 +3867,7 @@ def to_stacked_array( self, new_dim: Hashable, sample_dims: Sequence[Hashable], - variable_dim: str = "variable", + variable_dim: Hashable = "variable", name: Hashable = None, ) -> "DataArray": """Combine variables of differing dimensionality into a DataArray From cbf289cffdc26c8d96f790013fc604ccd550e48c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Jun 2021 21:49:24 +0200 Subject: [PATCH 051/100] test fixing to_stacked_array --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 31f17e99c96..a7b10ea8f86 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3866,7 +3866,7 @@ def stack( def to_stacked_array( self, new_dim: Hashable, - sample_dims: Sequence[Hashable], + sample_dims: Mapping[Hashable, Sequence[Hashable]], variable_dim: Hashable = "variable", name: Hashable = None, ) -> "DataArray": From 4cfd5be654194eb20523dadb998b67a07d1447e5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 18 Jun 2021 18:45:59 +0200 Subject: [PATCH 052/100] small is large --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 44bcde71fca..c2b47055a68 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1559,7 +1559,7 @@ def cross( # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 - array_large, array_small = arrays[i], arrays[1 - i] + array_small, array_large = arrays[i], arrays[1 - i] if getattr(array_large, "coords", False) and getattr( array_small, "coords", False From 658a59fb3f6319140d847b41b8b08f0c6265647d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 18 Jun 2021 19:33:36 +0200 Subject: [PATCH 053/100] Update computation.py --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c2b47055a68..af2c37dbad9 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1521,8 +1521,8 @@ def cross( from .dataarray import DataArray from .dataset import Dataset - all_dims = [] - arrays = [a, b] + all_dims: List[Hashable] = [] + arrays: List["T_DSorDAorVar"] = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the From ab5ae2054a5ba80afe1acfc081bb164bc08a7039 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 18 Jun 2021 23:18:13 +0200 Subject: [PATCH 054/100] Update xarray/core/computation.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index af2c37dbad9..981684a399c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1569,7 +1569,7 @@ def cross( # with zeros: arrays[i] = array_small.reindex_like(array_large, fill_value=0) elif array_small.sizes[dim] == 2: - # If the array doesn't have coords we can can only infer + # If the array doesn't have coords we can only infer # that it is composite values if the size is 2: arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) if is_duck_dask_array(arrays[i].data): From d65ca418f3db8eee0e044b0d9d6bb842217abe3a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 14:11:53 +0200 Subject: [PATCH 055/100] obfuscate variable_dim some --- xarray/core/computation.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 981684a399c..a63566dc7d4 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1400,7 +1400,7 @@ def _get_valid_values(da, other): def cross( a: "T_DSorDAorVar", b: "T_DSorDAorVar", - dim: str, + dim: Hashable, ) -> "T_DSorDAorVar": """ Return the cross product of two (arrays of) vectors. @@ -1530,8 +1530,8 @@ def cross( # dataset. is_dataset = True arrays[i] = arr = arr.to_stacked_array( - variable_dim=dim, new_dim="variable", sample_dims=arr.dims - ).unstack("variable") + variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims + ).unstack("stacked__dim") if is_duck_dask_array(arr.data): arrays[i] = arr = arr.chunk({dim: -1}) elif isinstance(arr, (DataArray, Variable)): @@ -1561,10 +1561,9 @@ def cross( i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 array_small, array_large = arrays[i], arrays[1 - i] - if getattr(array_large, "coords", False) and getattr( - array_small, "coords", False + if getattr(array_small, "coords", False) and getattr( + array_large, "coords", False ): - # if all([getattr(arr, "coords", False) for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: arrays[i] = array_small.reindex_like(array_large, fill_value=0) @@ -1592,7 +1591,7 @@ def cross( ) c = c.transpose(*[d for d in all_dims if d in c.dims]) if is_dataset: - c = c.stack(variable=[dim]).to_unstacked_dataset("variable") + c = c.stack(stacked__dim=[dim]).to_unstacked_dataset("stacked__dim") c = c.expand_dims( list({d: s for ds in arrays for d, s in ds.sizes.items() if s == 1}) ) From 20eef039e90a09cd7a2b9b5c329fb0463fb50659 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 14:21:25 +0200 Subject: [PATCH 056/100] Update computation.py --- xarray/core/computation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a63566dc7d4..eb4067a88ef 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1398,10 +1398,10 @@ def _get_valid_values(da, other): def cross( - a: "T_DSorDAorVar", - b: "T_DSorDAorVar", + a: T_DSorDAorVar, + b: T_DSorDAorVar, dim: Hashable, -) -> "T_DSorDAorVar": +) -> T_DSorDAorVar: """ Return the cross product of two (arrays of) vectors. @@ -1522,7 +1522,7 @@ def cross( from .dataset import Dataset all_dims: List[Hashable] = [] - arrays: List["T_DSorDAorVar"] = [a, b] + arrays: List[T_DSorDAorVar] = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the From 274af32042ccc4da8a3c7608293e9b747e551c74 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 14:26:12 +0200 Subject: [PATCH 057/100] undo to_stacked_array changes --- xarray/core/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7b10ea8f86..48925b70b66 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3866,8 +3866,8 @@ def stack( def to_stacked_array( self, new_dim: Hashable, - sample_dims: Mapping[Hashable, Sequence[Hashable]], - variable_dim: Hashable = "variable", + sample_dims: Sequence[Hashable], + variable_dim: str = "variable", name: Hashable = None, ) -> "DataArray": """Combine variables of differing dimensionality into a DataArray From f3523035cf564c53e00ea098552ad89115c88c14 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 14:35:34 +0200 Subject: [PATCH 058/100] test sample_dims typing --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 48925b70b66..f855d6f3f00 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3866,7 +3866,7 @@ def stack( def to_stacked_array( self, new_dim: Hashable, - sample_dims: Sequence[Hashable], + sample_dims: Mapping[Hashable, int], variable_dim: str = "variable", name: Hashable = None, ) -> "DataArray": From 0a773cbf78ee340c88860af50bf5074b0725e72e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 14:47:32 +0200 Subject: [PATCH 059/100] to_stacked_array fixes --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f855d6f3f00..0be9cf8dadc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3867,7 +3867,7 @@ def to_stacked_array( self, new_dim: Hashable, sample_dims: Mapping[Hashable, int], - variable_dim: str = "variable", + variable_dim: Hashable = "variable", name: Hashable = None, ) -> "DataArray": """Combine variables of differing dimensionality into a DataArray From d8da29fd534d90541219bea28f8b24b4037ca9c7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 19 Jun 2021 20:33:21 +0200 Subject: [PATCH 060/100] add reindex_like check --- xarray/core/computation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index eb4067a88ef..30051b97d07 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1561,8 +1561,11 @@ def cross( i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 array_small, array_large = arrays[i], arrays[1 - i] - if getattr(array_small, "coords", False) and getattr( - array_large, "coords", False + if ( + getattr(array_small, "coords", False) + and getattr(array_large, "coords", False) + and hasattr(array_small, "reindex_like") + and hasattr(array_large, "reindex_like") ): # If the arrays have coords we know which indexes to fill # with zeros: From 54a76c13f47953f895e331ae50afd6dec855abc1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 20 Jun 2021 20:09:07 +0200 Subject: [PATCH 061/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 30051b97d07..e6ba6fc5127 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1522,7 +1522,7 @@ def cross( from .dataset import Dataset all_dims: List[Hashable] = [] - arrays: List[T_DSorDAorVar] = [a, b] + arrays: List[T_DSorDAorVar, T_DSorDAorVar] = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the From 0a2dc2ef24e69428c5ea2de7a853725dcf1247e4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 20 Jun 2021 20:15:17 +0200 Subject: [PATCH 062/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e6ba6fc5127..30051b97d07 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1522,7 +1522,7 @@ def cross( from .dataset import Dataset all_dims: List[Hashable] = [] - arrays: List[T_DSorDAorVar, T_DSorDAorVar] = [a, b] + arrays: List[T_DSorDAorVar] = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the From b3592f36ace0ac347c7cfe4ce9c56479448f2a20 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 20 Jun 2021 20:23:15 +0200 Subject: [PATCH 063/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 30051b97d07..8d546a80709 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1575,7 +1575,7 @@ def cross( # that it is composite values if the size is 2: arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) if is_duck_dask_array(arrays[i].data): - arrays[i] = arrays[i].chunk({dim: -1}) + arrays[i] = arrays[i].chunk({dim: -1.0}) else: # Size is 1, then we do not know if the array is a constant or # composite value: From 06772dadea907165f9be88cce7851312b74e4bf5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 20 Jun 2021 21:03:21 +0200 Subject: [PATCH 064/100] test forcing int type in chunk() --- xarray/core/dataarray.py | 8 ++++---- xarray/core/dataset.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index eab4413d5ce..6768046f0b0 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1039,10 +1039,10 @@ def chunks(self) -> Optional[Tuple[Tuple[int, ...], ...]]: def chunk( self, chunks: Union[ - Number, - Tuple[Number, ...], - Tuple[Tuple[Number, ...], ...], - Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]], + int, + Tuple[int, ...], + Tuple[Tuple[int, ...], ...], + Mapping[Hashable, Union[None, int, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0be9cf8dadc..746335d0020 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2085,9 +2085,9 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: def chunk( self, chunks: Union[ - Number, + int, str, - Mapping[Hashable, Union[None, Number, str, Tuple[Number, ...]]], + Mapping[Hashable, Union[None, int, str, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, From cfd11f75552d0005f1a8acaf5808b6ec7f383105 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 20 Jun 2021 21:09:58 +0200 Subject: [PATCH 065/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 8d546a80709..30051b97d07 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1575,7 +1575,7 @@ def cross( # that it is composite values if the size is 2: arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) if is_duck_dask_array(arrays[i].data): - arrays[i] = arrays[i].chunk({dim: -1.0}) + arrays[i] = arrays[i].chunk({dim: -1}) else: # Size is 1, then we do not know if the array is a constant or # composite value: From 90553edeff5b4f7f38d28c8d03c2249429cacc98 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 21 Jun 2021 20:08:49 +0200 Subject: [PATCH 066/100] test collection in to_stacked_array --- xarray/core/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 19ae746b3bf..5dc81a8439c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -12,6 +12,7 @@ TYPE_CHECKING, Any, Callable, + Collection, DefaultDict, Dict, Hashable, @@ -3870,7 +3871,7 @@ def stack( def to_stacked_array( self, new_dim: Hashable, - sample_dims: Mapping[Hashable, int], + sample_dims: Collection, variable_dim: Hashable = "variable", name: Hashable = None, ) -> "DataArray": From 6eed96e48689097f7529642bfcb8b4a0aa64574f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 21 Jun 2021 21:15:05 +0200 Subject: [PATCH 067/100] Update computation.py --- xarray/core/computation.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 148982dab74..3b34614029b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1559,21 +1559,17 @@ def cross( # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 - array_small, array_large = arrays[i], arrays[1 - i] - - if ( - getattr(array_small, "coords", False) - and getattr(array_large, "coords", False) - and hasattr(array_small, "reindex_like") - and hasattr(array_large, "reindex_like") - ): + + if all([getattr(arr, "coords", False) for arr in arrays]): # If the arrays have coords we know which indexes to fill # with zeros: - arrays[i] = array_small.reindex_like(array_large, fill_value=0) - elif array_small.sizes[dim] == 2: + arrays[i] = arrays[i].reindex_like( + arrays[1 - i], fill_value=0 + ) # type: DataArray + elif arrays[i].sizes[dim] == 2: # If the array doesn't have coords we can only infer # that it is composite values if the size is 2: - arrays[i] = array_small.pad({dim: (0, 1)}, constant_values=0) + arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) if is_duck_dask_array(arrays[i].data): arrays[i] = arrays[i].chunk({dim: -1}) else: From d3648e54c25710e2c01f593d0bf2001f0791081c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 22 Jun 2021 23:15:28 +0200 Subject: [PATCH 068/100] Update computation.py --- xarray/core/computation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 3b34614029b..1a4c5441b6f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1529,11 +1529,14 @@ def cross( # normal code path. Then at the end turn it back to a # dataset. is_dataset = True - arrays[i] = arr = arr.to_stacked_array( + arr = arr.to_stacked_array( variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims ).unstack("stacked__dim") + if is_duck_dask_array(arr.data): - arrays[i] = arr = arr.chunk({dim: -1}) + arr = arr.chunk({dim: -1}) + + arrays[i] = arr elif isinstance(arr, (DataArray, Variable)): is_dataset = False else: From c639aa33622b867996a6dc1f4914255d53b2b0b6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 22 Jun 2021 23:19:39 +0200 Subject: [PATCH 069/100] Update computation.py --- xarray/core/computation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1a4c5441b6f..037dbe7c196 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1529,9 +1529,10 @@ def cross( # normal code path. Then at the end turn it back to a # dataset. is_dataset = True - arr = arr.to_stacked_array( + arr_ = arr.to_stacked_array( variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims ).unstack("stacked__dim") + arr = arr_ if is_duck_dask_array(arr.data): arr = arr.chunk({dim: -1}) @@ -1568,7 +1569,7 @@ def cross( # with zeros: arrays[i] = arrays[i].reindex_like( arrays[1 - i], fill_value=0 - ) # type: DataArray + ) # type: ignore elif arrays[i].sizes[dim] == 2: # If the array doesn't have coords we can only infer # that it is composite values if the size is 2: From 4c636f545414f1413396ca25c278ddc94151ec1e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 22 Jun 2021 23:40:31 +0200 Subject: [PATCH 070/100] Update computation.py --- xarray/core/computation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 037dbe7c196..7a6bf39c0cd 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1522,17 +1522,16 @@ def cross( from .dataset import Dataset all_dims: List[Hashable] = [] - arrays: List[T_DSorDAorVar] = [a, b] + arrays: List[Any] = [a, b] for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the # normal code path. Then at the end turn it back to a # dataset. is_dataset = True - arr_ = arr.to_stacked_array( + arr = arr.to_stacked_array( variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims ).unstack("stacked__dim") - arr = arr_ if is_duck_dask_array(arr.data): arr = arr.chunk({dim: -1}) @@ -1564,12 +1563,15 @@ def cross( # array is missing a value, zeros will not affect np.cross: i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 - if all([getattr(arr, "coords", False) for arr in arrays]): + if all( + getattr(arr, "coords", False) and not isinstance(arr, Variable) + for arr in arrays + ): # If the arrays have coords we know which indexes to fill # with zeros: arrays[i] = arrays[i].reindex_like( arrays[1 - i], fill_value=0 - ) # type: ignore + ) elif arrays[i].sizes[dim] == 2: # If the array doesn't have coords we can only infer # that it is composite values if the size is 2: From 3bea9361bb09826c33f8c92b52c62e5e899b13f5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 22 Jun 2021 23:45:53 +0200 Subject: [PATCH 071/100] Update computation.py --- xarray/core/computation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7a6bf39c0cd..43790c633e6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1564,14 +1564,13 @@ def cross( i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 if all( + # The variable check is only used to make mypy happy: getattr(arr, "coords", False) and not isinstance(arr, Variable) for arr in arrays ): # If the arrays have coords we know which indexes to fill # with zeros: - arrays[i] = arrays[i].reindex_like( - arrays[1 - i], fill_value=0 - ) + arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) elif arrays[i].sizes[dim] == 2: # If the array doesn't have coords we can only infer # that it is composite values if the size is 2: From 12da913bf1ec35101d43b5a5a1d6cb6504d4e34c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 24 Jun 2021 20:38:59 +0200 Subject: [PATCH 072/100] whats new and api.rst --- doc/api.rst | 1 + doc/whats-new.rst | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index bb3a99bfbb0..fd7bd5bc04c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -31,6 +31,7 @@ Top-level functions ones_like cov corr + cross dot polyval map_blocks diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c89c41da0b1..8a2c7863dc2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,7 +21,8 @@ v0.18.3 (unreleased) New Features ~~~~~~~~~~~~ - +- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). + By `Jimmy Westling `_. - Added :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` (:issue:`5454`, :pull:`5475`). By `Deepak Cherian `_. - Xarray now uses consolidated metadata by default when writing and reading Zarr From ea062e6de0aed604d52469014a9e40071327d75c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 24 Jun 2021 21:09:27 +0200 Subject: [PATCH 073/100] Update whats-new.rst --- doc/whats-new.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8a2c7863dc2..a95c5eff5a6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,7 @@ v0.18.3 (unreleased) New Features ~~~~~~~~~~~~ + - New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). By `Jimmy Westling `_. - Added :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` (:issue:`5454`, :pull:`5475`). From 629df59d29fca2b10d5dbb0a1b045de8b912da8c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 26 Jul 2021 23:22:26 +0200 Subject: [PATCH 074/100] Output as dataset if any input is a dataset --- xarray/core/computation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7a8be2f58a0..039a001ffb1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1523,12 +1523,13 @@ def cross( all_dims: List[Hashable] = [] arrays: List[Any] = [a, b] + output_as_dataset = False for i, arr in enumerate(arrays): if isinstance(arr, Dataset): # Turn the dataset to a stacked dataarray to follow the # normal code path. Then at the end turn it back to a # dataset. - is_dataset = True + output_as_dataset = True arr = arr.to_stacked_array( variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims ).unstack("stacked__dim") @@ -1538,7 +1539,7 @@ def cross( arrays[i] = arr elif isinstance(arr, (DataArray, Variable)): - is_dataset = False + pass else: raise TypeError( "Only xr.DataArray, xr.Dataset and xr.Variable are supported, " @@ -1594,7 +1595,7 @@ def cross( output_dtypes=[np.result_type(*arrays)], ) c = c.transpose(*[d for d in all_dims if d in c.dims]) - if is_dataset: + if output_as_dataset: c = c.stack(stacked__dim=[dim]).to_unstacked_dataset("stacked__dim") c = c.expand_dims( list({d: s for ds in arrays for d, s in ds.sizes.items() if s == 1}) From 972c7dc7f47d303930a5626e7cc230c7d80011bf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 26 Jul 2021 23:27:51 +0200 Subject: [PATCH 075/100] Simplify the if terms instead of using pass. --- xarray/core/computation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 039a001ffb1..56a53acaa28 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1538,9 +1538,7 @@ def cross( arr = arr.chunk({dim: -1}) arrays[i] = arr - elif isinstance(arr, (DataArray, Variable)): - pass - else: + elif not isinstance(arr, (DataArray, Variable)): raise TypeError( "Only xr.DataArray, xr.Dataset and xr.Variable are supported, " f"got {type(arr)}." From 49967d4985b996d595d0482853f9caa1260c45e4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 22:31:35 +0200 Subject: [PATCH 076/100] Update computation.py --- xarray/core/computation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 79fa5431d42..c49d0bfa9f2 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1387,11 +1387,7 @@ def _get_valid_values(da, other): return corr -def cross( - a: T_DSorDAorVar, - b: T_DSorDAorVar, - dim: Hashable, -) -> T_DSorDAorVar: +def cross(a: T_Xarray, b: T_Xarray, dim: Hashable) -> T_Xarray: """ Return the cross product of two (arrays of) vectors. From 6ab7d193b0d8e7dbeb4e30aa66830113b575daf7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 23:26:41 +0200 Subject: [PATCH 077/100] Remove support for datasets --- xarray/core/computation.py | 49 +++++++++++++------------------- xarray/tests/test_computation.py | 8 ------ 2 files changed, 19 insertions(+), 38 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c49d0bfa9f2..a897bad1c8e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from .coordinates import Coordinates from .dataset import Dataset - from .types import T_Xarray + from .types import DaCompatible, T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -1387,7 +1387,7 @@ def _get_valid_values(da, other): return corr -def cross(a: T_Xarray, b: T_Xarray, dim: Hashable) -> T_Xarray: +def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: """ Return the cross product of two (arrays of) vectors. @@ -1401,7 +1401,7 @@ def cross(a: T_Xarray, b: T_Xarray, dim: Hashable) -> T_Xarray: Parameters ---------- - a, b : DataArray, Dataset or Variable + a, b : DataArray or Variable Components of the first and second vector(s). dim : hashable The dimension along which the cross product will be computed. @@ -1500,36 +1500,30 @@ def cross(a: T_Xarray, b: T_Xarray, dim: Hashable) -> T_Xarray: * time (time) int64 0 1 * cartesian (cartesian) >> ds_a = xr.Dataset(data_vars=dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) + >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) + >>> c = xr.cross(ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian") + >>> ds_c = c.to_dataset(dim="cartesian") + >>> print(ds_c) + + Dimensions: (dim_0: 1) + Dimensions without coordinates: dim_0 + Data variables: + x (dim_0) int32 -3 + y (dim_0) int32 6 + z (dim_0) int32 -3 + See Also -------- numpy.cross : Corresponding numpy function """ - from .dataarray import DataArray - from .dataset import Dataset all_dims: List[Hashable] = [] arrays: List[Any] = [a, b] - output_as_dataset = False for i, arr in enumerate(arrays): - if isinstance(arr, Dataset): - # Turn the dataset to a stacked dataarray to follow the - # normal code path. Then at the end turn it back to a - # dataset. - output_as_dataset = True - arr = arr.to_stacked_array( - variable_dim=dim, new_dim="stacked__dim", sample_dims=arr.dims - ).unstack("stacked__dim") - - if is_duck_dask_array(arr.data): - arr = arr.chunk({dim: -1}) - - arrays[i] = arr - elif not isinstance(arr, (DataArray, Variable)): - raise TypeError( - "Only xr.DataArray, xr.Dataset and xr.Variable are supported, " - f"got {type(arr)}." - ) - # TODO: Find spatial dim default by looking for unique # (3 or 2)-valued dim? if dim not in arr.dims: @@ -1579,11 +1573,6 @@ def cross(a: T_Xarray, b: T_Xarray, dim: Hashable) -> T_Xarray: output_dtypes=[np.result_type(*arrays)], ) c = c.transpose(*[d for d in all_dims if d in c.dims]) - if output_as_dataset: - c = c.stack(stacked__dim=[dim]).to_unstacked_dataset("stacked__dim") - c = c.expand_dims( - list({d: s for ds in arrays for d, s in ds.sizes.items() if s == 1}) - ) return c diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1a3b6b561b8..c865db71a67 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1968,14 +1968,6 @@ def test_polyval(use_dask, use_datetime) -> None: "dim_0", -1, ], - [ - xr.Dataset({0: ("dim_0", [1]), 1: ("dim_0", [2]), 2: ("dim_0", [3])}), - xr.Dataset({0: ("dim_0", [4]), 1: ("dim_0", [5]), 2: ("dim_0", [6])}), - [1, 2, 3], - [4, 5, 6], - "cartesian", - -1, - ], [ # Test dim in the middle: xr.DataArray( np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)), From 20a6cb619529e7df39024a143ab81cbc2f3e6a40 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 23:31:43 +0200 Subject: [PATCH 078/100] Update computation.py --- xarray/core/computation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a897bad1c8e..91d75d68c67 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1503,9 +1503,11 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: Cross can used by on Datasets by converting to DataArrays and then back to Datasets: - >>> ds_a = xr.Dataset(data_vars=dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) + >>> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) - >>> c = xr.cross(ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian") + >>> c = xr.cross( + ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian" + ... ) >>> ds_c = c.to_dataset(dim="cartesian") >>> print(ds_c) From ba3fa9c41a44fcbe84bc90ff3a73319327f42883 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 23:38:17 +0200 Subject: [PATCH 079/100] Add some typing to test. --- xarray/tests/test_computation.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c865db71a67..3161a97a477 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2026,7 +2026,7 @@ def test_polyval(use_dask, use_datetime) -> None: ], ], ) -def test_cross(a, b, ae, be, dim, axis, use_dask): +def test_cross(a, b, ae, be, dim : str, axis : int, use_dask : bool) -> None: expected = np.cross(ae, be, axis=axis) if use_dask: @@ -2036,14 +2036,4 @@ def test_cross(a, b, ae, be, dim, axis, use_dask): b = b.chunk() actual = xr.cross(a, b, dim=dim) - - if isinstance(actual, xr.Dataset): - actual = ( - actual.to_stacked_array( - variable_dim=dim, new_dim="variable", sample_dims=actual.dims - ) - .unstack("variable") - .squeeze() - ) - xr.testing.assert_duckarray_allclose(expected, actual) From 8b192f22f35e661713bfb72dca0c7f73d49d23f7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 23:40:19 +0200 Subject: [PATCH 080/100] doctest fix --- xarray/core/computation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 91d75d68c67..08b6210a504 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1514,9 +1514,9 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: Dimensions: (dim_0: 1) Dimensions without coordinates: dim_0 Data variables: - x (dim_0) int32 -3 - y (dim_0) int32 6 - z (dim_0) int32 -3 + x (dim_0) int64 -3 + y (dim_0) int64 6 + z (dim_0) int64 -3 See Also -------- From a27965cf5e555cb718bdfb058a7d54718aff47c1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Aug 2021 23:43:41 +0200 Subject: [PATCH 081/100] lint --- xarray/tests/test_computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 3161a97a477..01c21dc0496 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2026,7 +2026,7 @@ def test_polyval(use_dask, use_datetime) -> None: ], ], ) -def test_cross(a, b, ae, be, dim : str, axis : int, use_dask : bool) -> None: +def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None: expected = np.cross(ae, be, axis=axis) if use_dask: From b058084f9be07c920a8b4621b5671ac093b589fb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 3 Oct 2021 17:56:00 +0200 Subject: [PATCH 082/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 08b6210a504..195ceb85401 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1508,8 +1508,7 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: >>> c = xr.cross( ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian" ... ) - >>> ds_c = c.to_dataset(dim="cartesian") - >>> print(ds_c) + >>> c.to_dataset(dim="cartesian") Dimensions: (dim_0: 1) Dimensions without coordinates: dim_0 From f007ed5cec98cbada7f87e5d2271ca4f089bea4d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 21:10:02 +0200 Subject: [PATCH 083/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 195ceb85401..0a88e526990 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1522,21 +1522,17 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: numpy.cross : Corresponding numpy function """ - all_dims: List[Hashable] = [] - arrays: List[Any] = [a, b] - for i, arr in enumerate(arrays): - # TODO: Find spatial dim default by looking for unique - # (3 or 2)-valued dim? - if dim not in arr.dims: - raise ValueError(f"Dimension {dim} not in {arr}.") - - if not 1 <= arr.sizes[dim] <= 3: - raise ValueError( - "Incompatible dimensions for cross product,\n" - "dimension with coords must be 1, 2 or 3." - ) + if dim not in a.dims: + raise ValueError(f"Dimension {dim!r} not on a") + elif dim not in b.dims: + raise ValueError(f"Dimension {dim!r} not on b") - all_dims += [d for d in arr.dims if d not in all_dims] + if not 1 <= a.sizes[dim] <= 3: + raise ValueError(f"The size of {dim!r} on a must be 1, 2, or 3 to be compatible with a cross product but is {a.sizes[dim]}") + elif not 1 <= b.sizes[dim] <= 3: + raise ValueError(f"The size of {dim!r} on b must be 1, 2, or 3 to be compatible with a cross product but is {b.sizes[dim]}") + + all_dims = list(dict.fromkeys(a.dims + b.dims)) if arrays[0].sizes[dim] != arrays[1].sizes[dim]: # Arrays have different sizes. Append zeros where the smaller From e88ae9dfec343db485ad4516f4a6cc719d357791 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 21:44:16 +0200 Subject: [PATCH 084/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 41 +++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0a88e526990..ba6f53a07f1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1534,31 +1534,26 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: all_dims = list(dict.fromkeys(a.dims + b.dims)) - if arrays[0].sizes[dim] != arrays[1].sizes[dim]: - # Arrays have different sizes. Append zeros where the smaller - # array is missing a value, zeros will not affect np.cross: - i = 1 if arrays[0].sizes[dim] > arrays[1].sizes[dim] else 0 - - if all( - # The variable check is only used to make mypy happy: - getattr(arr, "coords", False) and not isinstance(arr, Variable) - for arr in arrays - ): - # If the arrays have coords we know which indexes to fill - # with zeros: - arrays[i] = arrays[i].reindex_like(arrays[1 - i], fill_value=0) - elif arrays[i].sizes[dim] == 2: - # If the array doesn't have coords we can only infer - # that it is composite values if the size is 2: - arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0) - if is_duck_dask_array(arrays[i].data): - arrays[i] = arrays[i].chunk({dim: -1}) + if a.sizes[dim] != b.sizes[dim]: + if dim in getattr(a, "coords", {}) and dim in getattr(b, "coords", {}): + # align with a fill value of 0 + a, b = xr.align( + a, + b, + fill_value=0, + join="outer", + exclude=set(all_dims) - {dim}, + ) + elif min(a.sizes[dim], b.sizes[dim]) == 2: + # coords for dim are missing on one array or both + if a.sizes[dim] < b.sizes[dim]: + a = a.pad({dim: (0, 1)}, constant_values=0) + else: + b = b.pad({dim: (0, 1)}, constant_values=0) else: - # Size is 1, then we do not know if the array is a constant or - # composite value: raise ValueError( - "Incompatible dimensions for cross product,\n" - "dimension without coords must be 2 or 3." + f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:" + " dimensions without coordinates must have have a length of 2 or 3" ) c = apply_ufunc( From 9aaee2be70769de2b631bf2b17550fd345d21449 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 22:08:56 +0200 Subject: [PATCH 085/100] Update computation.py --- xarray/core/computation.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ba6f53a07f1..5a868dd8136 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1528,16 +1528,26 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: raise ValueError(f"Dimension {dim!r} not on b") if not 1 <= a.sizes[dim] <= 3: - raise ValueError(f"The size of {dim!r} on a must be 1, 2, or 3 to be compatible with a cross product but is {a.sizes[dim]}") + raise ValueError( + f"The size of {dim!r} on a must be 1, 2, or 3 to be " + f"compatible with a cross product but is {a.sizes[dim]}" + ) elif not 1 <= b.sizes[dim] <= 3: - raise ValueError(f"The size of {dim!r} on b must be 1, 2, or 3 to be compatible with a cross product but is {b.sizes[dim]}") + raise ValueError( + f"The size of {dim!r} on b must be 1, 2, or 3 to be " + f"compatible with a cross product but is {b.sizes[dim]}" + ) all_dims = list(dict.fromkeys(a.dims + b.dims)) if a.sizes[dim] != b.sizes[dim]: + # Arrays have different sizes. Append zeros where the smaller + # array is missing a value, zeros will not affect np.cross: + if dim in getattr(a, "coords", {}) and dim in getattr(b, "coords", {}): - # align with a fill value of 0 - a, b = xr.align( + # If the arrays have coords we know which indexes to fill + # with zeros: + a, b = align( a, b, fill_value=0, @@ -1545,11 +1555,16 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: exclude=set(all_dims) - {dim}, ) elif min(a.sizes[dim], b.sizes[dim]) == 2: - # coords for dim are missing on one array or both + # If the array doesn't have coords we can only infer + # that it has composite values if the size is at least 2. + # Once padded, rechunk the padded array because apply_ufunc + # requires core dimensions not to be chunked: if a.sizes[dim] < b.sizes[dim]: a = a.pad({dim: (0, 1)}, constant_values=0) + a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a else: b = b.pad({dim: (0, 1)}, constant_values=0) + b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b else: raise ValueError( f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:" @@ -1558,11 +1573,12 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: c = apply_ufunc( np.cross, - *arrays, + a, + b, input_core_dims=[[dim], [dim]], - output_core_dims=[[dim] if arrays[0].sizes[dim] == 3 else []], + output_core_dims=[[dim] if a.sizes[dim] == 3 else []], dask="parallelized", - output_dtypes=[np.result_type(*arrays)], + output_dtypes=[np.result_type(*a, b)], ) c = c.transpose(*[d for d in all_dims if d in c.dims]) From 5d6ecba62c124d7183e25cf5351478248e9db2d4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 22:57:19 +0200 Subject: [PATCH 086/100] Update computation.py --- xarray/core/computation.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 5a868dd8136..99bc0d92188 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from .coordinates import Coordinates from .dataset import Dataset - from .types import DaCompatible, T_Xarray + from .types import T_DataArray, T_Variable, T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -1387,7 +1387,9 @@ def _get_valid_values(da, other): return corr -def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: +def cross( + a: Union[T_DataArray, T_Variable], b: Union[T_DataArray, T_Variable], dim: Hashable +) -> Union[T_DataArray, T_Variable]: """ Return the cross product of two (arrays of) vectors. @@ -1409,7 +1411,7 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: Examples -------- - Vector cross-product with 3 dimensions. + Vector cross-product with 3 dimensions: >>> a = xr.DataArray([1, 2, 3]) >>> b = xr.DataArray([4, 5, 6]) @@ -1437,7 +1439,7 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: array([ 0, 0, -3]) Dimensions without coordinates: dim_0 - One vector with dimension 2. + One vector with dimension 2: >>> a = xr.DataArray( ... [1, 2], @@ -1455,7 +1457,7 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: Coordinates: * cartesian (cartesian) object 'x' 'y' 'z' - One vector with dimension 2 but coords in other positions. + One vector with dimension 2 but coords in other positions: >>> a = xr.DataArray( ... [1, 2], @@ -1471,10 +1473,10 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: array([-10, 2, 5]) Coordinates: - * cartesian (cartesian) object 'x' 'y' 'z' + * cartesian (cartesian) >> a = xr.DataArray( ... [[1, 2, 3], [4, 5, 6]], @@ -1578,7 +1580,7 @@ def cross(a: DaCompatible, b: DaCompatible, dim: Hashable) -> DaCompatible: input_core_dims=[[dim], [dim]], output_core_dims=[[dim] if a.sizes[dim] == 3 else []], dask="parallelized", - output_dtypes=[np.result_type(*a, b)], + output_dtypes=[np.result_type(a, b)], ) c = c.transpose(*[d for d in all_dims if d in c.dims]) From 71fc9c14e0416de603c8e0639d7db3493042408a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 23:10:24 +0200 Subject: [PATCH 087/100] Update computation.py --- xarray/core/computation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 99bc0d92188..a7bcdc18cb3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1546,7 +1546,12 @@ def cross( # Arrays have different sizes. Append zeros where the smaller # array is missing a value, zeros will not affect np.cross: - if dim in getattr(a, "coords", {}) and dim in getattr(b, "coords", {}): + if ( + isinstance(a, T_DataArray) # Only used to make mypy happy. + and dim in getattr(a, "coords", {}) + and isinstance(b, T_DataArray) # Only used to make mypy happy. + and dim in getattr(b, "coords", {}) + ): # If the arrays have coords we know which indexes to fill # with zeros: a, b = align( From a98b2e3d03373487fc687bb02dc22b3e56386e8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 5 Oct 2021 23:15:28 +0200 Subject: [PATCH 088/100] Update computation.py --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a7bcdc18cb3..d97be83560d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1547,9 +1547,9 @@ def cross( # array is missing a value, zeros will not affect np.cross: if ( - isinstance(a, T_DataArray) # Only used to make mypy happy. + not isinstance(a, Variable) # Only used to make mypy happy. and dim in getattr(a, "coords", {}) - and isinstance(b, T_DataArray) # Only used to make mypy happy. + and not isinstance(b, Variable) # Only used to make mypy happy. and dim in getattr(b, "coords", {}) ): # If the arrays have coords we know which indexes to fill From c95817b346881c988f6202edac5ab9ff339a005f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Oct 2021 07:31:16 +0200 Subject: [PATCH 089/100] Update computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index d97be83560d..67e5ee083f6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1455,7 +1455,7 @@ def cross( array([12, -6, -3]) Coordinates: - * cartesian (cartesian) object 'x' 'y' 'z' + * cartesian (cartesian) Date: Thu, 7 Oct 2021 20:32:21 +0200 Subject: [PATCH 090/100] Can't narrow types with old type Seems using bounds in typevar makes it impossible to narrow the type using isinstance checks. --- xarray/core/computation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 67e5ee083f6..8f973331e64 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -36,8 +36,9 @@ if TYPE_CHECKING: from .coordinates import Coordinates + from .dataarray import DataArray from .dataset import Dataset - from .types import T_DataArray, T_Variable, T_Xarray + from .types import T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -1388,8 +1389,8 @@ def _get_valid_values(da, other): def cross( - a: Union[T_DataArray, T_Variable], b: Union[T_DataArray, T_Variable], dim: Hashable -) -> Union[T_DataArray, T_Variable]: + a: Union[DataArray, Variable], b: Union[DataArray, Variable], dim: Hashable +) -> Union[DataArray, Variable]: """ Return the cross product of two (arrays of) vectors. From 316b93531fa1dae09628631c335e4e66e22b7fe2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 7 Oct 2021 20:37:13 +0200 Subject: [PATCH 091/100] dim now keyword only --- xarray/core/computation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 8f973331e64..0461f74bc69 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1389,7 +1389,7 @@ def _get_valid_values(da, other): def cross( - a: Union[DataArray, Variable], b: Union[DataArray, Variable], dim: Hashable + a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable ) -> Union[DataArray, Variable]: """ Return the cross product of two (arrays of) vectors. @@ -1416,7 +1416,7 @@ def cross( >>> a = xr.DataArray([1, 2, 3]) >>> b = xr.DataArray([4, 5, 6]) - >>> xr.cross(a, b, "dim_0") + >>> xr.cross(a, b, dim="dim_0") array([-3, 6, -3]) Dimensions without coordinates: dim_0 @@ -1426,7 +1426,7 @@ def cross( >>> a = xr.DataArray([1, 2]) >>> b = xr.DataArray([4, 5]) - >>> xr.cross(a, b, "dim_0") + >>> xr.cross(a, b, dim="dim_0") array(-3) @@ -1435,7 +1435,7 @@ def cross( >>> a = xr.DataArray([1, 2, 0]) >>> b = xr.DataArray([4, 5, 0]) - >>> xr.cross(a, b, "dim_0") + >>> xr.cross(a, b, dim="dim_0") array([ 0, 0, -3]) Dimensions without coordinates: dim_0 @@ -1452,7 +1452,7 @@ def cross( ... dims=["cartesian"], ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) - >>> xr.cross(a, b, "cartesian") + >>> xr.cross(a, b, dim="cartesian") array([12, -6, -3]) Coordinates: @@ -1470,7 +1470,7 @@ def cross( ... dims=["cartesian"], ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) - >>> xr.cross(a, b, "cartesian") + >>> xr.cross(a, b, dim="cartesian") array([-10, 2, 5]) Coordinates: @@ -1495,7 +1495,7 @@ def cross( ... cartesian=(["cartesian"], ["x", "y", "z"]), ... ), ... ) - >>> xr.cross(a, b, "cartesian") + >>> xr.cross(a, b, dim="cartesian") array([[-3, 6, -3], [ 3, -6, 3]]) From 3b5b0306618eeb102061d026fdb6d43f8c8e050c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 7 Oct 2021 20:47:23 +0200 Subject: [PATCH 092/100] use all_dims in transpose --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0461f74bc69..cb39f5e4765 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1588,7 +1588,7 @@ def cross( dask="parallelized", output_dtypes=[np.result_type(a, b)], ) - c = c.transpose(*[d for d in all_dims if d in c.dims]) + c = c.transpose(*all_dims) return c From 34b300de7c8ed185bf30a78461223ea0fd0e1dd3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 7 Oct 2021 21:52:50 +0200 Subject: [PATCH 093/100] if in transpose indeed needed if a and b has size 2 it's needed. --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cb39f5e4765..0461f74bc69 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1588,7 +1588,7 @@ def cross( dask="parallelized", output_dtypes=[np.result_type(a, b)], ) - c = c.transpose(*all_dims) + c = c.transpose(*[d for d in all_dims if d in c.dims]) return c From cf13bf93387db649c445056a8a0716549a5cd09c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 18:00:26 +0200 Subject: [PATCH 094/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0461f74bc69..c25ef50d4e2 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1503,8 +1503,8 @@ def cross( * time (time) int64 0 1 * cartesian (cartesian) >> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) From f2167a677841b6e0a8a83c84473aa92aa044031d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 18:02:31 +0200 Subject: [PATCH 095/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c25ef50d4e2..69ecc92a748 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1392,7 +1392,7 @@ def cross( a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable ) -> Union[DataArray, Variable]: """ - Return the cross product of two (arrays of) vectors. + Compute the cross product of two (arrays of) vectors. The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular to both `a` and `b`. If `a` and `b` are arrays of From 570a806f05e02bb5dca7b370e6801ef574943e69 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 18:07:27 +0200 Subject: [PATCH 096/100] Update xarray/core/computation.py Co-authored-by: keewis --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 69ecc92a748..119d21f8e13 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1588,7 +1588,7 @@ def cross( dask="parallelized", output_dtypes=[np.result_type(a, b)], ) - c = c.transpose(*[d for d in all_dims if d in c.dims]) + c = c.transpose(*all_dims, missing_dims="ignore") return c From 6f57ed61a5bb2b602a0a555e891a642d589fd9ac Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 20:29:58 +0200 Subject: [PATCH 097/100] Update computation.py --- xarray/core/computation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 119d21f8e13..013ece66314 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1395,12 +1395,13 @@ def cross( Compute the cross product of two (arrays of) vectors. The cross product of `a` and `b` in :math:`R^3` is a vector - perpendicular to both `a` and `b`. If `a` and `b` are arrays of - vectors, and these axes can have dimensions 2 or 3. Where the - dimension of either `a` or `b` is 2, the third component of the - input vector is assumed to be zero and the cross product calculated - accordingly. In cases where both input vectors have dimension 2, - the z-component of the cross product is returned. + perpendicular to both `a` and `b`. The vectors in `a` and `b` are + defined by the values along the dimension `dim` and these axes can + have sizes 1, 2 or 3. Where the size of either `a` or `b` is + 1 or 2, the remaining components of the input vector is assumed to + be zero and the cross product calculated accordingly. In cases where + both input vectors have dimension 2, the z-component of the cross + product is returned. Parameters ---------- From 52a986b2550c77e0aff4ad50c8f609d82dcc6e8b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 20:31:29 +0200 Subject: [PATCH 098/100] Update computation.py --- xarray/core/computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 013ece66314..9d74f8ceed2 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1396,8 +1396,8 @@ def cross( The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular to both `a` and `b`. The vectors in `a` and `b` are - defined by the values along the dimension `dim` and these axes can - have sizes 1, 2 or 3. Where the size of either `a` or `b` is + defined by the values along the dimension `dim` and can have sizes + 1, 2 or 3. Where the size of either `a` or `b` is 1 or 2, the remaining components of the input vector is assumed to be zero and the cross product calculated accordingly. In cases where both input vectors have dimension 2, the z-component of the cross From fa78e741c3f7f0f3af15f4663c71fac843abb87b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 10 Oct 2021 20:35:05 +0200 Subject: [PATCH 099/100] add todo comments --- xarray/core/computation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9d74f8ceed2..1925325750d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1570,9 +1570,11 @@ def cross( # requires core dimensions not to be chunked: if a.sizes[dim] < b.sizes[dim]: a = a.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a else: b = b.pad({dim: (0, 1)}, constant_values=0) + # TODO: Should pad or apply_ufunc handle correct chunking? b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b else: raise ValueError( From 70d2a4ba235b6f8ea281711fd2f4be488c6660b4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 27 Dec 2021 01:39:16 +0100 Subject: [PATCH 100/100] Update whats-new.rst --- doc/whats-new.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e8faf6cc6cd..bd6097d61fe 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,7 +23,6 @@ New Features ~~~~~~~~~~~~ - New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`). By `Jimmy Westling `_. -- Add :py:meth:`var`, :py:meth:`std` and :py:meth:`sum_of_squares` to :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted`. Breaking changes