Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataTree.coords.__setitem__ by adding DataTreeCoordinates class #9451

Merged
merged 50 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
704db79
add a DataTreeCoordinates class
TomNicholas Sep 8, 2024
417e3e9
passing read-only properties tests
TomNicholas Sep 8, 2024
9562e92
tests for modifying in-place
TomNicholas Sep 8, 2024
0e7de82
WIP making the modification test pass
TomNicholas Sep 8, 2024
839858f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
9370b9b
get to the delete tests
TomNicholas Sep 8, 2024
9b50567
test
TomNicholas Sep 8, 2024
c466f8d
improve error message
TomNicholas Sep 8, 2024
0397eca
implement delitem
TomNicholas Sep 8, 2024
85bb221
test KeyError
TomNicholas Sep 8, 2024
7802c63
Merge branch 'delitem' into datatree_coords_setitem
TomNicholas Sep 8, 2024
1bf5082
subclass Coordinates instead of DatasetCoordinates
TomNicholas Sep 8, 2024
e8620cf
use Frozen(self._data._coord_variables)
TomNicholas Sep 8, 2024
1108504
Simplify when to raise KeyError
TomNicholas Sep 8, 2024
0a7201b
correct bug in suggestion
TomNicholas Sep 8, 2024
51e11bc
Update xarray/core/coordinates.py
TomNicholas Sep 8, 2024
7ecdd16
simplify _update_coords by creating new node data first
TomNicholas Sep 8, 2024
dfcdb6d
Merge branch 'main' into datatree_coords_setitem
TomNicholas Sep 9, 2024
3278153
Merge branch 'main' into datatree_coords_setitem
TomNicholas Sep 9, 2024
f672c5e
update indexes correctly
TomNicholas Sep 9, 2024
7fb1622
passes test
TomNicholas Sep 9, 2024
897b7c4
update ._drop_indexed_coords
TomNicholas Sep 9, 2024
b5a56f4
Merge branch 'main' into datatree_coords_setitem
TomNicholas Sep 9, 2024
fdae5bc
some mypy fixes
TomNicholas Sep 10, 2024
9dc845a
remove the apparently-unused _drop_indexed_coords method
TomNicholas Sep 10, 2024
6595fe9
Merge branch 'datatree_coords_setitem' of https://github.com/TomNicho…
TomNicholas Sep 10, 2024
ed87554
fix import error
TomNicholas Sep 10, 2024
c155bc1
test that Dataset and DataArray constructors can handle being passed …
TomNicholas Sep 10, 2024
217cb84
test dt.coords can be passed to DataTree constructor
TomNicholas Sep 10, 2024
540bb0f
improve readability of inline comment
TomNicholas Sep 10, 2024
7126efa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
8486227
initial tests with inherited coords
TomNicholas Sep 10, 2024
12f24df
Merge branch 'datatree_coords_setitem' of https://github.com/TomNicho…
TomNicholas Sep 10, 2024
8f09c93
ignore typeerror indicating dodgy inheritance
TomNicholas Sep 11, 2024
d23105f
try to avoid Unbound type error
TomNicholas Sep 11, 2024
978e05e
cast return value correctly
TomNicholas Sep 11, 2024
bd47575
cehck that .coords works with inherited coords
TomNicholas Sep 11, 2024
8ef94df
Merge branch 'main' into datatree_coords_setitem
TomNicholas Sep 11, 2024
10b8a78
Merge branch 'main' into datatree_coords_setitem
TomNicholas Sep 11, 2024
b9ede22
fix data->dataset
TomNicholas Sep 11, 2024
540a825
fix return type of __getitem__
TomNicholas Sep 11, 2024
b30d5e0
Use .dataset instead of .to_dataset()
TomNicholas Sep 11, 2024
639ad07
_check_alignment -> check_alignment
TomNicholas Sep 11, 2024
0a9a328
remove dict comprehension
TomNicholas Sep 11, 2024
80bc0bd
KeyError message formatting
TomNicholas Sep 11, 2024
a366bf6
keep generic types for .dims and .sizes
TomNicholas Sep 11, 2024
4d352bd
test verifying you cant delete inherited coord
TomNicholas Sep 11, 2024
4626fa8
fix mypy complaint
TomNicholas Sep 11, 2024
ea8a195
type hint as accepting objects
TomNicholas Sep 11, 2024
af94af4
update note about .dims returning all dims
TomNicholas Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from xarray.core.common import DataWithCoords
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

# Used as the key corresponding to a DataArray's variable when converting
# arbitrary DataArray objects to datasets
Expand Down Expand Up @@ -197,12 +198,12 @@ class Coordinates(AbstractCoordinates):

Coordinates are either:

- returned via the :py:attr:`Dataset.coords` and :py:attr:`DataArray.coords`
properties
- returned via the :py:attr:`Dataset.coords`, :py:attr:`DataArray.coords`,
and :py:attr:`DataTree.coords` properties,
- built from Pandas or other index objects
(e.g., :py:meth:`Coordinates.from_pandas_multiindex`)
(e.g., :py:meth:`Coordinates.from_pandas_multiindex`),
- built directly from coordinate data and Xarray ``Index`` objects (beware that
no consistency check is done on those inputs)
no consistency check is done on those inputs),

