Skip to content

Commit

Permalink
Migrate datatree mapping.py (#8948)
Browse files Browse the repository at this point in the history
* DAS-2064: rename/relocate mapping.py -> xarray.core.datatree_mapping.py

DAS-2064: fix circular import issue.

* DAS-2064 - Minor changes to datatree_mapping.py.

---------

Co-authored-by: Matt Savoie <matthew.savoie@colorado.edu>
  • Loading branch information
owenlittlejohns and flamingbear authored Apr 17, 2024
1 parent 9a37053 commit 60f3e74
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 33 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
Expand All @@ -36,11 +41,6 @@
from xarray.datatree_.datatree.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.datatree_.datatree.mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import NodePath, TreeNode

Expand Down Expand Up @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

Expand Down Expand Up @@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections import abc
from collections.abc import Iterator
from typing import Callable

Expand All @@ -9,7 +8,7 @@
"""These iterators are copied from anytree.iterators, with minor modifications."""


class LevelOrderIter(abc.Iterator):
class LevelOrderIter(Iterator):
"""Iterate over tree applying level-order strategy starting at `node`.
This is the iterator used by `DataTree` to traverse nodes.
Expand Down
3 changes: 0 additions & 3 deletions xarray/datatree_/datatree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# import public API
from .mapping import TreeIsomorphismError, map_over_subtree
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError


__all__ = (
"TreeIsomorphismError",
"InvalidTreeError",
"NotFoundInTreeError",
"map_over_subtree",
)
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray.core.formatting import _compat_to_str, diff_dataset_repr

from xarray.datatree_.datatree.mapping import diff_treestructure
from xarray.core.datatree_mapping import diff_treestructure
from xarray.datatree_.datatree.render import RenderTree

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xarray import Dataset

from .mapping import map_over_subtree
from xarray.core.datatree_mapping import map_over_subtree

"""
Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
import pytest
import xarray as xr

import xarray as xr
from xarray.core.datatree import DataTree
from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.testing import assert_equal

empty = xr.Dataset()
Expand All @@ -12,7 +16,7 @@
class TestCheckTreesIsomorphic:
def test_not_a_tree(self):
with pytest.raises(TypeError, match="not a tree"):
check_isomorphic("s", 1)
check_isomorphic("s", 1) # type: ignore[arg-type]

def test_different_widths(self):
dt1 = DataTree.from_dict(d={"a": empty})
Expand Down Expand Up @@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree):
def test_checking_from_root(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
real_root = DataTree(name="real root")
real_root: DataTree = DataTree(name="real root")
dt2.name = "not_real_root"
dt2.parent = real_root
with pytest.raises(TreeIsomorphismError):
Expand Down

0 comments on commit 60f3e74

Please sign in to comment.