Skip to content

Commit

Permalink
Add support for cross product (#5365)
Browse files Browse the repository at this point in the history
* Add support for cross

* Update test_computation.py

* Update computation.py

* Update computation.py

* Update test_computation.py

* Update test_computation.py

* Update test_computation.py

* add more tests

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* spatial_dim to dim

* Update computation.py

* use pad instead of concat

* copy paste np.cross intro

* Get last dim for each array, which is more inline with np.cross

* examples in docs

* Update computation.py

* more doc examples

* single dim required, tranpose after apply_ufunc

* add dims to tests

* Update computation.py

* reduce code

* support xr.Variable

* Update computation.py

* Update computation.py

* reduce code

* docstring explanations

* Use same terms

* docstring formatting

* reduce code

* add tests for dask

* simplify check, align used variables

* trim down tests

* Update computation.py

* simplify code

* Add type hints

* less type hints

* Update computation.py

* undo type hints

* Update computation.py

* Add support for datasets

* determine dtype with np.result_type

* test datasets, daskify the inputs not the results

* rechunk padded values, handle 1 sized datasets

* expand only unique dims, squeeze out dims in tests

* rechunk along the dim

* Attempt typing again

* Update __init__.py

* Update computation.py

* Update computation.py

* test fixing type in to_stacked_array

* test fixing to_stacked_array

* small is large

* Update computation.py

* Update xarray/core/computation.py

Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>

* obfuscate variable_dim some

* Update computation.py

* undo to_stacked_array changes

* test sample_dims typing

* to_stacked_array fixes

* add reindex_like check

* Update computation.py

* Update computation.py

* Update computation.py

* test forcing int type in chunk()

* Update computation.py

* test collection in to_stacked_array

* Update computation.py

* Update computation.py

* Update computation.py

* Update computation.py

* Update computation.py

* whats new and api.rst

* Update whats-new.rst

* Output as dataset if any input is a dataset

* Simplify the if terms instead of using pass.

* Update computation.py

* Remove support for datasets

* Update computation.py

* Add some typing to test.

* doctest fix

* lint

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update computation.py

* Update computation.py

* Update computation.py

* Update computation.py

* Update computation.py

* Can't narrow types with old type

Seems using bounds in typevar makes it impossible to narrow the type using isinstance checks.

* dim now keyword only

* use all_dims in transpose

* if in transpose indeed needed

if a and b has size 2 it's needed.

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update xarray/core/computation.py

Co-authored-by: keewis <keewis@users.noreply.github.com>

* Update computation.py

* Update computation.py

* add todo comments

* Update whats-new.rst

Co-authored-by: keewis <keewis@users.noreply.github.com>
Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 29, 2021
1 parent 92ac89f commit 379b5b7
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Top-level functions
ones_like
cov
corr
cross
dot
polyval
map_blocks
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ v0.21.0 (unreleased)

New Features
~~~~~~~~~~~~
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
By `Jimmy Westling <https://github.com/illviljan>`_.


Breaking changes
Expand Down
12 changes: 11 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
from .core.computation import (
apply_ufunc,
corr,
cov,
cross,
dot,
polyval,
unify_chunks,
where,
)
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -60,6 +69,7 @@
"dot",
"cov",
"corr",
"cross",
"full_like",
"get_options",
"infer_freq",
Expand Down
209 changes: 209 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_Xarray

Expand Down Expand Up @@ -1373,6 +1374,214 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
return corr


def cross(
a: Union[DataArray, Variable], b: Union[DataArray, Variable], *, dim: Hashable
) -> Union[DataArray, Variable]:
"""
Compute the cross product of two (arrays of) vectors.
The cross product of `a` and `b` in :math:`R^3` is a vector
perpendicular to both `a` and `b`. The vectors in `a` and `b` are
defined by the values along the dimension `dim` and can have sizes
1, 2 or 3. Where the size of either `a` or `b` is
1 or 2, the remaining components of the input vector is assumed to
be zero and the cross product calculated accordingly. In cases where
both input vectors have dimension 2, the z-component of the cross
product is returned.
Parameters
----------
a, b : DataArray or Variable
Components of the first and second vector(s).
dim : hashable
The dimension along which the cross product will be computed.
Must be available in both vectors.
Examples
--------
Vector cross-product with 3 dimensions:
>>> a = xr.DataArray([1, 2, 3])
>>> b = xr.DataArray([4, 5, 6])
>>> xr.cross(a, b, dim="dim_0")
<xarray.DataArray (dim_0: 3)>
array([-3, 6, -3])
Dimensions without coordinates: dim_0
Vector cross-product with 2 dimensions, returns in the perpendicular
direction:
>>> a = xr.DataArray([1, 2])
>>> b = xr.DataArray([4, 5])
>>> xr.cross(a, b, dim="dim_0")
<xarray.DataArray ()>
array(-3)
Vector cross-product with 3 dimensions but zeros at the last axis
yields the same results as with 2 dimensions:
>>> a = xr.DataArray([1, 2, 0])
>>> b = xr.DataArray([4, 5, 0])
>>> xr.cross(a, b, dim="dim_0")
<xarray.DataArray (dim_0: 3)>
array([ 0, 0, -3])
Dimensions without coordinates: dim_0
One vector with dimension 2:
>>> a = xr.DataArray(
... [1, 2],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y"])),
... )
>>> b = xr.DataArray(
... [4, 5, 6],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
... )
>>> xr.cross(a, b, dim="cartesian")
<xarray.DataArray (cartesian: 3)>
array([12, -6, -3])
Coordinates:
* cartesian (cartesian) <U1 'x' 'y' 'z'
One vector with dimension 2 but coords in other positions:
>>> a = xr.DataArray(
... [1, 2],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "z"])),
... )
>>> b = xr.DataArray(
... [4, 5, 6],
... dims=["cartesian"],
... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
... )
>>> xr.cross(a, b, dim="cartesian")
<xarray.DataArray (cartesian: 3)>
array([-10, 2, 5])
Coordinates:
* cartesian (cartesian) <U1 'x' 'y' 'z'
Multiple vector cross-products. Note that the direction of the
cross product vector is defined by the right-hand rule:
>>> a = xr.DataArray(
... [[1, 2, 3], [4, 5, 6]],
... dims=("time", "cartesian"),
... coords=dict(
... time=(["time"], [0, 1]),
... cartesian=(["cartesian"], ["x", "y", "z"]),
... ),
... )
>>> b = xr.DataArray(
... [[4, 5, 6], [1, 2, 3]],
... dims=("time", "cartesian"),
... coords=dict(
... time=(["time"], [0, 1]),
... cartesian=(["cartesian"], ["x", "y", "z"]),
... ),
... )
>>> xr.cross(a, b, dim="cartesian")
<xarray.DataArray (time: 2, cartesian: 3)>
array([[-3, 6, -3],
[ 3, -6, 3]])
Coordinates:
* time (time) int64 0 1
* cartesian (cartesian) <U1 'x' 'y' 'z'
Cross can be called on Datasets by converting to DataArrays and later
back to a Dataset:
>>> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3])))
>>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6])))
>>> c = xr.cross(
... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian"
... )
>>> c.to_dataset(dim="cartesian")
<xarray.Dataset>
Dimensions: (dim_0: 1)
Dimensions without coordinates: dim_0
Data variables:
x (dim_0) int64 -3
y (dim_0) int64 6
z (dim_0) int64 -3
See Also
--------
numpy.cross : Corresponding numpy function
"""

if dim not in a.dims:
raise ValueError(f"Dimension {dim!r} not on a")
elif dim not in b.dims:
raise ValueError(f"Dimension {dim!r} not on b")

if not 1 <= a.sizes[dim] <= 3:
raise ValueError(
f"The size of {dim!r} on a must be 1, 2, or 3 to be "
f"compatible with a cross product but is {a.sizes[dim]}"
)
elif not 1 <= b.sizes[dim] <= 3:
raise ValueError(
f"The size of {dim!r} on b must be 1, 2, or 3 to be "
f"compatible with a cross product but is {b.sizes[dim]}"
)

all_dims = list(dict.fromkeys(a.dims + b.dims))

if a.sizes[dim] != b.sizes[dim]:
# Arrays have different sizes. Append zeros where the smaller
# array is missing a value, zeros will not affect np.cross:

if (
not isinstance(a, Variable) # Only used to make mypy happy.
and dim in getattr(a, "coords", {})
and not isinstance(b, Variable) # Only used to make mypy happy.
and dim in getattr(b, "coords", {})
):
# If the arrays have coords we know which indexes to fill
# with zeros:
a, b = align(
a,
b,
fill_value=0,
join="outer",
exclude=set(all_dims) - {dim},
)
elif min(a.sizes[dim], b.sizes[dim]) == 2:
# If the array doesn't have coords we can only infer
# that it has composite values if the size is at least 2.
# Once padded, rechunk the padded array because apply_ufunc
# requires core dimensions not to be chunked:
if a.sizes[dim] < b.sizes[dim]:
a = a.pad({dim: (0, 1)}, constant_values=0)
# TODO: Should pad or apply_ufunc handle correct chunking?
a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a
else:
b = b.pad({dim: (0, 1)}, constant_values=0)
# TODO: Should pad or apply_ufunc handle correct chunking?
b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b
else:
raise ValueError(
f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:"
" dimensions without coordinates must have have a length of 2 or 3"
)

c = apply_ufunc(
np.cross,
a,
b,
input_core_dims=[[dim], [dim]],
output_core_dims=[[dim] if a.sizes[dim] == 3 else []],
dask="parallelized",
output_dtypes=[np.result_type(a, b)],
)
c = c.transpose(*all_dims, missing_dims="ignore")

return c


def dot(*arrays, dims=None, **kwargs):
"""Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
Expand Down
107 changes: 107 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,3 +1952,110 @@ def test_polyval(use_dask, use_datetime) -> None:
da_pv = xr.polyval(da.x, coeffs)

xr.testing.assert_allclose(da, da_pv.T)


@pytest.mark.parametrize("use_dask", [False, True])
@pytest.mark.parametrize(
"a, b, ae, be, dim, axis",
[
[
xr.DataArray([1, 2, 3]),
xr.DataArray([4, 5, 6]),
[1, 2, 3],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.DataArray([1, 2]),
xr.DataArray([4, 5, 6]),
[1, 2],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.Variable(dims=["dim_0"], data=[1, 2, 3]),
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
[1, 2, 3],
[4, 5, 6],
"dim_0",
-1,
],
[
xr.Variable(dims=["dim_0"], data=[1, 2]),
xr.Variable(dims=["dim_0"], data=[4, 5, 6]),
[1, 2],
[4, 5, 6],
"dim_0",
-1,
],
[ # Test dim in the middle:
xr.DataArray(
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], ["x", "y", "z"]),
var=(["var"], [1, 1.5, 2, 2.5]),
),
),
xr.DataArray(
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
dims=["time", "cartesian", "var"],
coords=dict(
time=(["time"], np.arange(0, 5)),
cartesian=(["cartesian"], ["x", "y", "z"]),
var=(["var"], [1, 1.5, 2, 2.5]),
),
),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)),
np.arange(0, 5 * 3 * 4).reshape((5, 3, 4)) + 1,
"cartesian",
1,
],
[ # Test 1 sized arrays with coords:
xr.DataArray(
np.array([1]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["z"])),
),
xr.DataArray(
np.array([4, 5, 6]),
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
),
[0, 0, 1],
[4, 5, 6],
"cartesian",
-1,
],
[ # Test filling inbetween with coords:
xr.DataArray(
[1, 2],
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "z"])),
),
xr.DataArray(
[4, 5, 6],
dims=["cartesian"],
coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])),
),
[1, 0, 2],
[4, 5, 6],
"cartesian",
-1,
],
],
)
def test_cross(a, b, ae, be, dim: str, axis: int, use_dask: bool) -> None:
expected = np.cross(ae, be, axis=axis)

if use_dask:
if not has_dask:
pytest.skip("test for dask.")
a = a.chunk()
b = b.chunk()

actual = xr.cross(a, b, dim=dim)
xr.testing.assert_duckarray_allclose(expected, actual)

0 comments on commit 379b5b7

Please sign in to comment.