diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index 4a6c2dc7b4e..0e0573a0afe 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -9,6 +9,16 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: @@ -16,17 +26,6 @@ try: except ImportError: DaskArray = np.ndarray -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy] - class DatasetOpsMixin: __slots__ = () def _binary_op(self, other, f, reflexive=...): ... diff --git a/xarray/core/common.py b/xarray/core/common.py index d3001532aa0..74763829856 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from contextlib import suppress from html import escape @@ -36,10 +38,10 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset + from .types import T_DataWithCoords, T_Xarray from .variable import Variable from .weighted import Weighted -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") C = TypeVar("C") T = TypeVar("T") @@ -795,9 +797,7 @@ def groupby_bins( }, ) - def weighted( - self: T_DataWithCoords, weights: "DataArray" - ) -> "Weighted[T_DataWithCoords]": + def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]: """ Weighted operations. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 5bfb14793bb..9278577cbd6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -21,7 +21,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, ) @@ -36,11 +35,9 @@ from .variable import Variable if TYPE_CHECKING: - from .coordinates import Coordinates # noqa - from .dataarray import DataArray + from .coordinates import Coordinates from .dataset import Dataset - - T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + from .types import T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any: return name -def _get_coords_list(args) -> List["Coordinates"]: +def _get_coords_list(args) -> List[Coordinates]: coords_list = [] for arg in args: try: @@ -400,8 +397,8 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable] -) -> "Dataset": + variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] +) -> Dataset: """Create a dataset as quickly as possible. Beware: the `variables` dict is modified INPLACE. @@ -1729,7 +1726,7 @@ def _calc_idxminmax( return res -def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]: +def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with unified chunk size along all chunked dimensions. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e86bae08a3f..e0c1a1c9d17 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import warnings from typing import ( @@ -12,7 +14,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, cast, ) @@ -70,8 +71,6 @@ assert_unique_multiindex_level_names, ) -T_DataArray = TypeVar("T_DataArray", bound="DataArray") -T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) if TYPE_CHECKING: try: from dask.delayed import Delayed @@ -86,6 +85,8 @@ except ImportError: iris_Cube = None + from .types import T_DataArray, T_Xarray + def _infer_coords_and_dims( shape, coords, dims @@ -3698,11 +3699,11 @@ def unify_chunks(self) -> "DataArray": def map_blocks( self, - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> T_DSorDA: + ) -> T_Xarray: """ Apply a function to each block of this DataArray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a5eaa82cfdd..98374fa2ba3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -25,7 +25,6 @@ Sequence, Set, Tuple, - TypeVar, Union, cast, overload, @@ -109,8 +108,7 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - - T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset") + from .types import T_Xarray try: from dask.delayed import Delayed @@ -6630,11 +6628,11 @@ def unify_chunks(self) -> "Dataset": def map_blocks( self, - func: "Callable[..., T_DSorDA]", + func: "Callable[..., T_Xarray]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> "T_DSorDA": + ) -> "T_Xarray": """ Apply a function to each block of this Dataset. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 20ec3608ebb..4917714a9c2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import collections import itertools import operator from typing import ( + TYPE_CHECKING, Any, Callable, DefaultDict, @@ -12,7 +15,6 @@ Mapping, Sequence, Tuple, - TypeVar, Union, ) @@ -32,7 +34,8 @@ pass -T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +if TYPE_CHECKING: + from .types import T_Xarray def unzip(iterable): @@ -122,8 +125,8 @@ def make_meta(obj): def infer_template( - func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs -) -> T_DSorDA: + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs +) -> T_Xarray: """Infer return object by running the function on meta objects.""" meta_args = [make_meta(arg) for arg in (obj,) + args] @@ -162,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping def map_blocks( - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union[DataArray, Dataset] = None, -) -> T_DSorDA: +) -> T_Xarray: """Apply a function to each block of a DataArray or Dataset. .. warning:: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index e0fe57a9fb0..31718267ee0 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,17 +1,14 @@ +from __future__ import annotations + from distutils.version import LooseVersion -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union +from typing import Generic, Hashable, Mapping, Union import numpy as np from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array - -if TYPE_CHECKING: - from .dataarray import DataArray # noqa: F401 - from .dataset import Dataset # noqa: F401 - -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +from .types import T_Xarray def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp(Generic[T_DSorDA]): +class RollingExp(Generic[T_Xarray]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]): def __init__( self, - obj: T_DSorDA, + obj: T_Xarray, windows: Mapping[Hashable, Union[int, float]], window_type: str = "span", ): - self.obj: T_DSorDA = obj + self.obj: T_Xarray = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self, keep_attrs: bool = None) -> T_DSorDA: + def mean(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving average. @@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA: move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs ) - def sum(self, keep_attrs: bool = None) -> T_DSorDA: + def sum(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving sum. diff --git a/xarray/core/types.py b/xarray/core/types.py new file mode 100644 index 00000000000..26b1b381c30 --- /dev/null +++ b/xarray/core/types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar, Union + +import numpy as np + +if TYPE_CHECKING: + from .common import DataWithCoords + from .dataarray import DataArray + from .dataset import Dataset + from .groupby import DataArrayGroupBy, GroupBy + from .npcompat import ArrayLike + from .variable import Variable + + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray + +T_Dataset = TypeVar("T_Dataset", bound="Dataset") +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_Variable = TypeVar("T_Variable", bound="Variable") +# Maybe we rename this to T_Data or something less Fortran-y? +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + +ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] +DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] +VarCompatible = Union["Variable", "ScalarOrArray"] +GroupByIncompatible = Union["Variable", "GroupBy"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2798a4ab956..f96cbe63d07 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import itertools import numbers @@ -5,6 +7,7 @@ from collections import defaultdict from datetime import timedelta from typing import ( + TYPE_CHECKING, Any, Dict, Hashable, @@ -13,7 +16,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, ) @@ -66,17 +68,8 @@ # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) -VariableType = TypeVar("VariableType", bound="Variable") -"""Type annotation to be used when methods of Variable return self or a copy of self. -When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the -output as an instance of the subclass. - -Usage:: - - class Variable: - def f(self: VariableType, ...) -> VariableType: - ... -""" +if TYPE_CHECKING: + from .types import T_Variable class MissingDimensionsError(ValueError): @@ -362,7 +355,7 @@ def data(self, data): self._data = data def astype( - self: VariableType, + self: T_Variable, dtype, *, order=None, @@ -370,7 +363,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> VariableType: + ) -> T_Variable: """ Copy of the Variable object, with data cast to a specified type. @@ -775,7 +768,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: VariableType, key) -> VariableType: + def __getitem__(self: T_Variable, key) -> T_Variable: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -794,7 +787,7 @@ def __getitem__(self: VariableType, key) -> VariableType: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: + def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -974,12 +967,12 @@ def copy(self, deep=True, data=None): return self._replace(data=data) def _replace( - self: VariableType, + self: T_Variable, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> VariableType: + ) -> T_Variable: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -1101,7 +1094,7 @@ def to_numpy(self) -> np.ndarray: return data - def as_numpy(self: VariableType) -> VariableType: + def as_numpy(self: T_Variable) -> T_Variable: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) @@ -1136,11 +1129,11 @@ def _to_dense(self): return self.copy(deep=False) def isel( - self: VariableType, + self: T_Variable, indexers: Mapping[Any, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, - ) -> VariableType: + ) -> T_Variable: """Return a new array indexed along the specified dimension(s). Parameters diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e8838b07157..c31b24f53b5 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,15 +1,9 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union from . import duck_array_ops from .computation import dot from .pycompat import is_duck_dask_array - -if TYPE_CHECKING: - from .common import DataWithCoords # noqa: F401 - from .dataarray import DataArray, Dataset - -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") - +from .types import T_Xarray _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -59,7 +53,12 @@ """ -class Weighted(Generic[T_DataWithCoords]): +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + + +class Weighted(Generic[T_Xarray]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -73,7 +72,7 @@ class Weighted(Generic[T_DataWithCoords]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): + def __init__(self, obj: T_Xarray, weights: "DataArray"): """ Create a Weighted object @@ -116,7 +115,7 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj: T_DataWithCoords = obj + self.obj: T_Xarray = obj self.weights: "DataArray" = weights def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]): @@ -210,7 +209,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -221,7 +220,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -232,7 +231,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index b6b7f8cbac7..1ccede945bf 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -198,23 +198,24 @@ def inplace(): from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: from dask.array import Array as DaskArray except ImportError: DaskArray = np.ndarray +''' -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy]''' CLASS_PREAMBLE = """{newline} class {cls_name}: