Skip to content

Commit

Permalink
.to_dataset in map over subtree xarray-contrib/datatree#194
Browse files Browse the repository at this point in the history
* tests

* use to_dataset instead of DatasetView in map_over_subtree

* no longer forbid initialising a DatasetView with init

* whatsnew
  • Loading branch information
TomNicholas authored Jan 10, 2023
1 parent 0bf5256 commit 8054d96
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
10 changes: 2 additions & 8 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class DatasetView(Dataset):
This includes all API on Dataset, which will be inherited.
This requires overriding all inherited private constructors.
We leave the public init constructor because it is used by type() in some xarray code (see datatree GH issue #188)
"""

# TODO what happens if user alters (in-place) a DataArray they extracted from this object?
Expand All @@ -113,14 +115,6 @@ class DatasetView(Dataset):
"_variables",
)

def __init__(
self,
data_vars: Optional[Mapping[Any, Any]] = None,
coords: Optional[Mapping[Any, Any]] = None,
attrs: Optional[Mapping[Any, Any]] = None,
):
raise AttributeError("DatasetView objects are not to be initialized directly")

@classmethod
def _from_node(
cls,
Expand Down
11 changes: 6 additions & 5 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,23 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasetviews = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
node_args_as_datasets = [
a.to_dataset() if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
node_kwargs_as_datasets = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
v.to_dataset() if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)

# Now we can call func on the data in this particular set of corresponding nodes
results = (
func(*node_args_as_datasetviews, **node_kwargs_as_datasetviews)
func(*node_args_as_datasets, **node_kwargs_as_datasets)
if not node_of_first_tree.is_empty
else None
)
Expand Down
16 changes: 16 additions & 0 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,22 @@ def test_arithmetic(self, create_test_datatree):
result = 10.0 * dt["set1"].ds
assert result.identical(expected)

def test_init_via_type(self):
# from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188
# xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray

a = xr.DataArray(
np.random.rand(3, 4, 10),
dims=["x", "y", "time"],
coords={"area": (["x", "y"], np.random.rand(3, 4))},
).to_dataset(name="data")
dt = DataTree(data=a)

def weighted_mean(ds):
return ds.weighted(ds.area).mean(["x", "y"])

weighted_mean(dt.ds)


class TestRestructuring:
def test_drop_nodes(self):
Expand Down
44 changes: 44 additions & 0 deletions datatree/tests/test_mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import xarray as xr

Expand Down Expand Up @@ -252,6 +253,49 @@ def times_ten(ds):
assert_equal(result_tree, expected, from_root=False)


class TestMutableOperations:
def test_construct_using_type(self):
# from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188
# xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray

a = xr.DataArray(
np.random.rand(3, 4, 10),
dims=["x", "y", "time"],
coords={"area": (["x", "y"], np.random.rand(3, 4))},
).to_dataset(name="data")
b = xr.DataArray(
np.random.rand(2, 6, 14),
dims=["x", "y", "time"],
coords={"area": (["x", "y"], np.random.rand(2, 6))},
).to_dataset(name="data")
dt = DataTree.from_dict({"a": a, "b": b})

def weighted_mean(ds):
return ds.weighted(ds.area).mean(["x", "y"])

dt.map_over_subtree(weighted_mean)

def test_alter_inplace(self):
simpsons = DataTree.from_dict(
d={
"/": xr.Dataset({"age": 83}),
"/Herbert": xr.Dataset({"age": 40}),
"/Homer": xr.Dataset({"age": 39}),
"/Homer/Bart": xr.Dataset({"age": 10}),
"/Homer/Lisa": xr.Dataset({"age": 8}),
"/Homer/Maggie": xr.Dataset({"age": 1}),
},
name="Abe",
)

def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset:
"""Add some years to the age, but by altering the given dataset"""
ds["age"] = ds["age"] + years
return ds

simpsons.map_over_subtree(fast_forward, years=10)


@pytest.mark.xfail
class TestMapOverSubTreeInplace:
def test_map_over_subtree_inplace(self):
Expand Down
2 changes: 2 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Allow for altering of given dataset inside function called by :py:func:`map_over_subtree` (:issue:`188`, :pull:`194`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Documentation
~~~~~~~~~~~~~
Expand Down

0 comments on commit 8054d96

Please sign in to comment.