Skip to content

Commit

Permalink
Consolidate TypeVars in a single place (#5569)
Browse files Browse the repository at this point in the history
* Consolidate type bounds in a single place

* More consolidation

* Update xarray/core/types.py

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* Update xarray/core/types.py

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* Rename T_DSorDA to T_Xarray

* Update xarray/core/weighted.py

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* Update xarray/core/rolling_exp.py

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>

* .

Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
  • Loading branch information
max-sixty and Illviljan authored Aug 21, 2021
1 parent befd1b9 commit 6b59d9a
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 97 deletions.
21 changes: 10 additions & 11 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@ from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
from .npcompat import ArrayLike
from .types import (
DaCompatible,
DsCompatible,
GroupByIncompatible,
ScalarOrArray,
T_DataArray,
T_Dataset,
T_Variable,
VarCompatible,
)
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

# DatasetOpsMixin etc. are parent classes of Dataset etc.
T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")

ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
VarCompatible = Union[Variable, ScalarOrArray]
GroupByIncompatible = Union[Variable, GroupBy]

class DatasetOpsMixin:
__slots__ = ()
def _binary_op(self, other, f, reflexive=...): ...
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from contextlib import suppress
from html import escape
Expand Down Expand Up @@ -36,10 +38,10 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataWithCoords, T_Xarray
from .variable import Variable
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")
Expand Down Expand Up @@ -795,9 +797,7 @@ def groupby_bins(
},
)

def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]:
"""
Weighted operations.
Expand Down
15 changes: 6 additions & 9 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -36,11 +35,9 @@
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .coordinates import Coordinates
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
from .types import T_Xarray

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any:
return name


def _get_coords_list(args) -> List["Coordinates"]:
def _get_coords_list(args) -> List[Coordinates]:
coords_list = []
for arg in args:
try:
Expand Down Expand Up @@ -400,8 +397,8 @@ def apply_dict_of_variables_vfunc(


def _fast_dataset(
variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable]
) -> "Dataset":
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
) -> Dataset:
"""Create a dataset as quickly as possible.
Beware: the `variables` dict is modified INPLACE.
Expand Down Expand Up @@ -1729,7 +1726,7 @@ def _calc_idxminmax(
return res


def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns
new objects with unified chunk size along all chunked dimensions.
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import warnings
from typing import (
Expand All @@ -12,7 +14,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -70,8 +71,6 @@
assert_unique_multiindex_level_names,
)

T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
if TYPE_CHECKING:
try:
from dask.delayed import Delayed
Expand All @@ -86,6 +85,8 @@
except ImportError:
iris_Cube = None

from .types import T_DataArray, T_Xarray


def _infer_coords_and_dims(
shape, coords, dims
Expand Down Expand Up @@ -3698,11 +3699,11 @@ def unify_chunks(self) -> "DataArray":

def map_blocks(
self,
func: Callable[..., T_DSorDA],
func: Callable[..., T_Xarray],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union["DataArray", "Dataset"] = None,
) -> T_DSorDA:
) -> T_Xarray:
"""
Apply a function to each block of this DataArray.
Expand Down
8 changes: 3 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
Expand Down Expand Up @@ -109,8 +108,7 @@
from ..backends import AbstractDataStore, ZarrStore
from .dataarray import DataArray
from .merge import CoercibleMapping

T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset")
from .types import T_Xarray

try:
from dask.delayed import Delayed
Expand Down Expand Up @@ -6630,11 +6628,11 @@ def unify_chunks(self) -> "Dataset":

def map_blocks(
self,
func: "Callable[..., T_DSorDA]",
func: "Callable[..., T_Xarray]",
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union["DataArray", "Dataset"] = None,
) -> "T_DSorDA":
) -> "T_Xarray":
"""
Apply a function to each block of this Dataset.
Expand Down
15 changes: 9 additions & 6 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import collections
import itertools
import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand All @@ -12,7 +15,6 @@
Mapping,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -32,7 +34,8 @@
pass


T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
if TYPE_CHECKING:
from .types import T_Xarray


def unzip(iterable):
Expand Down Expand Up @@ -122,8 +125,8 @@ def make_meta(obj):


def infer_template(
func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs
) -> T_DSorDA:
func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs
) -> T_Xarray:
"""Infer return object by running the function on meta objects."""
meta_args = [make_meta(arg) for arg in (obj,) + args]

Expand Down Expand Up @@ -162,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping


def map_blocks(
func: Callable[..., T_DSorDA],
func: Callable[..., T_Xarray],
obj: Union[DataArray, Dataset],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union[DataArray, Dataset] = None,
) -> T_DSorDA:
) -> T_Xarray:
"""Apply a function to each block of a DataArray or Dataset.
.. warning::
Expand Down
21 changes: 9 additions & 12 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

from distutils.version import LooseVersion
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union
from typing import Generic, Hashable, Mapping, Union

import numpy as np

from .options import _get_keep_attrs
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")
from .types import T_Xarray


def _get_alpha(com=None, span=None, halflife=None, alpha=None):
Expand Down Expand Up @@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


class RollingExp(Generic[T_DSorDA]):
class RollingExp(Generic[T_Xarray]):
"""
Exponentially-weighted moving window object.
Similar to EWM in pandas
Expand All @@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]):

def __init__(
self,
obj: T_DSorDA,
obj: T_Xarray,
windows: Mapping[Hashable, Union[int, float]],
window_type: str = "span",
):
self.obj: T_DSorDA = obj
self.obj: T_Xarray = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})

def mean(self, keep_attrs: bool = None) -> T_DSorDA:
def mean(self, keep_attrs: bool = None) -> T_Xarray:
"""
Exponentially weighted moving average.
Expand All @@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA:
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
)

def sum(self, keep_attrs: bool = None) -> T_DSorDA:
def sum(self, keep_attrs: bool = None) -> T_Xarray:
"""
Exponentially weighted moving sum.
Expand Down
31 changes: 31 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Union

import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
from .npcompat import ArrayLike
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

T_Dataset = TypeVar("T_Dataset", bound="Dataset")
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_Variable = TypeVar("T_Variable", bound="Variable")
# Maybe we rename this to T_Data or something less Fortran-y?
T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
GroupByIncompatible = Union["Variable", "GroupBy"]
Loading

0 comments on commit 6b59d9a

Please sign in to comment.