Parameters
----------
Expand Down Expand Up @@ -796,6 +797,126 @@ def _ipython_key_completions_(self):
]


class DataTreeCoordinates(DatasetCoordinates):
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
"""
Dictionary like container for coordinates of a DataTree node (variables + indexes).
"""

# TODO: This collection can be passed directly to the :py:class:`~xarray.Dataset`
# and :py:class:`~xarray.DataArray` constructors via their `coords` argument.
# This will add both the coordinates variables and their index.

# TODO: This only needs to be a separate class from `DatasetCoordinates` because DataTree nodes store their variables differently
# internally than how Datasets do, see https://github.com/pydata/xarray/issues/9203.

# TODO should inherited coordinates be here? It would be very hard to allow updating them...
# But actually maybe the ChainMap approach would make this work okay??
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see an easy way to make assignment of coordinates to parent nodes work. I think it would be better just to have dt.coords constructed from inherited coords, but not allow setting coordinates further up the tree. I believe that's consistent with the behaviour we currently have for dt.__setitem__ anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!


_data: DataTree
Copy link
Member Author

@TomNicholas TomNicholas Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inheriting from Coordinates and using DataTree instead of DataWithCoords here both expose a few legitimate typing issues. Particularly:

Mostly we have got away with these discrepancies so far but inheritance here is probably going to mean I need to either ignore or cast all over the place?

cc @headtr1ck

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this wasn't very difficult to work around, it mostly just required adding a type: ignore[assignment] here :)


__slots__ = ("_data",)

def __init__(self, datatree: DataTree):
self._data = datatree

@property
def _names(self) -> set[Hashable]:
return set(self._data._coord_variables)

@property
def dims(self) -> Frozen[Hashable, int]:
# TODO is there a potential bug here? What happens if a dim is only present on data variables?
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
return self._data.dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine. Dimensions on DatasetCoordinates also include dimensions that are only present on data variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really? This (on main) looks wrong to me:

In [1]: import xarray as xr

In [2]: ds = xr.Dataset({'a': ('x', [0, 1])})

In [3]: ds
Out[3]: 
<xarray.Dataset> Size: 16B
Dimensions:  (x: 2)
Dimensions without coordinates: x
Data variables:
    a        (x) int64 16B 0 1

In [4]: ds.coords
Out[4]: 
Coordinates:
    *empty*

In [5]: ds.coords.dims
Out[5]: FrozenMappingWarningOnValuesAccess({'x': 2})

I mean the fact no-one has raised this before means it probably isn't of much consequence, but it does seem incorrect / misleading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raised #9466 to track this


@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.

Cannot be modified directly, but is updated when adding new variables.

See Also
--------
Dataset.dtypes
"""
return Frozen({n: v.dtype for n, v in self._data._coord_variables.items()})

@property
def variables(self) -> Mapping[Hashable, Variable]:
variables = self._data._data_variables | self._data._coord_variables
return Frozen({k: v for k, v in variables.items() if k in self._names})
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, key: Hashable) -> DataArray:
if key in self._data._data_variables or key in self._data.children:
raise KeyError(key)
return self._data[key]
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

def to_dataset(self) -> Dataset:
"""Convert these coordinates into a new Dataset"""
return self._data.to_dataset()._copy_listed(self._names)
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
) -> None:
shoyer marked this conversation as resolved.
Show resolved Hide resolved

# TODO I don't know how to update coordinates that live in parent nodes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is OK for this to be an error. The user can replace those coordinates on the parent nodes.

# TODO We would have to find the correct node and update `._node_coord_variables`

coord_variables = self._data._coord_variables.copy()
coord_variables.update(coords)

# check for inconsistent state *before* modifying anything in-place
variables = coord_variables | self._data._data_variables.copy()
# TODO is there a subtlety here with rebuild_dims?
dims = calculate_dimensions(variables)
new_coord_names = set(coords)
for dim, size in dims.items():
if dim in variables:
new_coord_names.add(dim)

# TODO we need to upgrade these variables to coord variables somehow
# coord_variables.update(new_coord_names)

self._data._coord_variables = coord_variables
self._data._dims = dims
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
for name in coord_names:
del self._data._coord_variables[name]
del self._data._indexes[name]
# self._data._coord_names.difference_update(coord_names)

def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None:
assert self._data.xindexes is not None
new_coords = drop_indexed_coords(coords_to_drop, self)
for name in self._data._coord_names - new_coords._names:
del self._data._coord_variables[name]
self._data._indexes = dict(new_coords.xindexes)
# self._data._coord_names.intersection_update(new_coords._names)

def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
else:
raise KeyError(
f"{key!r} is not in coordinate variables {tuple(self.keys())}"
)
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
return [
key
for key in self._data._ipython_key_completions_()
if key not in self._data.data_vars
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
]


class DataArrayCoordinates(Coordinates, Generic[T_DataArray]):
"""Dictionary like container for DataArray coordinates (variables + indexes).

