Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Commit

Permalink
actually allow DataTree objects as values in from_dict (#159)
Browse files Browse the repository at this point in the history
* allow passing `DataTree` objects as `dict` values

* add a test verifying that DataTree objects are actually allowed

* ignore mypy error with copied copy method

* whatsnew

Co-authored-by: Tom Nicholas <thomas.nicholas@columbia.edu>
  • Loading branch information
keewis and TomNicholas authored Dec 7, 2022
1 parent 0f436e3 commit d8e7d38
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
9 changes: 7 additions & 2 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None:
@classmethod
def from_dict(
cls,
d: MutableMapping[str, Dataset | DataArray | None],
d: MutableMapping[str, Dataset | DataArray | DataTree | None],
name: Optional[str] = None,
) -> DataTree:
"""
Expand Down Expand Up @@ -790,7 +790,12 @@ def from_dict(
for path, data in d.items():
# Create and set new node
node_name = NodePath(path).name
new_node = cls(name=node_name, data=data)
if isinstance(data, cls):
# TODO ignoring type error only needed whilst .copy() method is copied from Dataset.copy().
new_node = data.copy() # type: ignore[attr-defined]
new_node.orphan()
else:
new_node = cls(name=node_name, data=data)
obj._set_item(
path,
new_node,
Expand Down
9 changes: 9 additions & 0 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,15 @@ def test_full(self, simple_datatree):
"/set3",
]

def test_datatree_values(self):
dat1 = DataTree(data=xr.Dataset({"a": 1}))
expected = DataTree()
expected["a"] = dat1

actual = DataTree.from_dict({"a": dat1})

dtt.assert_identical(actual, expected)

def test_roundtrip(self, simple_datatree):
dt = simple_datatree
roundtrip = DataTree.from_dict(dt.to_dict())
Expand Down
3 changes: 3 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Deprecations
Bug fixes
~~~~~~~~~

- Allow ``Datatree`` objects as values in :py:meth:`DataTree.from_dict` (:pull:`159`).
By `Justus Magin <https://github.com/keewis>`_.

Documentation
~~~~~~~~~~~~~

Expand Down

0 comments on commit d8e7d38

Please sign in to comment.