Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add join='override' #3175

Merged
merged 19 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ New functions/methods
Enhancements
~~~~~~~~~~~~

- Added ``join='override'``. This only checks that index sizes are equal among objects and skips
checking indexes for equality. By `Deepak Cherian <https://github.com/dcherian>`_.
- :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg.
It is passed down to :py:func:`~xarray.align`. By `Deepak Cherian <https://github.com/dcherian>`_.
- In :py:meth:`~xarray.Dataset.to_zarr`, passing ``mode`` is not mandatory if
Expand Down
23 changes: 22 additions & 1 deletion xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def _get_joiner(join):
# We cannot return a function to "align" in this case, because it needs
# access to the dimension name to give a good error message.
return None
elif join == "override":
# We rewrite all indexes and then use join='left'
return operator.itemgetter(0)
else:
raise ValueError("invalid value for join: %s" % join)

Expand All @@ -57,7 +60,7 @@ def align(
----------
*objects : Dataset or DataArray
Objects to align.
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
join : {'outer', 'inner', 'left', 'right', 'exact', 'override'}, optional
Method for joining the indexes of the passed objects along each
dimension:

Expand All @@ -67,6 +70,9 @@ def align(
- 'right': use indexes from the last object with each dimension
- 'exact': instead of aligning, raise `ValueError` when indexes to be
aligned are not equal
- 'override': if indexes are of same size, rewrite indexes to be
those of the first object. Indexes for the same dimension must have
the same size in all objects.
copy : bool, optional
If ``copy=True``, data in the return values is always copied. If
``copy=False`` and reindexing is unnecessary, or can be performed with
Expand Down Expand Up @@ -99,6 +105,21 @@ def align(
obj, = objects
return (obj.copy(deep=copy),)

if join == "override":
dcherian marked this conversation as resolved.
Show resolved Hide resolved
objects = list(objects)
new_indexes = dict()
for index, obj in enumerate(objects[1:]):
for dim in objects[0].indexes:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
if dim not in exclude:
if obj.indexes[dim].size != objects[0].indexes[dim].size:
raise ValueError(
"Indexes along dimension %r are not equal."
" Cannot use join='override'." % dim
)
new_indexes[dim] = objects[0].indexes[dim]

objects[index + 1] = obj._overwrite_indexes(new_indexes)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

all_indexes = defaultdict(list)
unlabeled_dim_sizes = defaultdict(set)
for obj in objects:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _replace_maybe_drop_dims(
)
return self._replace(variable, coords, name)

def _replace_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
if not len(indexes):
return self
coords = self._coords.copy()
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def test_concat_join_kwarg(self):
{"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)},
coords={"x": [0, 1], "y": [0.0001]},
)
expected["override"] = Dataset(
{"a": (("x", "y"), np.array([0, 0], ndmin=2).T)},
coords={"x": [0, 1], "y": [0]},
)

with raises_regex(ValueError, "indexes along dimension 'y'"):
actual = concat([ds1, ds2], join="exact", dim="x")
Expand Down Expand Up @@ -399,6 +403,10 @@ def test_concat_join_kwarg(self):
{"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)},
coords={"x": [0, 1], "y": [0.0001]},
)
expected["override"] = Dataset(
{"a": (("x", "y"), np.array([0, 0], ndmin=2).T)},
coords={"x": [0, 1], "y": [0]},
)

with raises_regex(ValueError, "indexes along dimension 'y'"):
actual = concat([ds1, ds2], join="exact", dim="x")
Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3113,6 +3113,33 @@ def test_align_copy(self):
assert_identical(x, x2)
assert source_ndarray(x2.data) is not source_ndarray(x.data)

def test_align_override(self):
left = DataArray([1, 2, 3], dims="x", coords={"x": [0, 1, 2]})
right = DataArray(
np.arange(9).reshape((3, 3)),
dims=["x", "y"],
coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]},
)

expected_right = DataArray(
np.arange(9).reshape(3, 3),
dims=["x", "y"],
coords={"x": [0, 1, 2], "y": [1, 2, 3]},
)

new_left, new_right = align(left, right, join="override")
assert_identical(left, new_left)
assert_identical(new_right, expected_right)

new_left, new_right = align(left, right, exclude="x", join="override")
assert_identical(left, new_left)
assert_identical(right, new_right)

with raises_regex(ValueError, "Indexes along dimension 'x' are not equal."):
new_left, new_right = align(
left.isel(x=0).expand_dims("x"), right, join="override"
)

def test_align_exclude(self):
x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])])
y = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, 20]), ("b", [5, 6])])
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,25 @@ def test_align_exact(self):
with raises_regex(ValueError, "indexes .* not equal"):
xr.align(left, right, join="exact")

def test_align_override(self):
left = xr.Dataset(coords={"x": [0, 1, 2]})
right = xr.Dataset(coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]})

expected_right = xr.Dataset(coords={"x": [0, 1, 2], "y": [1, 2, 3]})

new_left, new_right = xr.align(left, right, join="override")
assert_identical(left, new_left)
assert_identical(new_right, expected_right)

new_left, new_right = xr.align(left, right, exclude="x", join="override")
assert_identical(left, new_left)
assert_identical(right, new_right)

with raises_regex(ValueError, "Indexes along dimension 'x' are not equal."):
new_left, new_right = xr.align(
left.isel(x=0).expand_dims("x"), right, join="override"
)

def test_align_exclude(self):
x = Dataset(
{
Expand Down