Skip to content
forked from pydata/xarray

Commit

Permalink
Add join='override'
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 1, 2019
1 parent 1757dff commit 66f907e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
22 changes: 21 additions & 1 deletion xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def _get_joiner(join):
# We cannot return a function to "align" in this case, because it needs
# access to the dimension name to give a good error message.
return None
elif join == 'override':
# We rewrite all indexes and then use join='left'
return operator.itemgetter(0)
else:
raise ValueError('invalid value for join: %s' % join)

Expand All @@ -60,7 +63,7 @@ def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(),
----------
*objects : Dataset or DataArray
Objects to align.
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
join : {'outer', 'inner', 'left', 'right', 'exact', 'override'}, optional
Method for joining the indexes of the passed objects along each
dimension:
Expand All @@ -70,6 +73,9 @@ def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(),
- 'right': use indexes from the last object with each dimension
- 'exact': instead of aligning, raise `ValueError` when indexes to be
aligned are not equal
- 'override': if indexes are of same size, rewrite indexes to be
those of the first object. Indexes for the same dimension must have
the same size in all objects.
copy : bool, optional
If ``copy=True``, data in the return values is always copied. If
``copy=False`` and reindexing is unnecessary, or can be performed with
Expand Down Expand Up @@ -102,6 +108,20 @@ def align(*objects, join='inner', copy=True, indexes=None, exclude=frozenset(),
obj, = objects
return (obj.copy(deep=copy),)

if join == 'override':
objects = list(objects)
new_indexes = dict()
for index, obj in enumerate(objects[1:]):
for dim in objects[0].indexes:
if dim not in exclude:
if obj.indexes[dim].size != objects[0].indexes[dim].size:
raise ValueError("Cannot use join='override' when "
"all indexes don't have the same "
"shape.")
new_indexes[dim] = objects[0].indexes[dim]

objects[index + 1] = obj._overwrite_indexes(new_indexes)

all_indexes = defaultdict(list)
unlabeled_dim_sizes = defaultdict(set)
for obj in objects:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _replace_maybe_drop_dims(
if set(v.dims) <= allowed_dims)
return self._replace(variable, coords, name)

def _replace_indexes(
def _overwrite_indexes(
self,
indexes: Mapping[Hashable, Any]
) -> 'DataArray':
Expand Down
26 changes: 26 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,32 @@ def test_align_copy(self):
assert_identical(x, x2)
assert source_ndarray(x2.data) is not source_ndarray(x.data)

def test_align_override(self):
left = DataArray([1, 2, 3], dims='x', coords={'x': [0, 1, 2]})
right = DataArray(np.arange(9).reshape((3, 3)),
dims=['x', 'y'],
coords={'x': [0.1, 1.1, 2.1],
'y': [1, 2, 3]})

expected_right = DataArray(np.arange(9).reshape(3, 3),
dims=['x', 'y'],
coords={'x': [0, 1, 2],
'y': [1, 2, 3]})

new_left, new_right = align(left, right, join='override')
assert_identical(left, new_left)
assert_identical(new_right, expected_right)

new_left, new_right = align(left, right, exclude='x',
join='override')
assert_identical(left, new_left)
assert_identical(right, new_right)

with raises_regex(ValueError, "all indexes don't have the same shape"):
new_left, new_right = align(left.isel(x=0).expand_dims('x'),
right,
join='override')

def test_align_exclude(self):
x = DataArray([[1, 2], [3, 4]],
coords=[('a', [-1, -2]), ('b', [3, 4])])
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,28 @@ def test_align_exact(self):
with raises_regex(ValueError, 'indexes .* not equal'):
xr.align(left, right, join='exact')

def test_align_override(self):
left = xr.Dataset(coords={'x': [0, 1, 2]})
right = xr.Dataset(coords={'x': [0.1, 1.1, 2.1],
'y': [1, 2, 3]})

expected_right = xr.Dataset(coords={'x': [0, 1, 2],
'y': [1, 2, 3]})

new_left, new_right = xr.align(left, right, join='override')
assert_identical(left, new_left)
assert_identical(new_right, expected_right)

new_left, new_right = xr.align(left, right, exclude='x',
join='override')
assert_identical(left, new_left)
assert_identical(right, new_right)

with raises_regex(ValueError, "all indexes don't have the same shape"):
new_left, new_right = xr.align(left.isel(x=0).expand_dims('x'),
right,
join='override')

def test_align_exclude(self):
x = Dataset({'foo': DataArray([[1, 2], [3, 4]], dims=['x', 'y'],
coords={'x': [1, 2], 'y': [3, 4]})})
Expand Down

0 comments on commit 66f907e

Please sign in to comment.