From 707231e3ae17d2dcb68975e5189c42421a88d0df Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 13 Oct 2024 14:03:21 +0900 Subject: [PATCH] Reimplement DataTree aggregations (#9589) * Reimplement DataTree aggregations They now allow for dimensions that are missing on particular nodes, and use Xarray's standard generate_aggregations machinery, like aggregations for DataArray and Dataset. Fixes https://github.com/pydata/xarray/issues/8949, https://github.com/pydata/xarray/issues/8963 * add API docs on DataTree aggregations * remove incorrectly added sel methods * fix docstring reprs * mypy fix * fix self import * remove unimplemented agg methods * replace dim_arg_to_dims_set with parse_dims * add parse_dims_as_set * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mypy errors * change tests to match slightly different error now thrown --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: TomNicholas --- doc/api.rst | 38 +- xarray/core/_aggregations.py | 1308 ++++++++++++++++++++++++++ xarray/core/computation.py | 9 +- xarray/core/dataset.py | 14 +- xarray/core/datatree.py | 40 +- xarray/core/utils.py | 52 +- xarray/namedarray/_aggregations.py | 61 +- xarray/tests/test_dataset.py | 4 +- xarray/tests/test_datatree.py | 40 +- xarray/tests/test_utils.py | 10 +- xarray/util/generate_aggregations.py | 44 +- 11 files changed, 1507 insertions(+), 113 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index add1eed8944..2e671f1de69 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -828,30 +828,26 @@ Index into all nodes in the subtree simultaneously. .. DataTree.polyfit .. DataTree.curvefit -.. Aggregation -.. ----------- +Aggregation +----------- -.. Aggregate data in all nodes in the subtree simultaneously. +Aggregate data in all nodes in the subtree simultaneously. -.. .. autosummary:: -.. :toctree: generated/ +.. autosummary:: + :toctree: generated/ -.. DataTree.all -.. DataTree.any -.. DataTree.argmax -.. DataTree.argmin -.. DataTree.idxmax -.. DataTree.idxmin -.. DataTree.max -.. DataTree.min -.. DataTree.mean -.. DataTree.median -.. DataTree.prod -.. DataTree.sum -.. DataTree.std -.. DataTree.var -.. DataTree.cumsum -.. DataTree.cumprod + DataTree.all + DataTree.any + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod .. ndarray methods .. --------------- diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index b557ad44a32..6b1029791ea 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -19,6 +19,1314 @@ flox_available = module_available("flox") +class DataTreeAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.count() + + Group: / + Dimensions: () + Data variables: + foo int64 8B 5 + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.all() + + Group: / + Dimensions: () + Data variables: + foo bool 1B False + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict( + ... foo=( + ... "time", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + ... ), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.any() + + Group: / + Dimensions: () + Data variables: + foo bool 1B True + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.max() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 3.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.max(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.min() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.min(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.mean() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.6 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.mean(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.prod() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.prod(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.prod(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 0.0 + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.sum() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.sum(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> dt.sum(skipna=True, min_count=2) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 8.0 + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.std() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.02 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.std(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.std(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.14 + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.var() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.04 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.var(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> dt.var(skipna=True, ddof=1) + + Group: / + Dimensions: () + Data variables: + foo float64 8B 1.3 + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.median() + + Group: / + Dimensions: () + Data variables: + foo float64 8B 2.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.median(skipna=False) + + Group: / + Dimensions: () + Data variables: + foo float64 8B nan + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumsum() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 8.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumsum(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this DataTree's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool or None, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False, the new object will be + returned without attributes. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataTree + New DataTree with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + DataTree.cumulative + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Note that the methods on the ``cumulative`` method are more performant (with numbagg installed) + and better supported. ``cumsum`` and ``cumprod`` may be deprecated + in the future. + + Examples + -------- + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", np.array([1, 2, 3, 0, 2, np.nan]))), + ... coords=dict( + ... time=( + ... "time", + ... pd.date_range("2001-01-01", freq="ME", periods=6), + ... ), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... ) + >>> dt + + Group: / + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> dt.cumprod() + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 0.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> dt.cumprod(skipna=False) + + Group: / + Dimensions: (time: 6) + Dimensions without coordinates: time + Data variables: + foo (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + class DatasetAggregations: __slots__ = () diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 91a184d55cd..e2a6676252a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -31,7 +31,7 @@ from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import Dims, T_DataArray -from xarray.core.utils import is_dict_like, is_scalar, parse_dims +from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array @@ -1841,16 +1841,15 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + dot_dims: set[Hashable] if dim is None: # find dimensions that occur more than once dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dim = tuple(d for d, c in dim_counts.items() if c > 1) + dot_dims = {d for d, c in dim_counts.items() if c > 1} else: - dim = parse_dims(dim, all_dims=tuple(all_dims)) - - dot_dims: set[Hashable] = set(dim) + dot_dims = parse_dims_as_set(dim, all_dims=set(all_dims)) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..e0cd92bab6e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -118,6 +118,7 @@ is_duck_dask_array, is_scalar, maybe_wrap_array, + parse_dims_as_set, ) from xarray.core.variable import ( IndexVariable, @@ -6986,18 +6987,7 @@ def reduce( " Please use 'dim' instead." ) - if dim is None or dim is ...: - dims = set(self.dims) - elif isinstance(dim, str) or not isinstance(dim, Iterable): - dims = {dim} - else: - dims = set(dim) - - missing_dimensions = tuple(d for d in dims if d not in self.dims) - if missing_dimensions: - raise ValueError( - f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}" - ) + dims = parse_dims_as_set(dim, set(self._dims.keys())) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e8d96f111e2..7861d86a539 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload from xarray.core import utils +from xarray.core._aggregations import DataTreeAggregations from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import Coordinates, DataTreeCoordinates @@ -41,6 +42,7 @@ drop_dims_from_indexers, either_dict_or_kwargs, maybe_wrap_array, + parse_dims_as_set, ) from xarray.core.variable import Variable @@ -57,6 +59,7 @@ from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleMapping, CoercibleValue from xarray.core.types import ( + Dims, ErrorOptions, ErrorOptionsWithWarn, NetcdfWriteModes, @@ -399,6 +402,7 @@ def map( # type: ignore[override] class DataTree( NamedNode["DataTree"], + DataTreeAggregations, TreeAttrAccessMixin, Mapping[str, "DataArray | DataTree"], ): @@ -1612,6 +1616,38 @@ def to_zarr( **kwargs, ) + def _get_all_dims(self) -> set: + all_dims = set() + for node in self.subtree: + all_dims.update(node._node_dims) + return all_dims + + def reduce( + self, + func: Callable, + dim: Dims = None, + *, + keep_attrs: bool | None = None, + keepdims: bool = False, + numeric_only: bool = False, + **kwargs: Any, + ) -> Self: + """Reduce this tree by applying `func` along some dimension(s).""" + dims = parse_dims_as_set(dim, self._get_all_dims()) + result = {} + for node in self.subtree: + reduce_dims = [d for d in node._node_dims if d in dims] + node_result = node.dataset.reduce( + func, + reduce_dims, + keep_attrs=keep_attrs, + keepdims=keepdims, + numeric_only=numeric_only, + **kwargs, + ) + result[node.path] = node_result + return type(self).from_dict(result, name=self.name) + def _selective_indexing( self, func: Callable[[Dataset, Mapping[Any, Any]], Dataset], @@ -1622,9 +1658,7 @@ def _selective_indexing( dimensions and inherited coordinates gracefully by only applying indexing at each node selectively. """ - all_dims = set() - for node in self.subtree: - all_dims.update(node._node_dims) + all_dims = self._get_all_dims() indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims) result = {} diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e5168342e1e..e2781366265 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -59,6 +59,7 @@ MutableMapping, MutableSet, Sequence, + Set, ValuesView, ) from enum import Enum @@ -831,7 +832,7 @@ def drop_dims_from_indexers( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -841,7 +842,7 @@ def parse_dims( @overload -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -850,7 +851,7 @@ def parse_dims( ) -> tuple[Hashable, ...] | None | EllipsisType: ... -def parse_dims( +def parse_dims_as_tuple( dim: Dims, all_dims: tuple[Hashable, ...], *, @@ -891,6 +892,47 @@ def parse_dims( return tuple(dim) +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[True] = True, +) -> set[Hashable]: ... + + +@overload +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: Literal[False], +) -> set[Hashable] | None | EllipsisType: ... + + +def parse_dims_as_set( + dim: Dims, + all_dims: set[Hashable], + *, + check_exists: bool = True, + replace_none: bool = True, +) -> set[Hashable] | None | EllipsisType: + """Like parse_dims_as_tuple, but returning a set instead of a tuple.""" + # TODO: Consider removing parse_dims_as_tuple? + if dim is None or dim is ...: + if replace_none: + return all_dims + return dim + if isinstance(dim, str): + dim = {dim} + dim = set(dim) + if check_exists: + _check_dims(dim, all_dims) + return dim + + @overload def parse_ordered_dims( dim: Dims, @@ -958,7 +1000,7 @@ def parse_ordered_dims( return dims[:idx] + other_dims + dims[idx + 1 :] else: # mypy cannot resolve that the sequence cannot contain "..." - return parse_dims( # type: ignore[call-overload] + return parse_dims_as_tuple( # type: ignore[call-overload] dim=dim, all_dims=all_dims, check_exists=check_exists, @@ -966,7 +1008,7 @@ def parse_ordered_dims( ) -def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: +def _check_dims(dim: Set[Hashable], all_dims: Set[Hashable]) -> None: wrong_dims = (dim - all_dims) - {...} if wrong_dims: wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py index 48001422386..139cea83b5b 100644 --- a/xarray/namedarray/_aggregations.py +++ b/xarray/namedarray/_aggregations.py @@ -61,10 +61,7 @@ def count( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -116,8 +113,7 @@ def all( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -170,8 +166,7 @@ def any( -------- >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x", - ... np.array([True, True, True, True, True, False], dtype=bool), + ... "x", np.array([True, True, True, True, True, False], dtype=bool) ... ) >>> na Size: 6B @@ -230,10 +225,7 @@ def max( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -298,10 +290,7 @@ def min( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -370,10 +359,7 @@ def mean( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -449,10 +435,7 @@ def prod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -535,10 +518,7 @@ def sum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -618,10 +598,7 @@ def std( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -701,10 +678,7 @@ def var( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -780,10 +754,7 @@ def median( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -857,10 +828,7 @@ def cumsum( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) @@ -934,10 +902,7 @@ def cumprod( Examples -------- >>> from xarray.namedarray.core import NamedArray - >>> na = NamedArray( - ... "x", - ... np.array([1, 2, 3, 0, 2, np.nan]), - ... ) + >>> na = NamedArray("x", np.array([1, 2, 3, 0, 2, np.nan])) >>> na Size: 48B array([ 1., 2., 3., 0., 2., nan]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1178498de19..eafc11b630c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5615,7 +5615,7 @@ def test_reduce_bad_dim(self) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): data.mean(dim="bad_dim") @@ -5644,7 +5644,7 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: data = create_test_data() with pytest.raises( ValueError, - match=r"Dimensions \('bad_dim',\) not found in data dimensions", + match=re.escape("Dimension(s) 'bad_dim' do not exist"), ): getattr(data, func)(dim="bad_dim") diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9681cb2de76..ab9843565ac 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1661,9 +1661,8 @@ def test_sel(self): assert_equal(actual, expected) -class TestDSMethodInheritance: +class TestAggregations: - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1673,7 +1672,6 @@ def test_reduce_method(self): result = dt.any() assert_equal(result, expected) - @pytest.mark.xfail(reason="reduce methods not implemented yet") def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1683,7 +1681,6 @@ def test_nan_reduce_method(self): result = dt.mean() assert_equal(result, expected) - @pytest.mark.xfail(reason="cum methods not implemented yet") def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) @@ -1698,6 +1695,41 @@ def test_cum_method(self): result = dt.cumsum() assert_equal(result, expected) + def test_dim_argument(self): + dt = DataTree.from_dict( + { + "/a": xr.Dataset({"A": ("x", [1, 2])}), + "/b": xr.Dataset({"B": ("y", [1, 2])}), + } + ) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": 1.5}), + } + ) + actual = dt.mean() + assert_equal(expected, actual) + + actual = dt.mean(dim=...) + assert_equal(expected, actual) + + expected = DataTree.from_dict( + { + "/a": xr.Dataset({"A": 1.5}), + "/b": xr.Dataset({"B": ("y", [1.0, 2.0])}), + } + ) + actual = dt.mean("x") + assert_equal(expected, actual) + + with pytest.raises( + ValueError, + match=re.escape("Dimension(s) 'invalid' do not exist."), + ): + dt.mean("invalid") + class TestOps: @pytest.mark.xfail(reason="arithmetic not implemented yet") diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 9ef4a688302..f62fbb63cb5 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -283,16 +283,16 @@ def test_infix_dims_errors(supplied, all_): pytest.param(..., ..., id="ellipsis"), ], ) -def test_parse_dims(dim, expected) -> None: +def test_parse_dims_as_tuple(dim, expected) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=False) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=False) assert actual == expected def test_parse_dims_set() -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables dim = {"a", 1} - actual = utils.parse_dims(dim, all_dims) + actual = utils.parse_dims_as_tuple(dim, all_dims) assert set(actual) == dim @@ -301,7 +301,7 @@ def test_parse_dims_set() -> None: ) def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables - actual = utils.parse_dims(dim, all_dims, replace_none=True) + actual = utils.parse_dims_as_tuple(dim, all_dims, replace_none=True) assert actual == all_dims @@ -316,7 +316,7 @@ def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: def test_parse_dims_raises(dim) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables with pytest.raises(ValueError, match="'x'"): - utils.parse_dims(dim, all_dims, check_exists=True) + utils.parse_dims_as_tuple(dim, all_dims, check_exists=True) @pytest.mark.parametrize( diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index d2fc4f6d4e2..089ef558581 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -263,7 +263,7 @@ class DataStructure: create_example: str example_var_name: str numeric_only: bool = False - see_also_modules: tuple[str] = tuple + see_also_modules: tuple[str, ...] = tuple class Method: @@ -287,13 +287,13 @@ def __init__( self.additional_notes = additional_notes if bool_reduce: self.array_method = f"array_{name}" - self.np_example_array = """ - ... np.array([True, True, True, True, True, False], dtype=bool)""" + self.np_example_array = ( + """np.array([True, True, True, True, True, False], dtype=bool)""" + ) else: self.array_method = name - self.np_example_array = """ - ... np.array([1, 2, 3, 0, 2, np.nan])""" + self.np_example_array = """np.array([1, 2, 3, 0, 2, np.nan])""" @dataclass @@ -541,10 +541,27 @@ def generate_code(self, method, has_keep_attrs): ) +DATATREE_OBJECT = DataStructure( + name="DataTree", + create_example=""" + >>> dt = xr.DataTree( + ... xr.Dataset( + ... data_vars=dict(foo=("time", {example_array})), + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ), + ... )""", + example_var_name="dt", + numeric_only=True, + see_also_modules=("Dataset", "DataArray"), +) DATASET_OBJECT = DataStructure( name="Dataset", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -559,7 +576,8 @@ def generate_code(self, method, has_keep_attrs): DATAARRAY_OBJECT = DataStructure( name="DataArray", create_example=""" - >>> da = xr.DataArray({example_array}, + >>> da = xr.DataArray( + ... {example_array}, ... dims="time", ... coords=dict( ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), @@ -570,6 +588,15 @@ def generate_code(self, method, has_keep_attrs): numeric_only=False, see_also_modules=("Dataset",), ) +DATATREE_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=DATATREE_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=AGGREGATIONS_PREAMBLE, +) DATASET_GENERATOR = GenericAggregationGenerator( cls="", datastructure=DATASET_OBJECT, @@ -634,7 +661,7 @@ def generate_code(self, method, has_keep_attrs): create_example=""" >>> from xarray.namedarray.core import NamedArray >>> na = NamedArray( - ... "x",{example_array}, + ... "x", {example_array} ... )""", example_var_name="na", numeric_only=False, @@ -670,6 +697,7 @@ def write_methods(filepath, generators, preamble): write_methods( filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", generators=[ + DATATREE_GENERATOR, DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR,