Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept int value in head, thin and tail #3298

Merged
merged 5 commits into from
Sep 14, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def sel(

def head(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "DataArray":
"""Return a new DataArray whose data is given by the the first `n`
Expand All @@ -1053,7 +1053,7 @@ def head(

def tail(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "DataArray":
"""Return a new DataArray whose data is given by the the last `n`
Expand All @@ -1070,7 +1070,7 @@ def tail(

def thin(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "DataArray":
"""Return a new DataArray whose data is given by each `n` value
Expand Down
36 changes: 26 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,7 @@ def sel(

def head(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "Dataset":
"""Returns a new dataset with the first `n` values of each array
Expand Down Expand Up @@ -2041,12 +2041,17 @@ def head(
if isinstance(indexers, int):
indexers = {dim: indexers for dim in self.dims}
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "head")
for v in indexers.values():
if not isinstance(v, int):
raise TypeError("indexer value must be an integer")
elif v < 0:
raise ValueError("indexer value must be positive")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is help to include a little more context in error messages if possible. In this case, you could include offending the name and value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, Something along these lines maybe?

"expected integer as indexer value, found type %r for dim %r" % (type(v), k)

and

"expected positive integer as indexer value for dim %r" % k

The k and v come from iterating over indexers.items()

indexers = {k: slice(val) for k, val in indexers.items()}
return self.isel(indexers)

def tail(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "Dataset":
"""Returns a new dataset with the last `n` values of each array
Expand Down Expand Up @@ -2077,6 +2082,11 @@ def tail(
if isinstance(indexers, int):
indexers = {dim: indexers for dim in self.dims}
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "tail")
for v in indexers.values():
if not isinstance(v, int):
raise TypeError("indexer value must be an integer")
elif v < 0:
raise ValueError("indexer value must be positive")
indexers = {
k: slice(-val, None) if val != 0 else slice(val)
for k, val in indexers.items()
Expand All @@ -2085,7 +2095,7 @@ def tail(

def thin(
self,
indexers: Union[Mapping[Hashable, Any], int] = None,
indexers: Union[Mapping[Hashable, int], int] = None,
**indexers_kwargs: Any
) -> "Dataset":
"""Returns a new dataset with each array indexed along every `n`th
Expand All @@ -2108,16 +2118,22 @@ def thin(
Dataset.tail
DataArray.thin
"""
if not indexers_kwargs:
if indexers is None:
indexers = 5
if not isinstance(indexers, int) and not is_dict_like(indexers):
raise TypeError("indexers must be a dict or a single integer")
if (
not indexers_kwargs
and not isinstance(indexers, int)
and not is_dict_like(indexers)
):
raise TypeError("indexers must be a dict or a single integer")
if isinstance(indexers, int):
indexers = {dim: indexers for dim in self.dims}
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "thin")
if 0 in indexers.values():
raise ValueError("step cannot be zero")
for v in indexers.values():
if not isinstance(v, int):
raise TypeError("indexer value must be an integer")
elif v < 0:
raise ValueError("indexer value must be positive")
elif v == 0:
raise ValueError("step cannot be zero")
indexers = {k: slice(None, None, val) for k, val in indexers.items()}
return self.isel(indexers)

Expand Down
24 changes: 16 additions & 8 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,11 @@ def test_head(self):
self.dv.isel({dim: slice(5) for dim in self.dv.dims}), self.dv.head()
)
with raises_regex(TypeError, "must be a dict or a single int"):
self.dv.head(3.1)
self.dv.head([3])
with raises_regex(TypeError, "must be an int"):
self.dv.head(x=3.1)
with raises_regex(ValueError, "must be positive"):
self.dv.head(-3)

def test_tail(self):
assert_equal(self.dv.isel(x=slice(-5, None)), self.dv.tail(x=5))
Expand All @@ -1025,22 +1029,26 @@ def test_tail(self):
self.dv.isel({dim: slice(-5, None) for dim in self.dv.dims}), self.dv.tail()
)
with raises_regex(TypeError, "must be a dict or a single int"):
self.dv.tail(3.1)
self.dv.tail([3])
with raises_regex(TypeError, "must be an int"):
self.dv.tail(x=3.1)
with raises_regex(ValueError, "must be positive"):
self.dv.tail(-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very thorough tests! Thank you!


def test_thin(self):
assert_equal(self.dv.isel(x=slice(None, None, 5)), self.dv.thin(x=5))
assert_equal(
self.dv.isel({dim: slice(None, None, 6) for dim in self.dv.dims}),
self.dv.thin(6),
)
assert_equal(
self.dv.isel({dim: slice(None, None, 5) for dim in self.dv.dims}),
self.dv.thin(),
)
with raises_regex(TypeError, "must be a dict or a single int"):
self.dv.thin([3])
with raises_regex(TypeError, "must be an int"):
self.dv.thin(x=3.1)
with raises_regex(ValueError, "must be positive"):
self.dv.thin(-3)
with raises_regex(ValueError, "cannot be zero"):
self.dv.thin(time=0)
with raises_regex(TypeError, "must be a dict or a single int"):
self.dv.thin(3.1)

def test_loc(self):
self.ds["x"] = ("x", np.array(list("abcdefghij")))
Expand Down
24 changes: 16 additions & 8 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,11 @@ def test_head(self):
assert_equal(expected, actual)

with raises_regex(TypeError, "must be a dict or a single int"):
data.head(3.1)
data.head([3])
with raises_regex(TypeError, "must be an int"):
data.head(dim2=3.1)
with raises_regex(ValueError, "must be positive"):
data.head(time=-3)

def test_tail(self):
data = create_test_data()
Expand All @@ -1453,7 +1457,11 @@ def test_tail(self):
assert_equal(expected, actual)

with raises_regex(TypeError, "must be a dict or a single int"):
data.tail(3.1)
data.tail([3])
with raises_regex(TypeError, "must be an int"):
data.tail(dim2=3.1)
with raises_regex(ValueError, "must be positive"):
data.tail(time=-3)

def test_thin(self):
data = create_test_data()
Expand All @@ -1466,14 +1474,14 @@ def test_thin(self):
actual = data.thin(6)
assert_equal(expected, actual)

expected = data.isel({dim: slice(None, None, 5) for dim in data.dims})
actual = data.thin()
assert_equal(expected, actual)

with raises_regex(TypeError, "must be a dict or a single int"):
data.thin([3])
with raises_regex(TypeError, "must be an int"):
data.thin(dim2=3.1)
with raises_regex(ValueError, "cannot be zero"):
data.thin(time=0)
with raises_regex(TypeError, "must be a dict or a single int"):
data.thin(3.1)
with raises_regex(ValueError, "must be positive"):
data.thin(time=-3)

@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_sel_fancy(self):
Expand Down