Skip to content

Commit

Permalink
rechunk padded values, handle 1 sized datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Jun 6, 2021
1 parent 72330ce commit bce2f3e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,8 +1507,9 @@ def cross(a, b, dim):
for i, arr in enumerate(arrays):
if isinstance(arr, Dataset):
is_dataset = True
# TODO: How make sure this temporary dimension is matches
# the orther dataset?
# Turn the dataset to a stacked dataarray to follow the
# normal code path. Then at the end turn it back to a
# dataset.
arrays[i] = arr = arr.to_stacked_array(
variable_dim=dim, new_dim="variable", sample_dims=arr.dims
).unstack("variable")
Expand Down Expand Up @@ -1546,6 +1547,8 @@ def cross(a, b, dim):
# If the array doesn't have coords we can can only infer
# that it is composite values if the size is 2:
arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0)
if is_duck_dask_array(arrays[i].data):
arrays[i] = arrays[i].chunk({dim: -1})
else:
# Size is 1, then we do not know if the array is a constant or
# composite value:
Expand All @@ -1565,6 +1568,9 @@ def cross(a, b, dim):
c = c.transpose(*[d for d in all_dims if d in c.dims])
if is_dataset:
c = c.stack(variable=[dim]).to_unstacked_dataset("variable")
c = c.expand_dims(
[dim for ds in arrays for dim, size in ds.sizes.items() if size == 1]
)

return c

Expand Down

0 comments on commit bce2f3e

Please sign in to comment.