Skip to content

Commit

Permalink
Accept missing_dims for Variable.tranpose and Dataset.transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
gcaria committed Jul 7, 2021
1 parent 8090513 commit 5489eff
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
22 changes: 15 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4543,7 +4543,11 @@ def drop_dims(
drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims}
return self.drop_vars(drop_vars)

def transpose(self, *dims: Hashable) -> "Dataset":
def transpose(
self,
*dims: Hashable,
missing_dims: str = "raise",
) -> "Dataset":
"""Return a new Dataset object with all array dimensions transposed.
Although the order of dimensions on each array will change, the dataset
Expand All @@ -4554,6 +4558,12 @@ def transpose(self, *dims: Hashable) -> "Dataset":
*dims : hashable, optional
By default, reverse the dimensions on each array. Otherwise,
reorder the dimensions to this order.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Dataset:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
Returns
-------
Expand All @@ -4572,12 +4582,10 @@ def transpose(self, *dims: Hashable) -> "Dataset":
numpy.transpose
DataArray.transpose
"""
if dims:
if set(dims) ^ set(self.dims) and ... not in dims:
raise ValueError(
f"arguments to transpose ({dims}) must be "
f"permuted dataset dimensions ({tuple(self.dims)})"
)
# Use infix_dims to check once for missing dimensions
if len(dims) != 0:
_ = list(infix_dims(dims, self.dims, missing_dims))

ds = self.copy()
for name, var in self._variables.items():
var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,)))
Expand Down
16 changes: 14 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,14 +1378,24 @@ def roll(self, shifts=None, **shifts_kwargs):
result = result._roll_one_dim(dim, count)
return result

def transpose(self, *dims) -> "Variable":
def transpose(
self,
*dims,
missing_dims: str = "raise",
) -> "Variable":
"""Return a new Variable object with transposed dimensions.
Parameters
----------
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Variable:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
Returns
-------
Expand All @@ -1404,7 +1414,9 @@ def transpose(self, *dims) -> "Variable":
"""
if len(dims) == 0:
dims = self.dims[::-1]
dims = tuple(infix_dims(dims, self.dims))
else:
dims = tuple(infix_dims(dims, self.dims, missing_dims))

if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
Expand Down
17 changes: 13 additions & 4 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5195,10 +5195,19 @@ def test_dataset_transpose(self):
expected_dims = tuple(d for d in new_order if d in ds[k].dims)
assert actual[k].dims == expected_dims

with pytest.raises(ValueError, match=r"permuted"):
ds.transpose("dim1", "dim2", "dim3")
with pytest.raises(ValueError, match=r"permuted"):
ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim")
# test missing dimension, raise error
with pytest.raises(ValueError):
ds.transpose(..., "not_a_dim")

# test missing dimension, ignore error
actual = ds.transpose(..., "not_a_dim", missing_dims="ignore")
expected_ell = ds.transpose(...)
assert_identical(expected_ell, actual)

# test missing dimension, raise warning
with pytest.warns(UserWarning):
actual = ds.transpose(..., "not_a_dim", missing_dims="warn")
assert_identical(expected_ell, actual)

assert "T" not in dir(ds)

Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,20 @@ def test_transpose(self):
w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x))
assert_identical(w, w3.transpose("a", "b", "c", "d"))

# test missing dimension, raise error
with pytest.raises(ValueError):
v.transpose(..., "not_a_dim")

# test missing dimension, ignore error
actual = v.transpose(..., "not_a_dim", missing_dims="ignore")
expected_ell = v.transpose(...)
assert_identical(expected_ell, actual)

# test missing dimension, raise warning
with pytest.warns(UserWarning):
v.transpose(..., "not_a_dim", missing_dims="warn")
assert_identical(expected_ell, actual)

def test_transpose_0d(self):
for value in [
3.5,
Expand Down

0 comments on commit 5489eff

Please sign in to comment.