diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9c16fb74a7b..f7b5354e9ee 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,11 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ +- Now :py:meth:`DataArray.pad` and :py:meth:`Dataset.pad` accept a tuple of indexes + as its arguments. In this case, these values will be used as the newly extended coordinate labels + of the IndexVariable. + By `Keisuke Fujii `_. + - Allow passing ``combine_attrs`` to :py:meth:`Dataset.merge` (:pull:`4895`). By `Justus Magin `_. - Support for `dask.graph_manipulation diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index dd871eb21bc..053f1bf2f87 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3788,7 +3788,9 @@ def polyfit( def pad( self, - pad_width: Mapping[Hashable, Union[int, Tuple[int, int]]] = None, + pad_width: Mapping[ + Hashable, Union[int, Tuple[int, int], Tuple[Sequence, Sequence]] + ] = None, mode: str = "constant", stat_length: Union[ int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] @@ -3818,6 +3820,11 @@ def pad( Mapping with the form of {dim: (pad_before, pad_after)} describing the number of values padded along each dimension. {dim: pad} is a shortcut for pad_before = pad_after = pad + Note that having np.nan in IndexVariable loses most of the useful + functionalities of xarray. To avoid this problem, sequences, + such as lists or np.arrays, can be used for pad_before and pad_after. + In this case, these values will be used for an IndexVariable preventing + from the loss of functionalities. mode : str, default: "constant" One of the following string values (taken from numpy docs) @@ -3942,6 +3949,20 @@ def pad( * x (x) float64 nan 0.0 1.0 nan * y (y) int64 10 20 30 40 z (x) float64 nan 100.0 200.0 nan + + Specify coordinate labels for padded values by passing a tuple of sequences + + >>> da.pad(x=([-2, -1], [2])) + + array([[nan, nan, nan, nan], + [nan, nan, nan, nan], + [ 0., 1., 2., 3.], + [10., 11., 12., 13.], + [nan, nan, nan, nan]]) + Coordinates: + * x (x) int64 -2 -1 0 1 2 + * y (y) int64 10 20 30 40 + z (x) float64 nan nan 100.0 200.0 nan """ ds = self._to_temp_dataset().pad( pad_width=pad_width, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index db45157e7c1..5f0cb05a010 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6526,7 +6526,10 @@ def polyfit( def pad( self, - pad_width: Mapping[Hashable, Union[int, Tuple[int, int]]] = None, + pad_width: Mapping[ + Hashable, + Union[int, Tuple[Union[int], Union[int]], Tuple[Sequence, Sequence]], + ] = None, mode: str = "constant", stat_length: Union[ int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] @@ -6552,10 +6555,15 @@ def pad( Parameters ---------- - pad_width : mapping of hashable to tuple of int + pad_width : mapping of hashable to int or tuple of int or Sequence. Mapping with the form of {dim: (pad_before, pad_after)} describing the number of values padded along each dimension. {dim: pad} is a shortcut for pad_before = pad_after = pad + Note that having np.nan in IndexVariable loses most of the useful + functionalities of xarray. To avoid this problem, sequences, + such as lists or np.arrays, can be used for pad_before and pad_after. + In this case, these values will be used for an IndexVariable preventing + from the loss of functionalities. mode : str, default: "constant" One of the following string values (taken from numpy docs). @@ -6652,6 +6660,14 @@ def pad( Dimensions without coordinates: x Data variables: foo (x) float64 nan 0.0 1.0 2.0 3.0 4.0 nan nan + >>> ds = xr.Dataset({"foo": ("x", range(3))}, coords={"x": [0, 1, 2]}) + >>> ds.pad(x=([-1], [3])) + + Dimensions: (x: 5) + Coordinates: + * x (x) int64 -1 0 1 2 3 + Data variables: + foo (x) float64 nan 0.0 1.0 2.0 nan """ pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad") @@ -6668,8 +6684,25 @@ def pad( coord_pad_options = {} variables = {} + + # standarize pad_width + pad_width_standardized = {} # type: Dict[Hashable, Tuple[int, int]] + for k, v in pad_width.items(): + if not isinstance(v, int): + # if pad_width is a tuple of iterable, we use its length for + # pad_width_standardized + # mypy does not know the length here and infers Tuple[int, ...] + # see https://github.com/python/mypy/issues/7509 + pad_width_standardized[k] = tuple( # type: ignore + len(v1) if isinstance(v1, Sequence) else v1 for v1 in v + ) + else: # just an int + pad_width_standardized[k] = (v, v) + for name, var in self.variables.items(): - var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} + var_pad_width = { + k: v for k, v in pad_width_standardized.items() if k in var.dims + } if not var_pad_width: variables[name] = var elif name in self.data_vars: @@ -6681,6 +6714,23 @@ def pad( end_values=end_values, reflect_type=reflect_type, ) + elif ( + name in var_pad_width.keys() # dimension coordinates + and isinstance(pad_width[name], Sequence) + and ( + isinstance(pad_width[name][0], Sequence) # type: ignore + or isinstance(pad_width[name][1], Sequence) # type: ignore + ) + ): + pad_start, pad_end = pad_width[name] # type: ignore + if isinstance(pad_start, int) or isinstance(pad_end, int): + # do not allow [Sequence, int] as pad_width + raise TypeError( + "({}, {}) is used for pad_width[{}]. Must be either (int, int) or (Sequence, Sequence).".format( + type(pad_start), type(pad_end), name + ) + ) + variables[name] = var.pad_indexes(pad_start=pad_start, pad_end=pad_end) else: variables[name] = var.pad( pad_width=var_pad_width, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c59cbf1f3e4..8096cdb03a6 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1355,6 +1355,15 @@ def pad( return type(self)(self.dims, array) + def pad_indexes(self, pad_start: Sequence, pad_end: Sequence): + """ + Return a new (Index)Variable with [pad_start, pad_end] padded at the head and tail + of the original array. Used in dataset.pad + """ + start = type(self)(self.dims[0], pad_start) + end = type(self)(self.dims[0], pad_end) + return type(self).concat([start, self, end], dim=self.dims[0]) + def _roll_one_dim(self, dim, count): axis = self.get_axis_num(dim) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 13cd03acf96..30c4b23deed 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5787,16 +5787,55 @@ def test_polyfit_warnings(self): def test_pad(self): ds = create_test_data(seed=1) - padded = ds.pad(dim2=(1, 1), constant_values=42) + for width in [(1, 1), 1]: + padded = ds.pad(dim2=width, constant_values=42) - assert padded["dim2"].shape == (11,) - assert padded["var1"].shape == (8, 11) - assert padded["var2"].shape == (8, 11) + assert padded["dim2"].shape == (11,) + assert padded["var1"].shape == (8, 11) + assert padded["var2"].shape == (8, 11) + assert padded["var3"].shape == (10, 8) + assert dict(padded.dims) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} + + np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) + np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + + def test_pad_index(self): + ds = create_test_data(seed=1) + padded = ds.pad(dim2=([0, 1, 2], []), constant_values=42) + + assert padded["dim2"].shape == (12,) + assert padded["var1"].shape == (8, 12) + assert padded["var2"].shape == (8, 12) assert padded["var3"].shape == (10, 8) - assert dict(padded.dims) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} + assert dict(padded.dims) == {"dim1": 8, "dim2": 12, "dim3": 10, "time": 20} + assert np.nan not in padded["dim2"] + + padded = ds.pad(dim2=([], [0, 1, 2]), constant_values=42) + assert np.nan not in padded["dim2"] + + padded = ds.pad(dim2=([0, 1], [0, 1, 2]), constant_values=42) + assert np.nan not in padded["dim2"] - np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) - np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + padded = ds.pad(dim2=([0, 1], [2]), constant_values=42) + assert np.nan not in padded["dim2"] + + def test_pad_index_error(self): + with pytest.raises(TypeError): + ds = create_test_data(seed=1) + ds.pad(dim2=(0, [1, 2])) + + def test_pad_index_doc(self): + ds = xr.Dataset({"foo": ("x", range(3))}, coords={"x": [0, 1, 2]}) + padded = ds.pad(x=([-1], [3])) + assert np.nan not in padded["x"] + + da = xr.DataArray( + [[0, 1, 2, 3], [10, 11, 12, 13]], + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30, 40], "z": ("x", [100, 200])}, + ) + padded = da.pad(x=([-2, -1], [2])) + assert np.nan not in padded["x"] def test_astype_attrs(self): data = create_test_data(seed=123)