From d8e7d389cfefde1766a324673930ea3d8457cf87 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 7 Dec 2022 20:27:36 +0100 Subject: [PATCH] actually allow `DataTree` objects as values in `from_dict` (#159) * allow passing `DataTree` objects as `dict` values * add a test verifying that DataTree objects are actually allowed * ignore mypy error with copied copy method * whatsnew Co-authored-by: Tom Nicholas --- datatree/datatree.py | 9 +++++++-- datatree/tests/test_datatree.py | 9 +++++++++ docs/source/whats-new.rst | 3 +++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index da53b7e3..5d588da1 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -754,7 +754,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: @classmethod def from_dict( cls, - d: MutableMapping[str, Dataset | DataArray | None], + d: MutableMapping[str, Dataset | DataArray | DataTree | None], name: Optional[str] = None, ) -> DataTree: """ @@ -790,7 +790,12 @@ def from_dict( for path, data in d.items(): # Create and set new node node_name = NodePath(path).name - new_node = cls(name=node_name, data=data) + if isinstance(data, cls): + # TODO ignoring type error only needed whilst .copy() method is copied from Dataset.copy(). + new_node = data.copy() # type: ignore[attr-defined] + new_node.orphan() + else: + new_node = cls(name=node_name, data=data) obj._set_item( path, new_node, diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index dd08618d..f5d7b928 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -391,6 +391,15 @@ def test_full(self, simple_datatree): "/set3", ] + def test_datatree_values(self): + dat1 = DataTree(data=xr.Dataset({"a": 1})) + expected = DataTree() + expected["a"] = dat1 + + actual = DataTree.from_dict({"a": dat1}) + + dtt.assert_identical(actual, expected) + def test_roundtrip(self, simple_datatree): dt = simple_datatree roundtrip = DataTree.from_dict(dt.to_dict()) diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 8d9d6573..cfe73a50 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -37,6 +37,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Allow ``Datatree`` objects as values in :py:meth:`DataTree.from_dict` (:pull:`159`). + By `Justus Magin `_. + Documentation ~~~~~~~~~~~~~