Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented pad with new-indexes #4974

Closed
wants to merge 16 commits into from
Closed
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3789,7 +3789,7 @@ def polyfit(
def pad(
self,
pad_width: Mapping[
Hashable, Union[int, Tuple[Union[int, Sequence], Union[int, Sequence]]]
Hashable, Union[int, Tuple[int, int], Tuple[Sequence, Sequence]]
] = None,
mode: str = "constant",
stat_length: Union[
Expand Down Expand Up @@ -3821,9 +3821,9 @@ def pad(
describing the number of values padded along each dimension.
{dim: pad} is a shortcut for pad_before = pad_after = pad
Note that having np.nan in IndexVariable loses most of the useful
functionalities of xarray. To avoid this problem, an iterable,
such as a list or np.array, can be used for either pad_before or pad_after.
In this case, these values will be used for an IndexVariable and preventing
functionalities of xarray. To avoid this problem, sequences,
such as lists or np.arrays, can be used for pad_before and pad_after.
In this case, these values will be used for an IndexVariable preventing
from the loss of functionalities.
mode : str, default: "constant"
One of the following string values (taken from numpy docs)
Expand Down
38 changes: 21 additions & 17 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6497,7 +6497,8 @@ def polyfit(
def pad(
self,
pad_width: Mapping[
Hashable, Union[int, Tuple[Union[int, Sequence], Union[int, Sequence]]]
Hashable,
Union[int, Tuple[Union[int], Union[int]], Tuple[Sequence, Sequence]],
] = None,
mode: str = "constant",
stat_length: Union[
Expand Down Expand Up @@ -6529,8 +6530,8 @@ def pad(
describing the number of values padded along each dimension.
{dim: pad} is a shortcut for pad_before = pad_after = pad
Note that having np.nan in IndexVariable loses most of the useful
functionalities of xarray. To avoid this problem, a sequence,
such as a list or np.array, can be used for either pad_before or pad_after.
functionalities of xarray. To avoid this problem, sequences,
such as lists or np.arrays, can be used for pad_before and pad_after.
In this case, these values will be used for an IndexVariable preventing
from the loss of functionalities.
mode : str, default: "constant"
Expand Down Expand Up @@ -6683,20 +6684,23 @@ def pad(
end_values=end_values,
reflect_type=reflect_type,
)
elif name in var_pad_width.keys() and not isinstance(
pad_width[name], int
): # dimension coordinates
w0, w1 = pad_width[name] # type: ignore
fill_value_ind = dtypes.get_fill_value(var.dtype)
if isinstance(w0, int):
w0_ = IndexVariable(name, [fill_value_ind] * w0)
else:
w0_ = IndexVariable(name, w0)
if isinstance(w1, int):
w1_ = IndexVariable(name, [fill_value_ind] * w1)
else:
w1_ = IndexVariable(name, w1)
variables[name] = var.concat([w0_, var, w1_], dim=name)
elif (
name in var_pad_width.keys() # dimension coordinates
and isinstance(pad_width[name], Sequence)
and (
isinstance(pad_width[name][0], Sequence) # type: ignore
or isinstance(pad_width[name][1], Sequence) # type: ignore
)
):
pad_start, pad_end = pad_width[name] # type: ignore
if isinstance(pad_start, int) or isinstance(pad_end, int):
# do not allow [Sequence, int] as pad_width
raise TypeError(
"({}, {}) is used for pad_width[{}]. Must be either (int, int) or (Sequence, Sequence).".format(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed that a mixture of sequence and int for pad_width, like

da.pad(dim2=([0, 1, 2], 3))

makes implementation very complex (as we may need to support many options, such as pad_mode).
I just disallowed the mixed argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

type(pad_start), type(pad_end), name
)
)
variables[name] = var.pad_indexes(pad_start=pad_start, pad_end=pad_end)
else:
variables[name] = var.pad(
pad_width=var_pad_width,
Expand Down
9 changes: 9 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,15 @@ def pad(

return type(self)(self.dims, array)

def pad_indexes(self, pad_start: Sequence, pad_end: Sequence):
Copy link
Contributor

@dcherian dcherian Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this IndexVariable.pad? and add a test in test_variable.py?

"""
Return a new (Index)Variable with [pad_start, pad_end] padded at the head and tail
of the original array. Used in dataset.pad
"""
start = type(self)(self.dims[0], pad_start)
end = type(self)(self.dims[0], pad_end)
return type(self).concat([start, self, end], dim=self.dims[0])

def _roll_one_dim(self, dim, count):
axis = self.get_axis_num(dim)

Expand Down
9 changes: 7 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5801,7 +5801,7 @@ def test_pad(self):

def test_pad_index(self):
ds = create_test_data(seed=1)
padded = ds.pad(dim2=([0, 1, 2], 0), constant_values=42)
padded = ds.pad(dim2=([0, 1, 2], []), constant_values=42)

assert padded["dim2"].shape == (12,)
assert padded["var1"].shape == (8, 12)
Expand All @@ -5810,7 +5810,7 @@ def test_pad_index(self):
assert dict(padded.dims) == {"dim1": 8, "dim2": 12, "dim3": 10, "time": 20}
assert np.nan not in padded["dim2"]

padded = ds.pad(dim2=(0, [0, 1, 2]), constant_values=42)
padded = ds.pad(dim2=([], [0, 1, 2]), constant_values=42)
assert np.nan not in padded["dim2"]

padded = ds.pad(dim2=([0, 1], [0, 1, 2]), constant_values=42)
Expand All @@ -5819,6 +5819,11 @@ def test_pad_index(self):
padded = ds.pad(dim2=([0, 1], [2]), constant_values=42)
assert np.nan not in padded["dim2"]

def test_pad_index_error(self):
with pytest.raises(TypeError):
ds = create_test_data(seed=1)
ds.pad(dim2=(0, [1, 2]))

def test_pad_index_doc(self):
ds = xr.Dataset({"foo": ("x", range(3))}, coords={"x": [0, 1, 2]})
padded = ds.pad(x=([-1], [3]))
Expand Down