Skip to content

Commit

Permalink
Add Ellipsis typehint to reductions (#7048)
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck authored Sep 28, 2022
1 parent e678a1d commit 226c23b
Show file tree
Hide file tree
Showing 14 changed files with 448 additions and 340 deletions.
431 changes: 238 additions & 193 deletions xarray/core/_reductions.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, Ellipsis, JoinOptions
from .types import CombineAttrsOptions, JoinOptions

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def cross(

def dot(
*arrays,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: str | Iterable[Hashable] | ellipsis | None = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
Expand Down
41 changes: 21 additions & 20 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from .types import (
CoarsenBoundaryOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -900,30 +900,30 @@ def coords(self) -> DataArrayCoordinates:
@overload
def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: Literal[False] = False,
) -> Dataset:
...

@overload
def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
*,
drop: Literal[True],
) -> T_DataArray:
...

def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: bool = False,
) -> T_DataArray | Dataset:
"""Given names of coordinates, reset them to become variables.
Parameters
----------
names : Hashable or iterable of Hashable, optional
names : str, Iterable of Hashable or None, optional
Name(s) of non-index coordinates in this dataset to reset into
variables. By default, all non-index coordinates are reset.
drop : bool, default: False
Expand Down Expand Up @@ -2574,7 +2574,7 @@ def stack(
# https://github.com/python/mypy/issues/12846 is resolved
def unstack(
self,
dim: Hashable | Sequence[Hashable] | None = None,
dim: Dims = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> DataArray:
Expand All @@ -2586,7 +2586,7 @@ def unstack(
Parameters
----------
dim : Hashable or sequence of Hashable, optional
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value : scalar or dict-like, default: nan
Expand Down Expand Up @@ -3400,9 +3400,9 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
def reduce(
self: T_DataArray,
func: Callable[..., Any],
dim: None | Hashable | Iterable[Hashable] = None,
dim: Dims | ellipsis = None,
*,
axis: None | int | Sequence[int] = None,
axis: int | Sequence[int] | None = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
Expand All @@ -3415,8 +3415,9 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : Hashable or Iterable of Hashable, optional
Dimension(s) over which to apply `func`.
dim : "...", str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
axis : int or sequence of int, optional
Axis(es) over which to repeatedly apply `func`. Only one of the
'dim' and 'axis' arguments can be supplied. If neither are
Expand Down Expand Up @@ -4386,7 +4387,7 @@ def imag(self: T_DataArray) -> T_DataArray:
def dot(
self: T_DataArray,
other: T_DataArray,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: Dims | ellipsis = None,
) -> T_DataArray:
"""Perform dot product of two DataArrays along their shared dims.
Expand All @@ -4396,7 +4397,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., str or Iterable of Hashable, optional
dims : ..., str, Iterable of Hashable or None, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.
Expand Down Expand Up @@ -4506,7 +4507,7 @@ def sortby(
def quantile(
self: T_DataArray,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down Expand Up @@ -5390,7 +5391,7 @@ def idxmax(
# https://github.com/python/mypy/issues/12846 is resolved
def argmin(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: Dims | ellipsis = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -5406,7 +5407,7 @@ def argmin(
Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the minimum. By default, finds minimum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down Expand Up @@ -5495,7 +5496,7 @@ def argmin(
# https://github.com/python/mypy/issues/12846 is resolved
def argmax(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: Dims | ellipsis = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -5511,7 +5512,7 @@ def argmax(
Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the maximum. By default, finds maximum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down Expand Up @@ -5679,7 +5680,7 @@ def curvefit(
self,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Hashable | Iterable[Hashable] | None = None,
reduce_dims: Dims = None,
skipna: bool = True,
p0: dict[str, Any] | None = None,
bounds: dict[str, Any] | None = None,
Expand All @@ -5704,7 +5705,7 @@ def curvefit(
array of length `len(x)`. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x0, x1), *params)`.
reduce_dims : Hashable or sequence of Hashable
reduce_dims : str, Iterable of Hashable or None, optional
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
Expand Down
70 changes: 35 additions & 35 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
CombineAttrsOptions,
CompatOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -1698,14 +1698,14 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas

def reset_coords(
self: T_Dataset,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: bool = False,
) -> T_Dataset:
"""Given names of coordinates, reset them to become variables
Parameters
----------
names : hashable or iterable of hashable, optional
names : str, Iterable of Hashable or None, optional
Name(s) of non-index coordinates in this dataset to reset into
variables. By default, all non-index coordinates are reset.
drop : bool, default: False
Expand Down Expand Up @@ -4457,7 +4457,7 @@ def _get_stack_index(

def _stack_once(
self: T_Dataset,
dims: Sequence[Hashable | Ellipsis],
dims: Sequence[Hashable | ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
Expand Down Expand Up @@ -4516,10 +4516,10 @@ def _stack_once(

def stack(
self: T_Dataset,
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
**dimensions_kwargs: Sequence[Hashable | ellipsis],
) -> T_Dataset:
"""
Stack any number of existing dimensions into a single new dimension.
Expand Down Expand Up @@ -4770,7 +4770,7 @@ def _unstack_full_reindex(

def unstack(
self: T_Dataset,
dim: Hashable | Iterable[Hashable] | None = None,
dim: Dims = None,
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
) -> T_Dataset:
Expand All @@ -4782,7 +4782,7 @@ def unstack(
Parameters
----------
dim : hashable or iterable of hashable, optional
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value : scalar or dict-like, default: nan
Expand Down Expand Up @@ -4860,15 +4860,13 @@ def unstack(
for v in nonindexes
)

for dim in dims:
for d in dims:
if needs_full_reindex:
result = result._unstack_full_reindex(
dim, stacked_indexes[dim], fill_value, sparse
d, stacked_indexes[d], fill_value, sparse
)
else:
result = result._unstack_once(
dim, stacked_indexes[dim], fill_value, sparse
)
result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)
return result

def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset:
Expand Down Expand Up @@ -5324,15 +5322,15 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset:

def drop_dims(
self: T_Dataset,
drop_dims: Hashable | Iterable[Hashable],
drop_dims: str | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
) -> T_Dataset:
"""Drop dimensions and associated variables from this dataset.
Parameters
----------
drop_dims : hashable or iterable of hashable
drop_dims : str or Iterable of Hashable
Dimension or dimensions to drop.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the
Expand Down Expand Up @@ -5763,7 +5761,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
def reduce(
self: T_Dataset,
func: Callable,
dim: Hashable | Iterable[Hashable] = None,
dim: Dims | ellipsis = None,
*,
keep_attrs: bool | None = None,
keepdims: bool = False,
Expand All @@ -5778,8 +5776,8 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : str or sequence of str, optional
Dimension(s) over which to apply `func`. By default `func` is
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
keep_attrs : bool or None, optional
If True, the dataset's attributes (`attrs`) will be copied from
Expand Down Expand Up @@ -5837,18 +5835,15 @@ def reduce(
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
):
reduce_maybe_single: Hashable | None | list[Hashable]
if len(reduce_dims) == 1:
# unpack dimensions for the benefit of functions
# like np.argmin which can't handle tuple arguments
(reduce_maybe_single,) = reduce_dims
elif len(reduce_dims) == var.ndim:
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
reduce_maybe_single = None
else:
reduce_maybe_single = reduce_dims
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
# keep single-element dims as list, to support Hashables
reduce_maybe_single = (
None
if len(reduce_dims) == var.ndim and var.ndim != 1
else reduce_dims
)
variables[name] = var.reduce(
func,
dim=reduce_maybe_single,
Expand Down Expand Up @@ -6957,7 +6952,7 @@ def sortby(
def quantile(
self: T_Dataset,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
numeric_only: bool = False,
keep_attrs: bool = None,
Expand Down Expand Up @@ -8303,7 +8298,9 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmin_func = getattr(duck_array_ops, "argmin")
return self.reduce(argmin_func, dim=dim, **kwargs)
return self.reduce(
argmin_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down Expand Up @@ -8361,7 +8358,9 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmax_func = getattr(duck_array_ops, "argmax")
return self.reduce(argmax_func, dim=dim, **kwargs)
return self.reduce(
argmax_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down Expand Up @@ -8469,7 +8468,7 @@ def curvefit(
self: T_Dataset,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Hashable | Iterable[Hashable] | None = None,
reduce_dims: Dims = None,
skipna: bool = True,
p0: dict[str, Any] | None = None,
bounds: dict[str, Any] | None = None,
Expand All @@ -8494,7 +8493,7 @@ def curvefit(
array of length `len(x)`. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x0, x1), *params)`.
reduce_dims : hashable or sequence of hashable
reduce_dims : str, Iterable of Hashable or None, optional
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
Expand Down Expand Up @@ -8545,6 +8544,7 @@ def curvefit(
if kwargs is None:
kwargs = {}

reduce_dims_: list[Hashable]
if not reduce_dims:
reduce_dims_ = []
elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable):
Expand Down
Loading

0 comments on commit 226c23b

Please sign in to comment.