From dea2e4d5ce61e473a8d2fc335b9b167acfba4cac Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 2 Jul 2021 16:54:06 -0700 Subject: [PATCH 01/12] Consolidate type bounds in a single place --- xarray/core/_typed_ops.pyi | 21 +++++++++---------- xarray/core/dataarray.py | 4 +--- xarray/core/types.py | 27 ++++++++++++++++++++++++ xarray/core/variable.py | 41 +++++++++---------------------------- xarray/util/generate_ops.py | 21 ++++++++++--------- 5 files changed, 59 insertions(+), 55 deletions(-) create mode 100644 xarray/core/types.py diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index 4a6c2dc7b4e..0e0573a0afe 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -9,6 +9,16 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: @@ -16,17 +26,6 @@ try: except ImportError: DaskArray = np.ndarray -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy] - class DatasetOpsMixin: __slots__ = () def _binary_op(self, other, f, reflexive=...): ... diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7392f7e1e6b..49d6136e1fd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -12,7 +12,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, cast, ) @@ -61,6 +60,7 @@ from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs +from .types import T_DSorDA from .utils import ( Default, HybridMappingProxy, @@ -76,8 +76,6 @@ assert_unique_multiindex_level_names, ) -T_DataArray = TypeVar("T_DataArray", bound="DataArray") -T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) if TYPE_CHECKING: try: from dask.delayed import Delayed diff --git a/xarray/core/types.py b/xarray/core/types.py new file mode 100644 index 00000000000..40c8c523a44 --- /dev/null +++ b/xarray/core/types.py @@ -0,0 +1,27 @@ +from typing import TypeVar, Union + +from .dataarray import DataArray +from .dataset import Dataset +from .groupby import DataArrayGroupBy, GroupBy +from .npcompat import ArrayLike +from .variable import Variable + +import numpy as np + +try: + from dask.array import Array as DaskArray +except ImportError: + DaskArray = np.ndarray +# DatasetOpsMixin etc. are parent classes of Dataset etc. + +T_Dataset = TypeVar("T_Dataset", bound="Dataset") +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_Variable = TypeVar("T_Variable", bound="Variable") +# Maybe we rename this to T_Data or something less Fortran-y? +T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) + +ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] +DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] +DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] +VarCompatible = Union[Variable, ScalarOrArray] +GroupByIncompatible = Union[Variable, GroupBy] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index ace09c6f482..0bbfa3cfc02 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -4,18 +4,7 @@ import warnings from collections import defaultdict from datetime import timedelta -from typing import ( - Any, - Dict, - Hashable, - List, - Mapping, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd @@ -58,17 +47,7 @@ # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) -VariableType = TypeVar("VariableType", bound="Variable") -"""Type annotation to be used when methods of Variable return self or a copy of self. -When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the -output as an instance of the subclass. - -Usage:: - - class Variable: - def f(self: VariableType, ...) -> VariableType: - ... -""" +from .types import T_Variable class MissingDimensionsError(ValueError): @@ -357,7 +336,7 @@ def data(self, data): self._data = data def astype( - self: VariableType, + self: T_Variable, dtype, *, order=None, @@ -365,7 +344,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> VariableType: + ) -> T_Variable: """ Copy of the Variable object, with data cast to a specified type. @@ -763,7 +742,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: VariableType, key) -> VariableType: + def __getitem__(self: T_Variable, key) -> T_Variable: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -782,7 +761,7 @@ def __getitem__(self: VariableType, key) -> VariableType: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: + def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -962,12 +941,12 @@ def copy(self, deep=True, data=None): return self._replace(data=data) def _replace( - self: VariableType, + self: T_Variable, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> VariableType: + ) -> T_Variable: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -1100,11 +1079,11 @@ def _to_dense(self): return self.copy(deep=False) def isel( - self: VariableType, + self: T_Variable, indexers: Mapping[Hashable, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, - ) -> VariableType: + ) -> T_Variable: """Return a new array indexed along the specified dimension(s). Parameters diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index b6b7f8cbac7..1ccede945bf 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -198,23 +198,24 @@ def inplace(): from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: from dask.array import Array as DaskArray except ImportError: DaskArray = np.ndarray +''' -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy]''' CLASS_PREAMBLE = """{newline} class {cls_name}: From b91fa128fded1365b8e4dcd2a2e31c6b87a6945b Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 2 Jul 2021 17:00:14 -0700 Subject: [PATCH 02/12] --- xarray/core/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 40c8c523a44..22cd7b93844 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -12,7 +12,6 @@ from dask.array import Array as DaskArray except ImportError: DaskArray = np.ndarray -# DatasetOpsMixin etc. are parent classes of Dataset etc. T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") From e736124c1595e612db5f01189c4c34d50035ae93 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 2 Jul 2021 17:12:40 -0700 Subject: [PATCH 03/12] --- xarray/core/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 22cd7b93844..6ff34ad1558 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,13 +1,13 @@ from typing import TypeVar, Union +import numpy as np + from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy from .npcompat import ArrayLike from .variable import Variable -import numpy as np - try: from dask.array import Array as DaskArray except ImportError: From 85e35615c126ca0d016c4b8161b7edd3757be7fd Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 3 Jul 2021 13:21:41 -0700 Subject: [PATCH 04/12] More consolidation --- xarray/core/common.py | 8 +++---- xarray/core/computation.py | 2 +- xarray/core/dataarray.py | 4 +++- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 6 +++++- xarray/core/rolling_exp.py | 20 +++++++++++------ xarray/core/types.py | 44 +++++++++++++++++++++----------------- xarray/core/variable.py | 17 +++++++++++++-- xarray/core/weighted.py | 32 +++++++++++++++------------ 9 files changed, 84 insertions(+), 51 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 7b6e9198b43..1872820f7f9 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,3 +1,4 @@ +from __future__ import annotations import warnings from contextlib import suppress from html import escape @@ -38,8 +39,9 @@ from .dataset import Dataset from .variable import Variable from .weighted import Weighted + from .types import T_DataWithCoords + from .types import T_DSorDA -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") C = TypeVar("C") T = TypeVar("T") @@ -795,9 +797,7 @@ def groupby_bins( }, ) - def weighted( - self: T_DataWithCoords, weights: "DataArray" - ) -> "Weighted[T_DataWithCoords]": + def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_DSorDA]: """ Weighted operations. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cd9e22d90db..14d7df513b3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -40,7 +40,7 @@ from .dataarray import DataArray from .dataset import Dataset - T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + from .types import T_DSorDA _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 49d6136e1fd..7b986da4f3b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,3 +1,4 @@ +from __future__ import annotations import datetime import warnings from typing import ( @@ -60,7 +61,6 @@ from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs -from .types import T_DSorDA from .utils import ( Default, HybridMappingProxy, @@ -90,6 +90,8 @@ except ImportError: iris_Cube = None + from .types import T_DSorDA + def _infer_coords_and_dims( shape, coords, dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 13da8cfad03..a4b1390ca32 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -111,7 +111,7 @@ from .dataarray import DataArray from .merge import CoercibleMapping - T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset") + from .types import T_DSorDA try: from dask.delayed import Delayed diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 795d30af28f..fc66049855e 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + try: import dask import dask.array @@ -20,6 +22,7 @@ List, Mapping, Sequence, + TYPE_CHECKING, Tuple, TypeVar, Union, @@ -33,7 +36,8 @@ from .dataarray import DataArray from .dataset import Dataset -T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +if TYPE_CHECKING: + from .types import T_DSorDA def unzip(iterable): diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index e0fe57a9fb0..d2da644a9b0 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,5 +1,6 @@ +from __future__ import annotations from distutils.version import LooseVersion -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Union, TypeVar import numpy as np @@ -7,12 +8,6 @@ from .pdcompat import count_not_none from .pycompat import is_duck_dask_array -if TYPE_CHECKING: - from .dataarray import DataArray # noqa: F401 - from .dataset import Dataset # noqa: F401 - -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") - def _get_alpha(com=None, span=None, halflife=None, alpha=None): # pandas defines in terms of com (converting to alpha in the algo) @@ -79,6 +74,17 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) +# We seem to need to redefine this here, rather than in `core.types`, because it a) +# needs to be defined (can't be a string) b) can't be behind an `if TYPE_CHECKING` +# branch and c) we have import errors if we import it without at the module level like: +# from .types import T_DSorDA + +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") + + class RollingExp(Generic[T_DSorDA]): """ Exponentially-weighted moving window object. diff --git a/xarray/core/types.py b/xarray/core/types.py index 6ff34ad1558..6ff78c2f1ac 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,26 +1,30 @@ -from typing import TypeVar, Union +from __future__ import annotations +from typing import TYPE_CHECKING, TypeVar, Union import numpy as np -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, GroupBy -from .npcompat import ArrayLike -from .variable import Variable +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + from .groupby import DataArrayGroupBy, GroupBy + from .npcompat import ArrayLike + from .variable import Variable + from .common import DataWithCoords -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray -T_Dataset = TypeVar("T_Dataset", bound="Dataset") -T_DataArray = TypeVar("T_DataArray", bound="DataArray") -T_Variable = TypeVar("T_Variable", bound="Variable") -# Maybe we rename this to T_Data or something less Fortran-y? -T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) + T_Dataset = TypeVar("T_Dataset", bound=Dataset) + T_DataArray = TypeVar("T_DataArray", bound=DataArray) + T_Variable = TypeVar("T_Variable", bound=Variable) + # Maybe we rename this to T_Data or something less Fortran-y? + T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) + T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords) -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy] + ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] + DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] + DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] + VarCompatible = Union[Variable, ScalarOrArray] + GroupByIncompatible = Union[Variable, GroupBy] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0bbfa3cfc02..8cac74854ab 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,10 +1,22 @@ +from __future__ import annotations import copy import itertools import numbers import warnings from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + Hashable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np import pandas as pd @@ -47,7 +59,8 @@ # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) -from .types import T_Variable +if TYPE_CHECKING: + from .types import T_Variable class MissingDimensionsError(ValueError): diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e8838b07157..effd995ad05 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,16 +1,9 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union, TypeVar from . import duck_array_ops from .computation import dot from .pycompat import is_duck_dask_array -if TYPE_CHECKING: - from .common import DataWithCoords # noqa: F401 - from .dataarray import DataArray, Dataset - -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") - - _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -59,7 +52,18 @@ """ -class Weighted(Generic[T_DataWithCoords]): +# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because +# a) it needs to be defined (can't be a string) b) it can't be behind an `if +# TYPE_CHECKING` branch and c) we have import errors if we import it without at the +# module level like: from .types import T_DSorDA + +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") + + +class Weighted(Generic[T_DSorDA]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -73,7 +77,7 @@ class Weighted(Generic[T_DataWithCoords]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): + def __init__(self, obj: T_DSorDA, weights: "DataArray"): """ Create a Weighted object @@ -116,7 +120,7 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj: T_DataWithCoords = obj + self.obj: T_DSorDA = obj self.weights: "DataArray" = weights def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]): @@ -210,7 +214,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_DSorDA: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -221,7 +225,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_DSorDA: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -232,7 +236,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_DSorDA: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs From 3bbed834ebfdee7613e69ab4bb82909f8954deaf Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 3 Jul 2021 14:07:18 -0700 Subject: [PATCH 05/12] --- xarray/core/common.py | 4 ++-- xarray/core/computation.py | 9 +++------ xarray/core/dataarray.py | 1 + xarray/core/dataset.py | 2 -- xarray/core/parallel.py | 3 +-- xarray/core/rolling_exp.py | 11 ++++++----- xarray/core/types.py | 3 ++- xarray/core/variable.py | 3 ++- xarray/core/weighted.py | 2 +- 9 files changed, 18 insertions(+), 20 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 1872820f7f9..4331eefadf1 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,4 +1,5 @@ from __future__ import annotations + import warnings from contextlib import suppress from html import escape @@ -37,10 +38,9 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset + from .types import T_DataWithCoords, T_DSorDA from .variable import Variable from .weighted import Weighted - from .types import T_DataWithCoords - from .types import T_DSorDA C = TypeVar("C") diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 14d7df513b3..c4e8eb8a1bc 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -21,7 +21,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, ) @@ -36,10 +35,8 @@ from .variable import Variable if TYPE_CHECKING: - from .coordinates import Coordinates # noqa - from .dataarray import DataArray + from .coordinates import Coordinates from .dataset import Dataset - from .types import T_DSorDA _NO_FILL_VALUE = utils.ReprObject("") @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any: return name -def _get_coords_list(args) -> List["Coordinates"]: +def _get_coords_list(args) -> List[Coordinates]: coords_list = [] for arg in args: try: @@ -401,7 +398,7 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] -) -> "Dataset": +) -> Dataset: """Create a dataset as quickly as possible. Beware: the `variables` dict is modified INPLACE. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7b986da4f3b..00fab724396 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,4 +1,5 @@ from __future__ import annotations + import datetime import warnings from typing import ( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a4b1390ca32..6e303a2a517 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -25,7 +25,6 @@ Sequence, Set, Tuple, - TypeVar, Union, cast, overload, @@ -110,7 +109,6 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - from .types import T_DSorDA try: diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fc66049855e..7387e646e4c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -13,6 +13,7 @@ import itertools import operator from typing import ( + TYPE_CHECKING, Any, Callable, DefaultDict, @@ -22,9 +23,7 @@ List, Mapping, Sequence, - TYPE_CHECKING, Tuple, - TypeVar, Union, ) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index d2da644a9b0..70a4323452c 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,6 +1,7 @@ from __future__ import annotations + from distutils.version import LooseVersion -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Union, TypeVar +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union import numpy as np @@ -74,10 +75,10 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -# We seem to need to redefine this here, rather than in `core.types`, because it a) -# needs to be defined (can't be a string) b) can't be behind an `if TYPE_CHECKING` -# branch and c) we have import errors if we import it without at the module level like: -# from .types import T_DSorDA +# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because +# a) it needs to be defined (can't be a string) b) it can't be behind an `if +# TYPE_CHECKING` branch and c) we have import errors if we import it without at the +# module level like: from .types import T_DSorDA if TYPE_CHECKING: from .dataarray import DataArray diff --git a/xarray/core/types.py b/xarray/core/types.py index 6ff78c2f1ac..91fa6621b79 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,15 +1,16 @@ from __future__ import annotations + from typing import TYPE_CHECKING, TypeVar, Union import numpy as np if TYPE_CHECKING: + from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy from .npcompat import ArrayLike from .variable import Variable - from .common import DataWithCoords try: from dask.array import Array as DaskArray diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8cac74854ab..56eba834cea 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,4 +1,5 @@ from __future__ import annotations + import copy import itertools import numbers @@ -6,6 +7,7 @@ from collections import defaultdict from datetime import timedelta from typing import ( + TYPE_CHECKING, Any, Dict, Hashable, @@ -15,7 +17,6 @@ Sequence, Tuple, Union, - TYPE_CHECKING, ) import numpy as np diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index effd995ad05..fe753de4de9 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union, TypeVar +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union from . import duck_array_ops from .computation import dot From 1d41f2322d4a56282fdd4ce1a99bff4b764b3852 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 20 Aug 2021 08:24:16 -0700 Subject: [PATCH 06/12] Update xarray/core/types.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/types.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 91fa6621b79..1f6352382ea 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -24,8 +24,8 @@ T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords) - ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] - DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] - DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] - VarCompatible = Union[Variable, ScalarOrArray] - GroupByIncompatible = Union[Variable, GroupBy] + ScalarOrArray = ArrayLike | np.generic | np.ndarray | DaskArray + DsCompatible = Dataset | DataArray | Variable | GroupBy | ScalarOrArray + DaCompatible = DataArray | Variable | DataArrayGroupBy | ScalarOrArray + VarCompatible = Variable | ScalarOrArray + GroupByIncompatible = Variable | GroupBy From d0886d827e8be65c470a47efee4b3cfe8b7ee4d1 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 01:07:35 -0700 Subject: [PATCH 07/12] Update xarray/core/types.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/types.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 1f6352382ea..312f534ea11 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -17,15 +17,15 @@ except ImportError: DaskArray = np.ndarray - T_Dataset = TypeVar("T_Dataset", bound=Dataset) - T_DataArray = TypeVar("T_DataArray", bound=DataArray) - T_Variable = TypeVar("T_Variable", bound=Variable) - # Maybe we rename this to T_Data or something less Fortran-y? - T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) - T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords) +T_Dataset = TypeVar("T_Dataset", bound="Dataset") +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_Variable = TypeVar("T_Variable", bound="Variable") +# Maybe we rename this to T_Data or something less Fortran-y? +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") - ScalarOrArray = ArrayLike | np.generic | np.ndarray | DaskArray - DsCompatible = Dataset | DataArray | Variable | GroupBy | ScalarOrArray - DaCompatible = DataArray | Variable | DataArrayGroupBy | ScalarOrArray - VarCompatible = Variable | ScalarOrArray - GroupByIncompatible = Variable | GroupBy +ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] +DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] +VarCompatible = Union["Variable", "ScalarOrArray"] +GroupByIncompatible = Union["Variable", "GroupBy"] From 715b4441cb52f4e4e9f4cf56233971338d6b5b42 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 21 Aug 2021 13:04:57 -0700 Subject: [PATCH 08/12] --- xarray/core/dataarray.py | 2 +- xarray/core/parallel.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 445339fdf84..ffd83fa6ea2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -85,7 +85,7 @@ except ImportError: iris_Cube = None - from .types import T_DSorDA + from .types import T_DataArray, T_DSorDA def _infer_coords_and_dims( diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 5e3f3c3cb72..8cd028799b7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,14 +1,5 @@ from __future__ import annotations -try: - import dask - import dask.array - from dask.array.utils import meta_from_array - from dask.highlevelgraph import HighLevelGraph - -except ImportError: - pass - import collections import itertools import operator @@ -33,6 +24,16 @@ from .dataarray import DataArray from .dataset import Dataset +try: + import dask + import dask.array + from dask.array.utils import meta_from_array + from dask.highlevelgraph import HighLevelGraph + +except ImportError: + pass + + if TYPE_CHECKING: from .types import T_DSorDA From 1893237eb4117ecf275e2ddc20f07c26594a8d47 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 21 Aug 2021 13:07:55 -0700 Subject: [PATCH 09/12] Rename T_DSorDA to T_Xarray --- xarray/core/common.py | 4 ++-- xarray/core/computation.py | 4 ++-- xarray/core/dataarray.py | 6 +++--- xarray/core/dataset.py | 6 +++--- xarray/core/parallel.py | 10 +++++----- xarray/core/rolling_exp.py | 16 ++++++++-------- xarray/core/types.py | 2 +- xarray/core/weighted.py | 18 +++++++++--------- 8 files changed, 33 insertions(+), 33 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index c90d4ae15d1..74763829856 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import T_DataWithCoords, T_DSorDA + from .types import T_DataWithCoords, T_Xarray from .variable import Variable from .weighted import Weighted @@ -797,7 +797,7 @@ def groupby_bins( }, ) - def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_DSorDA]: + def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]: """ Weighted operations. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c4e8eb8a1bc..9278577cbd6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from .coordinates import Coordinates from .dataset import Dataset - from .types import T_DSorDA + from .types import T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -1726,7 +1726,7 @@ def _calc_idxminmax( return res -def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]: +def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with unified chunk size along all chunked dimensions. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ffd83fa6ea2..e0c1a1c9d17 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -85,7 +85,7 @@ except ImportError: iris_Cube = None - from .types import T_DataArray, T_DSorDA + from .types import T_DataArray, T_Xarray def _infer_coords_and_dims( @@ -3699,11 +3699,11 @@ def unify_chunks(self) -> "DataArray": def map_blocks( self, - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> T_DSorDA: + ) -> T_Xarray: """ Apply a function to each block of this DataArray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 61abd6ff410..3eee87b2fe1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -108,7 +108,7 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - from .types import T_DSorDA + from .types import T_Xarray try: from dask.delayed import Delayed @@ -6623,11 +6623,11 @@ def unify_chunks(self) -> "Dataset": def map_blocks( self, - func: "Callable[..., T_DSorDA]", + func: "Callable[..., T_Xarray]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> "T_DSorDA": + ) -> "T_Xarray": """ Apply a function to each block of this Dataset. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8cd028799b7..4917714a9c2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: - from .types import T_DSorDA + from .types import T_Xarray def unzip(iterable): @@ -125,8 +125,8 @@ def make_meta(obj): def infer_template( - func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs -) -> T_DSorDA: + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs +) -> T_Xarray: """Infer return object by running the function on meta objects.""" meta_args = [make_meta(arg) for arg in (obj,) + args] @@ -165,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping def map_blocks( - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union[DataArray, Dataset] = None, -) -> T_DSorDA: +) -> T_Xarray: """Apply a function to each block of a DataArray or Dataset. .. warning:: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 70a4323452c..03ee492a0a3 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -75,18 +75,18 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because +# We seem to need to redefine T_Xarray here, rather than importing `core.types`, because # a) it needs to be defined (can't be a string) b) it can't be behind an `if # TYPE_CHECKING` branch and c) we have import errors if we import it without at the -# module level like: from .types import T_DSorDA +# module level like: from .types import T_Xarray if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") -class RollingExp(Generic[T_DSorDA]): +class RollingExp(Generic[T_Xarray]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -110,16 +110,16 @@ class RollingExp(Generic[T_DSorDA]): def __init__( self, - obj: T_DSorDA, + obj: T_Xarray, windows: Mapping[Hashable, Union[int, float]], window_type: str = "span", ): - self.obj: T_DSorDA = obj + self.obj: T_Xarray = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self, keep_attrs: bool = None) -> T_DSorDA: + def mean(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving average. @@ -146,7 +146,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA: move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs ) - def sum(self, keep_attrs: bool = None) -> T_DSorDA: + def sum(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving sum. diff --git a/xarray/core/types.py b/xarray/core/types.py index 312f534ea11..26b1b381c30 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -21,7 +21,7 @@ T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") # Maybe we rename this to T_Data or something less Fortran-y? -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index fe753de4de9..e6590a7519b 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -52,18 +52,18 @@ """ -# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because +# We seem to need to redefine T_Xarray here, rather than importing `core.types`, because # a) it needs to be defined (can't be a string) b) it can't be behind an `if # TYPE_CHECKING` branch and c) we have import errors if we import it without at the -# module level like: from .types import T_DSorDA +# module level like: from .types import T_Xarray if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") -class Weighted(Generic[T_DSorDA]): +class Weighted(Generic[T_Xarray]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -77,7 +77,7 @@ class Weighted(Generic[T_DSorDA]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_DSorDA, weights: "DataArray"): + def __init__(self, obj: T_Xarray, weights: "DataArray"): """ Create a Weighted object @@ -120,7 +120,7 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj: T_DSorDA = obj + self.obj: T_Xarray = obj self.weights: "DataArray" = weights def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]): @@ -214,7 +214,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> T_DSorDA: + ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -225,7 +225,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DSorDA: + ) -> T_Xarray: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -236,7 +236,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DSorDA: + ) -> T_Xarray: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs From cdb793ebdfd7ed6dbe478da038da9bf5bced345a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:00:00 -0700 Subject: [PATCH 10/12] Update xarray/core/weighted.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/weighted.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e6590a7519b..9a37ce9ca49 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -52,15 +52,11 @@ """ -# We seem to need to redefine T_Xarray here, rather than importing `core.types`, because -# a) it needs to be defined (can't be a string) b) it can't be behind an `if -# TYPE_CHECKING` branch and c) we have import errors if we import it without at the -# module level like: from .types import T_Xarray - if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +from .types import T_Xarray class Weighted(Generic[T_Xarray]): From 9cf5755f4972d15502a35785463eda6c4c96d460 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:00:12 -0700 Subject: [PATCH 11/12] Update xarray/core/rolling_exp.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/rolling_exp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 03ee492a0a3..4993c583127 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -75,15 +75,11 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -# We seem to need to redefine T_Xarray here, rather than importing `core.types`, because -# a) it needs to be defined (can't be a string) b) it can't be behind an `if -# TYPE_CHECKING` branch and c) we have import errors if we import it without at the -# module level like: from .types import T_Xarray - if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +from .types import T_Xarray class RollingExp(Generic[T_Xarray]): From 4a00bc00ae0ab82c7edcd52a24aaf6bf2a8fd0ca Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 21 Aug 2021 16:09:36 -0700 Subject: [PATCH 12/12] . --- xarray/core/rolling_exp.py | 10 ++-------- xarray/core/weighted.py | 5 ++--- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 4993c583127..31718267ee0 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,13 +1,14 @@ from __future__ import annotations from distutils.version import LooseVersion -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union +from typing import Generic, Hashable, Mapping, Union import numpy as np from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array +from .types import T_Xarray def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -75,13 +76,6 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -if TYPE_CHECKING: - from .dataarray import DataArray - from .dataset import Dataset - -from .types import T_Xarray - - class RollingExp(Generic[T_Xarray]): """ Exponentially-weighted moving window object. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 9a37ce9ca49..c31b24f53b5 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union from . import duck_array_ops from .computation import dot from .pycompat import is_duck_dask_array +from .types import T_Xarray _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -56,8 +57,6 @@ from .dataarray import DataArray from .dataset import Dataset -from .types import T_Xarray - class Weighted(Generic[T_Xarray]): """An object that implements weighted operations.