Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate TypeVars in a single place #5569

Merged
merged 14 commits into from
Aug 21, 2021
21 changes: 10 additions & 11 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@ from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
from .npcompat import ArrayLike
from .types import (
DaCompatible,
DsCompatible,
GroupByIncompatible,
ScalarOrArray,
T_DataArray,
T_Dataset,
T_Variable,
VarCompatible,
)
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

# DatasetOpsMixin etc. are parent classes of Dataset etc.
T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")

ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
VarCompatible = Union[Variable, ScalarOrArray]
GroupByIncompatible = Union[Variable, GroupBy]

class DatasetOpsMixin:
__slots__ = ()
def _binary_op(self, other, f, reflexive=...): ...
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from contextlib import suppress
from html import escape
Expand Down Expand Up @@ -36,10 +38,10 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataWithCoords, T_DSorDA
from .variable import Variable
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")
Expand Down Expand Up @@ -795,9 +797,7 @@ def groupby_bins(
},
)

def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_DSorDA]:
"""
Weighted operations.

Expand Down
11 changes: 4 additions & 7 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -36,11 +35,9 @@
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .coordinates import Coordinates
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
from .types import T_DSorDA

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any:
return name


def _get_coords_list(args) -> List["Coordinates"]:
def _get_coords_list(args) -> List[Coordinates]:
coords_list = []
for arg in args:
try:
Expand Down Expand Up @@ -401,7 +398,7 @@ def apply_dict_of_variables_vfunc(

def _fast_dataset(
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
) -> "Dataset":
) -> Dataset:
"""Create a dataset as quickly as possible.

Beware: the `variables` dict is modified INPLACE.
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import warnings
from typing import (
Expand All @@ -12,7 +14,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -76,8 +77,6 @@
assert_unique_multiindex_level_names,
)

T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
if TYPE_CHECKING:
try:
from dask.delayed import Delayed
Expand All @@ -92,6 +91,8 @@
except ImportError:
iris_Cube = None

from .types import T_DSorDA


def _infer_coords_and_dims(
shape, coords, dims
Expand Down
4 changes: 1 addition & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
Expand Down Expand Up @@ -110,8 +109,7 @@
from ..backends import AbstractDataStore, ZarrStore
from .dataarray import DataArray
from .merge import CoercibleMapping

T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset")
from .types import T_DSorDA

try:
from dask.delayed import Delayed
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

try:
import dask
import dask.array
Expand All @@ -11,6 +13,7 @@
import itertools
import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand All @@ -21,7 +24,6 @@
Mapping,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -33,7 +35,8 @@
from .dataarray import DataArray
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
if TYPE_CHECKING:
from .types import T_DSorDA


def unzip(iterable):
Expand Down
19 changes: 13 additions & 6 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from distutils.version import LooseVersion
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union

Expand All @@ -7,12 +9,6 @@
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")


def _get_alpha(com=None, span=None, halflife=None, alpha=None):
# pandas defines in terms of com (converting to alpha in the algo)
Expand Down Expand Up @@ -79,6 +75,17 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because
# a) it needs to be defined (can't be a string) b) it can't be behind an `if
# TYPE_CHECKING` branch and c) we have import errors if we import it without at the
# module level like: from .types import T_DSorDA

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")


class RollingExp(Generic[T_DSorDA]):
"""
Exponentially-weighted moving window object.
Expand Down
31 changes: 31 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Union

import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
from .npcompat import ArrayLike
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

T_Dataset = TypeVar("T_Dataset", bound=Dataset)
T_DataArray = TypeVar("T_DataArray", bound=DataArray)
T_Variable = TypeVar("T_Variable", bound=Variable)
# Maybe we rename this to T_Data or something less Fortran-y?
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords)

ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
VarCompatible = Union[Variable, ScalarOrArray]
GroupByIncompatible = Union[Variable, GroupBy]
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
33 changes: 13 additions & 20 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import copy
import itertools
import numbers
import warnings
from collections import defaultdict
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Dict,
Hashable,
Expand All @@ -13,7 +16,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand Down Expand Up @@ -58,17 +60,8 @@
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)

VariableType = TypeVar("VariableType", bound="Variable")
"""Type annotation to be used when methods of Variable return self or a copy of self.
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
output as an instance of the subclass.

Usage::

class Variable:
def f(self: VariableType, ...) -> VariableType:
...
"""
if TYPE_CHECKING:
from .types import T_Variable


class MissingDimensionsError(ValueError):
Expand Down Expand Up @@ -357,15 +350,15 @@ def data(self, data):
self._data = data

def astype(
self: VariableType,
self: T_Variable,
dtype,
*,
order=None,
casting=None,
subok=None,
copy=None,
keep_attrs=True,
) -> VariableType:
) -> T_Variable:
"""
Copy of the Variable object, with data cast to a specified type.

Expand Down Expand Up @@ -763,7 +756,7 @@ def _broadcast_indexes_vectorized(self, key):

return out_dims, VectorizedIndexer(tuple(out_key)), new_order

def __getitem__(self: VariableType, key) -> VariableType:
def __getitem__(self: T_Variable, key) -> T_Variable:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.

Expand All @@ -782,7 +775,7 @@ def __getitem__(self: VariableType, key) -> VariableType:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)

def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable:
"""Used by IndexVariable to return IndexVariable objects when possible."""
return self._replace(dims=dims, data=data)

Expand Down Expand Up @@ -962,12 +955,12 @@ def copy(self, deep=True, data=None):
return self._replace(data=data)

def _replace(
self: VariableType,
self: T_Variable,
dims=_default,
data=_default,
attrs=_default,
encoding=_default,
) -> VariableType:
) -> T_Variable:
if dims is _default:
dims = copy.copy(self._dims)
if data is _default:
Expand Down Expand Up @@ -1100,11 +1093,11 @@ def _to_dense(self):
return self.copy(deep=False)

def isel(
self: VariableType,
self: T_Variable,
indexers: Mapping[Hashable, Any] = None,
missing_dims: str = "raise",
**indexers_kwargs: Any,
) -> VariableType:
) -> T_Variable:
"""Return a new array indexed along the specified dimension(s).

Parameters
Expand Down
Loading