Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cross product #5365

Merged
merged 120 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
1490c16
Add support for cross
Illviljan May 23, 2021
03db734
Update test_computation.py
Illviljan May 23, 2021
c824e36
Update computation.py
Illviljan May 23, 2021
7ce39c7
Update computation.py
Illviljan May 23, 2021
916e661
Update test_computation.py
Illviljan May 23, 2021
654ad60
Update test_computation.py
Illviljan May 23, 2021
a6ac578
Update test_computation.py
Illviljan May 23, 2021
e0c1fac
add more tests
Illviljan May 23, 2021
7aebae7
Update xarray/core/computation.py
Illviljan May 23, 2021
b85e236
Merge branch 'master' into Illviljan-cross
Illviljan May 23, 2021
2b54a42
Merge branch 'Illviljan-cross' of https://github.com/Illviljan/xarray…
Illviljan May 23, 2021
be7b2c2
spatial_dim to dim
Illviljan May 23, 2021
4448006
Update computation.py
Illviljan May 23, 2021
af8b09c
use pad instead of concat
Illviljan May 23, 2021
a135e05
copy paste np.cross intro
Illviljan May 23, 2021
6f17b9b
Get last dim for each array, which is more inline with np.cross
Illviljan May 23, 2021
1fadb5f
examples in docs
Illviljan May 23, 2021
57239a4
Update computation.py
Illviljan May 23, 2021
265ef82
more doc examples
Illviljan May 23, 2021
dd60562
single dim required, tranpose after apply_ufunc
Illviljan May 24, 2021
a20cb86
add dims to tests
Illviljan May 24, 2021
7ce9315
Update computation.py
Illviljan May 24, 2021
d5a0ea8
reduce code
Illviljan May 25, 2021
ef94fa4
support xr.Variable
Illviljan May 25, 2021
1a85147
Update computation.py
Illviljan May 25, 2021
2ce3dbe
Update computation.py
Illviljan May 25, 2021
53c84c2
reduce code
Illviljan May 25, 2021
dded720
docstring explanations
Illviljan May 25, 2021
7058166
Use same terms
Illviljan May 25, 2021
cb57a55
docstring formatting
Illviljan May 25, 2021
e69ca81
reduce code
Illviljan May 25, 2021
4b2fc72
add tests for dask
Illviljan May 25, 2021
afe572d
simplify check, align used variables
Illviljan May 26, 2021
e137350
trim down tests
Illviljan May 26, 2021
1a26324
Update computation.py
Illviljan May 26, 2021
531a98b
simplify code
Illviljan May 27, 2021
2146406
Add type hints
Illviljan May 28, 2021
0940472
less type hints
Illviljan May 28, 2021
a7cc565
Update computation.py
Illviljan May 28, 2021
1d1f205
undo type hints
Illviljan May 28, 2021
9af7091
Update computation.py
Illviljan May 28, 2021
14decb3
Add support for datasets
Illviljan May 30, 2021
6f73c32
determine dtype with np.result_type
Illviljan Jun 2, 2021
72330ce
test datasets, daskify the inputs not the results
Illviljan Jun 6, 2021
bce2f3e
rechunk padded values, handle 1 sized datasets
Illviljan Jun 6, 2021
1636d25
expand only unique dims, squeeze out dims in tests
Illviljan Jun 6, 2021
b5b97a0
rechunk along the dim
Illviljan Jun 6, 2021
f77780f
Merge branch 'master' into Illviljan-cross
Illviljan Jun 7, 2021
02364ca
Attempt typing again
Illviljan Jun 17, 2021
e842c75
Merge branch 'master' into Illviljan-cross
Illviljan Jun 17, 2021
ed44400
Update __init__.py
Illviljan Jun 17, 2021
4fe9737
Update computation.py
Illviljan Jun 17, 2021
ec05780
Update computation.py
Illviljan Jun 17, 2021
36c5956
test fixing type in to_stacked_array
Illviljan Jun 17, 2021
cbf289c
test fixing to_stacked_array
Illviljan Jun 17, 2021
4cfd5be
small is large
Illviljan Jun 18, 2021
658a59f
Update computation.py
Illviljan Jun 18, 2021
ab5ae20
Update xarray/core/computation.py
Illviljan Jun 18, 2021
d65ca41
obfuscate variable_dim some
Illviljan Jun 19, 2021
20eef03
Update computation.py
Illviljan Jun 19, 2021
274af32
undo to_stacked_array changes
Illviljan Jun 19, 2021
f352303
test sample_dims typing
Illviljan Jun 19, 2021
0a773cb
to_stacked_array fixes
Illviljan Jun 19, 2021
d8da29f
add reindex_like check
Illviljan Jun 19, 2021
54a76c1
Update computation.py
Illviljan Jun 20, 2021
0a2dc2e
Update computation.py
Illviljan Jun 20, 2021
b3592f3
Update computation.py
Illviljan Jun 20, 2021
06772da
test forcing int type in chunk()
Illviljan Jun 20, 2021
cfd11f7
Update computation.py
Illviljan Jun 20, 2021
8451a9e
Merge branch 'master' into Illviljan-cross
Illviljan Jun 21, 2021
90553ed
test collection in to_stacked_array
Illviljan Jun 21, 2021
6eed96e
Update computation.py
Illviljan Jun 21, 2021
d3648e5
Update computation.py
Illviljan Jun 22, 2021
c639aa3
Update computation.py
Illviljan Jun 22, 2021
4c636f5
Update computation.py
Illviljan Jun 22, 2021
3bea936
Update computation.py
Illviljan Jun 22, 2021
4fc7fcb
Merge branch 'master' into Illviljan-cross
Illviljan Jun 23, 2021
19e8f93
Merge branch 'master' into Illviljan-cross
Illviljan Jun 24, 2021
f71a6f1
Merge branch 'main' into Illviljan-cross
Illviljan Jun 24, 2021
d4070ab
Merge branch 'main' into Illviljan-cross
Illviljan Jun 24, 2021
12da913
whats new and api.rst
Illviljan Jun 24, 2021
ea062e6
Update whats-new.rst
Illviljan Jun 24, 2021
ebd89e6
Merge branch 'main' into Illviljan-cross
Illviljan Jul 2, 2021
3c7122b
Merge branch 'main' into Illviljan-cross
Illviljan Jul 5, 2021
9af1198
Merge branch 'main' into Illviljan-cross
Illviljan Jul 18, 2021
27262e6
Merge branch 'main' into Illviljan-cross
Illviljan Jul 22, 2021
cc91e7c
Merge branch 'main' into Illviljan-cross
Illviljan Jul 25, 2021
629df59
Output as dataset if any input is a dataset
Illviljan Jul 26, 2021
972c7dc
Simplify the if terms instead of using pass.
Illviljan Jul 26, 2021
3c4ace0
Merge branch 'main' into Illviljan-cross
Illviljan Aug 30, 2021
49967d4
Update computation.py
Illviljan Aug 30, 2021
6ab7d19
Remove support for datasets
Illviljan Aug 30, 2021
20a6cb6
Update computation.py
Illviljan Aug 30, 2021
ba3fa9c
Add some typing to test.
Illviljan Aug 30, 2021
8b192f2
doctest fix
Illviljan Aug 30, 2021
a27965c
lint
Illviljan Aug 30, 2021
5ec65d2
Merge branch 'main' into Illviljan-cross
Illviljan Sep 8, 2021
b058084
Update xarray/core/computation.py
Illviljan Oct 3, 2021
f007ed5
Update xarray/core/computation.py
Illviljan Oct 5, 2021
e88ae9d
Update xarray/core/computation.py
Illviljan Oct 5, 2021
9aaee2b
Update computation.py
Illviljan Oct 5, 2021
5d6ecba
Update computation.py
Illviljan Oct 5, 2021
71fc9c1
Update computation.py
Illviljan Oct 5, 2021
a98b2e3
Update computation.py
Illviljan Oct 5, 2021
c95817b
Update computation.py
Illviljan Oct 6, 2021
408eb39
Can't narrow types with old type
Illviljan Oct 7, 2021
316b935
dim now keyword only
Illviljan Oct 7, 2021
3b5b030
use all_dims in transpose
Illviljan Oct 7, 2021
f9c5404
Merge branch 'main' into Illviljan-cross
Illviljan Oct 7, 2021
34b300d
if in transpose indeed needed
Illviljan Oct 7, 2021
cf13bf9
Update xarray/core/computation.py
Illviljan Oct 10, 2021
f2167a6
Update xarray/core/computation.py
Illviljan Oct 10, 2021
570a806
Update xarray/core/computation.py
Illviljan Oct 10, 2021
6f57ed6
Update computation.py
Illviljan Oct 10, 2021
52a986b
Update computation.py
Illviljan Oct 10, 2021
fa78e74
add todo comments
Illviljan Oct 10, 2021
f2d98b6
Merge branch 'main' into Illviljan-cross
Illviljan Oct 31, 2021
7449cd7
Merge branch 'main' into Illviljan-cross
Illviljan Dec 27, 2021
70d2a4b
Update whats-new.rst
Illviljan Dec 27, 2021
e6020e3
Merge branch 'main' into Illviljan-cross
Illviljan Dec 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, where
from .core.computation import apply_ufunc, corr, cov, cross, dot, polyval, where
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -56,6 +56,7 @@
"dot",
"cov",
"corr",
"cross",
"full_like",
"infer_freq",
"load_dataarray",
Expand Down
131 changes: 131 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,137 @@ def dot(*arrays, dims=None, **kwargs):
return result.transpose(*[d for d in all_dims if d in result.dims])


def cross(a, b, dim=None):
"""
Return the cross product of two (arrays of) vectors.
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular
to both `a` and `b`. If `a` and `b` are arrays of vectors, the vectors
are defined by the last axis of `a` and `b` by default, and these axes
can have dimensions 2 or 3. Where the dimension of either `a` or `b` is
2, the third component of the input vector is assumed to be zero and the
cross product calculated accordingly. In cases where both input vectors
have dimension 2, the z-component of the cross product is returned.

Parameters
----------
a, b : DataArray
something
dim : hashable or tuple of hashable
something

Examples
--------
Vector cross-product.

>>> x = xr.DataArray(np.array([1, 2, 3]))
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
>>> y = xr.DataArray(np.array([4, 5, 6]))
>>> xr.cross(x, y)
array([-3, 6, -3])
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

One vector with dimension 2.

>>> a = xr.DataArray(
... np.array([1, 2]),
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))),
... )
>>> b = xr.DataArray(
... np.array([4, 5, 6]),
... dims=["x"],
... coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))),
... )
>>> xr.cross(a, b)
array([12, -6, -3])

Multiple vector cross-products. Note that the direction of the
cross product vector is defined by the right-hand rule.

>>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6]]), dims=("a", "b"))
>>> y = xr.DataArray(np.array([[4, 5, 6], [1, 2, 3]]), dims=("a", "b"))
>>> xr.cross(x, y)
array([[-3, 6, -3],
[ 3, -6, 3]])

Change the vector definition of x and y using axisa and axisb.

>>> x = xr.DataArray(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
>>> y = xr.DataArray(np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]]))
>>> np.cross(x, y)
array([[ -6, 12, -6],
[ 0, 0, 0],
[ 6, -12, 6]])
>>> np.cross(x, y, axisa=0, axisb=0)
array([[-24, 48, -24],
[-30, 60, -30],
[-36, 72, -36]])

See Also
--------
numpy.cross : Corresponding numpy function
"""
from .dataarray import DataArray

arrays = [a, b]
for arr in arrays:
if not isinstance(arr, (DataArray)):
raise TypeError(
f"Only xr.DataArray and xr.Variable are supported, got {type(arr)}."
)

if dim is None:
# TODO: Find spatial dim default by looking for unique
# (3 or 2)-valued dim?
dim = arr.dims[-1]
elif dim not in arr.dims:
raise ValueError(f"Dimension {dim} not in {arr}.")

s = arr.sizes[dim]
if s < 1 or s > 3:
raise ValueError(
"incompatible dimensions for cross product\n"
"(dimension with coords must be 1, 2 or 3)"
)

if a.sizes[dim] == b.sizes[dim]:
# Arrays have the same size, no need to do anything:
pass
else:
# Arrays have different sizes. Append zeros where the smaller
# array is missing a value, zeros will not affect np.cross:
ind = 1 if a.sizes[dim] > b.sizes[dim] else 0

if a.coords:
# If the array has coords we know which indexes to fill
# with zeros:
arrays[ind] = arrays[ind].reindex_like(arrays[1 - ind], fill_value=0)
elif arrays[ind].sizes[dim] > 1:
# If it doesn't have coords we can can only that infer that
# it is composite values if the size is 2.
arrays[ind] = arrays[ind].pad({dim: (0, 1)}, constant_values=0)
else:
# Size is 1, then we do not know if it is a constant or
# composite value:
raise ValueError(
"incompatible dimensions for cross product\n"
"(dimension without coords must be 2 or 3)"
)

# Figure out the output dtype:
output_dtype = np.cross(
np.empty((2, 2), dtype=arrays[0].dtype), np.empty((2, 2), dtype=arrays[1].dtype)
).dtype

return apply_ufunc(
np.cross,
*arrays,
input_core_dims=[[dim], [dim]],
output_core_dims=[[dim]],
dask="parallelized",
output_dtypes=[output_dtype],
)


def where(cond, x, y):
"""Return elements from `x` or `y` depending on `cond`.

Expand Down
83 changes: 83 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,3 +1898,86 @@ def test_polyval(use_dask, use_datetime):
da_pv = xr.polyval(da.x, coeffs)

xr.testing.assert_allclose(da, da_pv.T)


@pytest.mark.parametrize(
"a, b, ae, be, dim, axis",
[
[
xr.DataArray(np.array([1, 2, 3])),
xr.DataArray(np.array([4, 5, 6])),
np.array([1, 2, 3]),
np.array([4, 5, 6]),
None,
-1,
],
[
xr.DataArray(np.array([1, 2])),
xr.DataArray(np.array([4, 5, 6])),
np.array([1, 2]),
np.array([4, 5, 6]),
None,
-1,
],
[ # Test dim in the middle:
xr.DataArray(
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], np.array(["x", "y", "z"])),
var=(["var"], np.array([1, 1.5, 2, 2.5])),
),
),
xr.DataArray(
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], np.array(["x", "y", "z"])),
var=(["var"], np.array([1, 1.5, 2, 2.5])),
),
),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
"cartesian",
1,
],
[ # Test 1 sized arrays with coords:
xr.DataArray(
np.array([1]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], np.array(["z"]))),
),
xr.DataArray(
np.array([4, 5, 6]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))),
),
np.array([0, 0, 1]),
np.array([4, 5, 6]),
None,
-1,
],
[ # Test filling inbetween with coords:
xr.DataArray(
np.array([1, 2]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], np.array(["x", "z"]))),
),
xr.DataArray(
np.array([4, 5, 6]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], np.array(["x", "y", "z"]))),
),
np.array([1, 0, 2]),
np.array([4, 5, 6]),
None,
-1,
],
],
)
def test_cross(a, b, ae, be, dim, axis):
expected = np.cross(ae, be, axis=axis)
actual = xr.cross(a, b, dim=dim)
xr.testing.assert_duckarray_allclose(expected, actual)