Expand Down
6 changes: 3 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from xarray.core import utils
from xarray.core.alignment import align
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.coordinates import DataTreeCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
Expand Down Expand Up @@ -1148,11 +1148,11 @@ def xindexes(self) -> Indexes[Index]:
)

@property
def coords(self) -> DatasetCoordinates:
def coords(self) -> DataTreeCoordinates:
"""Dictionary of xarray.DataArray objects corresponding to coordinate
variables
"""
return DatasetCoordinates(self.to_dataset())
return DataTreeCoordinates(self)

@property
def data_vars(self) -> DataVariables:
Expand Down
120 changes: 119 additions & 1 deletion xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import pytest

import xarray as xr
from xarray import Dataset
from xarray.core.coordinates import DataTreeCoordinates
from xarray.core.datatree import DataTree
from xarray.core.datatree_ops import _MAPPED_DOCSTRING_ADDENDUM, insert_doc_addendum
from xarray.core.treenode import NotFoundInTreeError
from xarray.testing import assert_equal, assert_identical
from xarray.tests import create_test_data, source_ndarray
from xarray.tests import assert_array_equal, create_test_data, source_ndarray


class TestTreeCreation:
Expand Down Expand Up @@ -525,6 +527,122 @@
assert_identical(results.to_dataset(), expected)


class TestCoordsInterface:
def test_properties(self):
# use int64 for repr consistency on windows
ds = Dataset(
data_vars={
"foo": (["x", "y"], np.random.randn(2, 3)),
},
coords={
"x": ("x", np.array([-1, -2], "int64")),
"y": ("y", np.array([0, 1, 2], "int64")),
"a": ("x", np.array([4, 5], "int64")),
"b": np.int64(-10),
},
)
dt = DataTree(data=ds)
dt["child"] = DataTree()

coords = dt.coords
assert isinstance(coords, DataTreeCoordinates)

# len
assert len(coords) == 4

# iter
assert list(coords) == ["x", "y", "a", "b"]

assert_identical(coords["x"].variable, dt["x"].variable)
assert_identical(coords["y"].variable, dt["y"].variable)

assert "x" in coords
assert "a" in coords
assert 0 not in coords
assert "foo" not in coords
assert "child" not in coords

with pytest.raises(KeyError):
coords["foo"]

# TODO this currently raises a ValueError instead of a KeyError
# with pytest.raises(KeyError):
# coords[0]

# repr
expected = dedent(
"""\
Coordinates:
* x (x) int64 16B -1 -2
* y (y) int64 24B 0 1 2
a (x) int64 16B 4 5
b int64 8B -10"""
)
actual = repr(coords)
assert expected == actual

# dims
assert coords.sizes == {"x": 2, "y": 3}

# dtypes
assert coords.dtypes == {
"x": np.dtype("int64"),
"y": np.dtype("int64"),
"a": np.dtype("int64"),
"b": np.dtype("int64"),
}

def test_modify(self):
ds = Dataset(
data_vars={
"foo": (["x", "y"], np.random.randn(2, 3)),
},
coords={
"x": ("x", np.array([-1, -2], "int64")),
"y": ("y", np.array([0, 1, 2], "int64")),
"a": ("x", np.array([4, 5], "int64")),
"b": np.int64(-10),
},
)
dt = DataTree(data=ds)
dt["child"] = DataTree()

actual = dt.copy(deep=True)
actual.coords["x"] = ("x", ["a", "b"])

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

TestCoordsInterface.test_modify AttributeError: can't set attribute '_coord_variables'

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

TestCoordsInterface.test_modify AttributeError: can't set attribute '_coord_variables'

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11 all-but-dask

TestCoordsInterface.test_modify AttributeError: property '_coord_variables' of 'DataTree' object has no setter

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

TestCoordsInterface.test_modify AttributeError: property '_coord_variables' of 'DataTree' object has no setter

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

TestCoordsInterface.test_modify AttributeError: can't set attribute '_coord_variables'

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

TestCoordsInterface.test_modify AttributeError: can't set attribute '_coord_variables'

Check failure on line 611 in xarray/tests/test_datatree.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestCoordsInterface.test_modify AttributeError: property '_coord_variables' of 'DataTree' object has no setter
assert_array_equal(actual["x"], ["a", "b"])

actual = dt.copy(deep=True)
actual.coords["z"] = ("z", ["a", "b"])
assert_array_equal(actual["z"], ["a", "b"])

actual = dt.copy(deep=True)
with pytest.raises(ValueError, match=r"conflicting dimension sizes"):
actual.coords["x"] = ("x", [-1])
assert_identical(actual, dt) # should not be modified

actual = dt.copy()
del actual.coords["b"]
expected = dt.reset_coords("b", drop=True)
assert_identical(expected, actual)

with pytest.raises(KeyError):
del dt.coords["not_found"]

with pytest.raises(KeyError):
del dt.coords["foo"]

actual = dt.copy(deep=True)
actual.coords.update({"c": 11})
expected = dt.merge({"c": 11}).set_coords("c")
assert_identical(expected, actual)

# regression test for GH3746
del actual.coords["x"]
assert "x" not in actual.xindexes

# TODO test with coordinate inheritance too...


class TestDictionaryInterface: ...


Expand Down
Loading