Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping DataTree methods over nodes with variables for which the args are invalid #8949

Closed
TomNicholas opened this issue Apr 16, 2024 · 0 comments · Fixed by #9589
Closed
Labels
topic-DataTree Related to the implementation of a DataTree class

Comments

@TomNicholas
Copy link
Member

TomNicholas commented Apr 16, 2024

What is your issue?

In the datatree call today we narrowed down an issue with how datatree maps methods over many variables in many nodes. This issue is essentially xarray-contrib/datatree#67, but I'll attempt to discuss the problem and solution in more general terms.

Context in xarray

xarray.Dataset is essentially a mapping of variable names to Variable objects, and most Dataset methods implicitly map a method defined on Variable over all these variables (e.g. .mean()). Sometimes the mapped method can be naively applied to every variable in the dataset, but sometimes it doesn't make sense to apply it to some of the variables. For example .mean(dim='time') only makes sense for the variables in the dataset that actually have a time dimension.

xarray.Dataset handles this for the user by either working out what version of the method does make sense for that variable (e.g. only trying to take the mean along the reduction dimensions actually present on that variable), or just passing the variable through unaltered. There are some weird subtleties lurking here, e.g. with statistical reductions like std and var.

# Some reduction functions (e.g. std, var) need to run on variables

There is therefore a difference between

ds.map(Variable.{REDUCTION}, dim='time') and ds.{REDUCTION}(dim='time')

For example:

In [13]: ds = xr.Dataset({'a': ('x', [1, 2]), 'b': 0})

In [14]: ds.isel(x=0)
Out[14]: 
<xarray.Dataset> Size: 16B
Dimensions:  ()
Data variables:
    a        int64 8B 1
    b        int64 8B 0

In [15]: ds.map(Variable.isel, x=0)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[15], line 1
----> 1 ds.map(Variable.isel, x=0)

...

ValueError: Dimensions {'x'} do not exist. Expected one or more of ()

(Aside: It would be nice for Dataset.map to include information about which variable it raised an exception on in the error message.)

Clearly Dataset.isel does more than just applying Variable.isel using Dataset.map.

Issue in DataTree

In datatree we have to map methods over different variables in the same node, but also over different variables in different nodes. Currently the implementation of a method naively maps the Dataset method over every node using map_over_subtree, but if there is a node containing a variable for which the method args are invalid, it will raise an exception.

This causes problems for users, for example in xarray-contrib/datatree#67. A minimal example of this problem would be

In [18]: ds1 = xr.Dataset({'a': ('x', [1, 2])})

In [19]: ds2 = xr.Dataset({'b': 0})

In [20]: dt = DataTree.from_dict({'node1': ds1, 'node2': ds2})

In [21]: dt
Out[21]: 
DataTree('None', parent=None)
├── DataTree('node1')
│       Dimensions:  (x: 2)
│       Dimensions without coordinates: xData variables:
│           a        (x) int64 16B 1 2
└── DataTree('node2')
        Dimensions:  ()
        Data variables:
            b        int64 8B 0

In [22]: dt.isel(x=0)
ValueError: Dimensions {'x'} do not exist. Expected one or more of FrozenMappingWarningOnValuesAccess({})
Raised whilst mapping function over node with path /node2

(The slightly weird error message here is related to the deprecation cycle in #8500)

We would have preferred that variable b in node2 survived unchanged, like it does in the pure Dataset example.

Desired behaviour

We can kind of think of the desired behaviour like a hypothesis property we want (xref #1846), but not quite. It would be something like

dt.{REDUCTION}().flatten_into_dataset() == dt.flatten_into_dataset().{REDUCTION}()

except that .flatten_into_dataset() can't really exist for all cases otherwise we wouldn't need datatree.

Proposed Solution

There are two ways I can imagine implementing this.

  1. Use map_over_subtree the apply the method as-is and try to catch known possible KeyErrors for missing dimensions. This would be fragile.
  2. Do some kind of pre-checking of the data in the tree, potentially adjust the method before applying it using map_over_subtree.

I think @shoyer and I concluded that we should make (2), in the form of some kind of new primitive, i.e. DataTree.reduce. (Actually DataTree.reduce already exists, but should be changed to not just map_over_subtree Dataset.reduce). Taking after Dataset.reduce, it would look something like this:

class DataTree:
    def reduce(self, reduce_func: Callable, dim: Dims = None, *, **kwargs) -> DataTree:
        all_dims_in_tree = set(node.dims for node in self.subtree)

        missing_dims = tuple(d for d in dims if d not in all_dims_in_tree)
        if missing_dims:
            raise ValueError()

        # TODO this could probably be refactored to call `map_over_subtree`
        for node in self.subtree:
            # using only the reduction dims that are actually present here would fix datatree GH issue #67
            reduce_dims = [d for d in node.dims if d in dims]
            result = node.ds.reduce(func, dims=reduce_dims, **kwargs)

        # TODO build the result and return it

Then every method that has this pattern of acting over one or more dims should be mapped over the tree using DataTree.reduce, not map_over_subtree.

cc @shoyer, @flamingbear, @owenlittlejohns

@TomNicholas TomNicholas added the topic-DataTree Related to the implementation of a DataTree class label Apr 16, 2024
shoyer added a commit to shoyer/xarray that referenced this issue Oct 7, 2024
They now allow for dimensions that are missing on particular nodes, and
use Xarray's standard generate_aggregations machinery, like aggregations
for DataArray and Dataset.

Fixes pydata#8949, pydata#8963
@shoyer shoyer closed this as completed in 707231e Oct 13, 2024
@github-project-automation github-project-automation bot moved this from To do to Done in DataTree integration Oct 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-DataTree Related to the implementation of a DataTree class
Projects
Development

Successfully merging a pull request may close this issue.

1 participant