Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ellipsis typehint to reductions #7048

Merged
merged 12 commits into from
Sep 28, 2022
431 changes: 238 additions & 193 deletions xarray/core/_reductions.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .coordinates import Coordinates
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, Ellipsis, JoinOptions
from .types import CombineAttrsOptions, JoinOptions

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -1624,7 +1624,7 @@ def cross(

def dot(
*arrays,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: str | Iterable[Hashable] | ellipsis | None = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
Expand Down
41 changes: 21 additions & 20 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from .types import (
CoarsenBoundaryOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -869,30 +869,30 @@ def coords(self) -> DataArrayCoordinates:
@overload
def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: Literal[False] = False,
) -> Dataset:
...

@overload
def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
*,
drop: Literal[True],
) -> T_DataArray:
...

def reset_coords(
self: T_DataArray,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: bool = False,
) -> T_DataArray | Dataset:
"""Given names of coordinates, reset them to become variables.

Parameters
----------
names : Hashable or iterable of Hashable, optional
names : str, Iterable of Hashable or None, optional
Name(s) of non-index coordinates in this dataset to reset into
variables. By default, all non-index coordinates are reset.
drop : bool, default: False
Expand Down Expand Up @@ -2352,7 +2352,7 @@ def stack(
# https://github.com/python/mypy/issues/12846 is resolved
def unstack(
self,
dim: Hashable | Sequence[Hashable] | None = None,
dim: Dims = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> DataArray:
Expand All @@ -2364,7 +2364,7 @@ def unstack(

Parameters
----------
dim : Hashable or sequence of Hashable, optional
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value : scalar or dict-like, default: nan
Expand Down Expand Up @@ -2888,9 +2888,9 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
def reduce(
self: T_DataArray,
func: Callable[..., Any],
dim: None | Hashable | Iterable[Hashable] = None,
dim: Dims | ellipsis = None,
*,
axis: None | int | Sequence[int] = None,
axis: int | Sequence[int] | None = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
Expand All @@ -2903,8 +2903,9 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : Hashable or Iterable of Hashable, optional
Dimension(s) over which to apply `func`.
dim : "...", str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
axis : int or sequence of int, optional
Axis(es) over which to repeatedly apply `func`. Only one of the
'dim' and 'axis' arguments can be supplied. If neither are
Expand Down Expand Up @@ -3770,7 +3771,7 @@ def imag(self: T_DataArray) -> T_DataArray:
def dot(
self: T_DataArray,
other: T_DataArray,
dims: str | Iterable[Hashable] | Ellipsis | None = None,
dims: Dims | ellipsis = None,
) -> T_DataArray:
"""Perform dot product of two DataArrays along their shared dims.

Expand All @@ -3780,7 +3781,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims : ..., str or Iterable of Hashable, optional
dims : ..., str, Iterable of Hashable or None, optional
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
If not specified, then all the common dimensions are summed over.

Expand Down Expand Up @@ -3890,7 +3891,7 @@ def sortby(
def quantile(
self: T_DataArray,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand Down Expand Up @@ -4774,7 +4775,7 @@ def idxmax(
# https://github.com/python/mypy/issues/12846 is resolved
def argmin(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: Dims | ellipsis = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -4790,7 +4791,7 @@ def argmin(

Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the minimum. By default, finds minimum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down Expand Up @@ -4879,7 +4880,7 @@ def argmin(
# https://github.com/python/mypy/issues/12846 is resolved
def argmax(
self,
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
dim: Dims | ellipsis = None,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
Expand All @@ -4895,7 +4896,7 @@ def argmax(

Parameters
----------
dim : Hashable, sequence of Hashable, None or ..., optional
dim : "...", str, Iterable of Hashable or None, optional
The dimensions over which to find the maximum. By default, finds maximum over
all dimensions - for now returning an int for backward compatibility, but
this is deprecated, in future will return a dict with indices for all
Expand Down Expand Up @@ -5063,7 +5064,7 @@ def curvefit(
self,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Hashable | Iterable[Hashable] | None = None,
reduce_dims: Dims = None,
skipna: bool = True,
p0: dict[str, Any] | None = None,
bounds: dict[str, Any] | None = None,
Expand All @@ -5088,7 +5089,7 @@ def curvefit(
array of length `len(x)`. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x0, x1), *params)`.
reduce_dims : Hashable or sequence of Hashable
reduce_dims : str, Iterable of Hashable or None, optional
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
Expand Down
70 changes: 35 additions & 35 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
CombineAttrsOptions,
CompatOptions,
DatetimeUnitOptions,
Ellipsis,
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
InterpOptions,
Expand Down Expand Up @@ -1698,14 +1698,14 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas

def reset_coords(
self: T_Dataset,
names: Hashable | Iterable[Hashable] | None = None,
names: Dims = None,
drop: bool = False,
) -> T_Dataset:
"""Given names of coordinates, reset them to become variables

Parameters
----------
names : hashable or iterable of hashable, optional
names : str, Iterable of Hashable or None, optional
Name(s) of non-index coordinates in this dataset to reset into
variables. By default, all non-index coordinates are reset.
drop : bool, default: False
Expand Down Expand Up @@ -4256,7 +4256,7 @@ def _get_stack_index(

def _stack_once(
self: T_Dataset,
dims: Sequence[Hashable | Ellipsis],
dims: Sequence[Hashable | ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
Expand Down Expand Up @@ -4315,10 +4315,10 @@ def _stack_once(

def stack(
self: T_Dataset,
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
**dimensions_kwargs: Sequence[Hashable | ellipsis],
) -> T_Dataset:
"""
Stack any number of existing dimensions into a single new dimension.
Expand Down Expand Up @@ -4569,7 +4569,7 @@ def _unstack_full_reindex(

def unstack(
self: T_Dataset,
dim: Hashable | Iterable[Hashable] | None = None,
dim: Dims = None,
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
) -> T_Dataset:
Expand All @@ -4581,7 +4581,7 @@ def unstack(

Parameters
----------
dim : hashable or iterable of hashable, optional
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value : scalar or dict-like, default: nan
Expand Down Expand Up @@ -4659,15 +4659,13 @@ def unstack(
for v in nonindexes
)

for dim in dims:
for d in dims:
if needs_full_reindex:
result = result._unstack_full_reindex(
dim, stacked_indexes[dim], fill_value, sparse
d, stacked_indexes[d], fill_value, sparse
)
else:
result = result._unstack_once(
dim, stacked_indexes[dim], fill_value, sparse
)
result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)
return result

def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset:
Expand Down Expand Up @@ -5065,15 +5063,15 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset:

def drop_dims(
self: T_Dataset,
drop_dims: Hashable | Iterable[Hashable],
drop_dims: str | Iterable[Hashable],
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
*,
errors: ErrorOptions = "raise",
) -> T_Dataset:
"""Drop dimensions and associated variables from this dataset.

Parameters
----------
drop_dims : hashable or iterable of hashable
drop_dims : str or Iterable of Hashable
Dimension or dimensions to drop.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', raises a ValueError error if any of the
Expand Down Expand Up @@ -5504,7 +5502,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
def reduce(
self: T_Dataset,
func: Callable,
dim: Hashable | Iterable[Hashable] = None,
dim: Dims | ellipsis = None,
*,
keep_attrs: bool | None = None,
keepdims: bool = False,
Expand All @@ -5519,8 +5517,8 @@ def reduce(
Function which can be called in the form
`f(x, axis=axis, **kwargs)` to return the result of reducing an
np.ndarray over an integer valued axis.
dim : str or sequence of str, optional
Dimension(s) over which to apply `func`. By default `func` is
dim : str, Iterable of Hashable or None, optional
Dimension(s) over which to apply `func`. By default `func` is
applied over all dimensions.
keep_attrs : bool or None, optional
If True, the dataset's attributes (`attrs`) will be copied from
Expand Down Expand Up @@ -5578,18 +5576,15 @@ def reduce(
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
):
reduce_maybe_single: Hashable | None | list[Hashable]
if len(reduce_dims) == 1:
# unpack dimensions for the benefit of functions
# like np.argmin which can't handle tuple arguments
(reduce_maybe_single,) = reduce_dims
elif len(reduce_dims) == var.ndim:
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
reduce_maybe_single = None
else:
reduce_maybe_single = reduce_dims
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
# keep single-element dims as list, to support Hashables
reduce_maybe_single = (
None
if len(reduce_dims) == var.ndim and var.ndim != 1
else reduce_dims
)
variables[name] = var.reduce(
func,
dim=reduce_maybe_single,
Expand Down Expand Up @@ -6698,7 +6693,7 @@ def sortby(
def quantile(
self: T_Dataset,
q: ArrayLike,
dim: str | Iterable[Hashable] | None = None,
dim: Dims = None,
method: QUANTILE_METHODS = "linear",
numeric_only: bool = False,
keep_attrs: bool = None,
Expand Down Expand Up @@ -8030,7 +8025,9 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmin_func = getattr(duck_array_ops, "argmin")
return self.reduce(argmin_func, dim=dim, **kwargs)
return self.reduce(
argmin_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down Expand Up @@ -8088,7 +8085,9 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
# Return int index if single dimension is passed, and is not part of a
# sequence
argmax_func = getattr(duck_array_ops, "argmax")
return self.reduce(argmax_func, dim=dim, **kwargs)
return self.reduce(
argmax_func, dim=None if dim is None else [dim], **kwargs
)
else:
raise ValueError(
"When dim is a sequence or ..., DataArray.argmin() returns a dict. "
Expand Down Expand Up @@ -8196,7 +8195,7 @@ def curvefit(
self: T_Dataset,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Hashable | Iterable[Hashable] | None = None,
reduce_dims: Dims = None,
skipna: bool = True,
p0: dict[str, Any] | None = None,
bounds: dict[str, Any] | None = None,
Expand All @@ -8221,7 +8220,7 @@ def curvefit(
array of length `len(x)`. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x0, x1), *params)`.
reduce_dims : hashable or sequence of hashable
reduce_dims : str, Iterable of Hashable or None, optional
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
Expand Down Expand Up @@ -8272,6 +8271,7 @@ def curvefit(
if kwargs is None:
kwargs = {}

reduce_dims_: list[Hashable]
if not reduce_dims:
reduce_dims_ = []
elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable):
Expand Down
Loading