Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Unexpected behavior for .map_over_subtree #188

Closed
jbusecke opened this issue Jan 9, 2023 · 5 comments · Fixed by #194
Closed

Unexpected behavior for .map_over_subtree #188

jbusecke opened this issue Jan 9, 2023 · 5 comments · Fixed by #194
Labels
bug Something isn't working

Comments

@jbusecke
Copy link
Contributor

jbusecke commented Jan 9, 2023

I want to apply a weighted mean over every dataset in a tree but cannot map a function as expected.

Take this simple example:

import xarray as xr
import numpy as np
a = xr.DataArray(np.random.rand(3,4, 10), dims=['x','y', 'time'], coords={'area':(['x','y'],np.random.rand(3,4))}).to_dataset(name='data')
b = xr.DataArray(np.random.rand(2,6, 14), dims=['x','y', 'time'], coords={'area':(['x','y'],np.random.rand(2,6))}).to_dataset(name='data')

dt = datatree.DataTree.from_dict({'a':a, 'b':b})
dt
DataTree('None', parent=None)
├── DataTree('a')
│       Dimensions:  (x: 3, y: 4, time: 10)
│       Coordinates:
│           area     (x, y) float64 0.5258 0.4488 0.9383 0.7907 ... 0.7133 0.2833 0.5557
│       Dimensions without coordinates: x, y, time
│       Data variables:
│           data     (x, y, time) float64 0.4141 0.06808 0.1853 ... 0.5805 0.0713 0.6484
└── DataTree('b')
        Dimensions:  (x: 2, y: 6, time: 14)
        Coordinates:
            area     (x, y) float64 0.9447 0.6305 0.7047 0.687 ... 0.916 0.732 0.6114
        Dimensions without coordinates: x, y, time
        Data variables:
            data     (x, y, time) float64 0.1824 0.3305 0.1357 ... 0.9851 0.09906 0.8442

I am defining a very simple weighted mean function:

def weighted_mean(ds):
    return ds.weighted(ds.area).mean(['x', 'y'])

which works as expected on a single dataset:

weighted_mean(a)

image

But when I try to map this over the tree

dt_mapped = dt.map_over_subtree(weighted_mean)

I get this:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[61], line 1
----> 1 dt_mapped = dt.map_over_subtree(weighted_mean)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/datatree/datatree.py:992, in DataTree.map_over_subtree(self, func, *args, **kwargs)
    964 """
    965 Apply a function to every dataset in this subtree, returning a new tree which stores the results.
    966 
   (...)
    987     One or more subtrees containing results from applying ``func`` to the data at each node.
    988 """
    989 # TODO this signature means that func has no way to know which node it is being called upon - change?
    990 
    991 # TODO fix this typing error
--> 992 return map_over_subtree(func)(self, *args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/datatree/mapping.py:207, in map_over_subtree.<locals>._map_over_subtree(*args, **kwargs)
    195 node_kwargs_as_datasetviews = dict(
    196     zip(
    197         [k for k in kwargs_as_tree_length_iterables.keys()],
   (...)
    202     )
    203 )
    205 # Now we can call func on the data in this particular set of corresponding nodes
    206 results = (
--> 207     func(*node_args_as_datasetviews, **node_kwargs_as_datasetviews)
    208     if not node_of_first_tree.is_empty
    209     else None
    210 )
    212 # TODO implement mapping over multiple trees in-place using if conditions from here on?
    213 out_data_objects[node_of_first_tree.path] = results

Cell In[57], line 11, in weighted_mean(ds)
     10 def weighted_mean(ds):
---> 11     return ds.weighted(ds.area).mean(['x', 'y'])

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/weighted.py:488, in Weighted.mean(self, dim, skipna, keep_attrs)
    481 def mean(
    482     self,
    483     dim: Dims = None,
    484     skipna: bool | None = None,
    485     keep_attrs: bool | None = None,
    486 ) -> T_Xarray:
--> 488     return self._implementation(
    489         self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
    490     )

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/weighted.py:550, in DatasetWeighted._implementation(self, func, dim, **kwargs)
    546 def _implementation(self, func, dim, **kwargs) -> Dataset:
    548     self._check_dim(dim)
--> 550     return self.obj.map(func, dim=dim, **kwargs)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/dataset.py:5957, in Dataset.map(self, func, keep_attrs, args, **kwargs)
   5955         v._copy_attrs_from(self.data_vars[k])
   5956 attrs = self.attrs if keep_attrs else None
-> 5957 return type(self)(variables, attrs=attrs)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/datatree/datatree.py:121, in DatasetView.__init__(self, data_vars, coords, attrs)
    115 def __init__(
    116     self,
    117     data_vars: Optional[Mapping[Any, Any]] = None,
    118     coords: Optional[Mapping[Any, Any]] = None,
    119     attrs: Optional[Mapping[Any, Any]] = None,
    120 ):
--> 121     raise AttributeError("DatasetView objects are not to be initialized directly")

AttributeError: DatasetView objects are not to be initialized directly

strangely if I modify the function to return a dataarray and map it

def weighted_mean_da(ds):
    return ds.data.weighted(ds.area).mean(['x', 'y'])

dt_mapped_da = dt.map_over_subtree(weighted_mean_da)
dt_mapped_da

It gives me the desired result, but I am not sure if this is intended behavior (the return dataarray is converted back to a dataset?)

image

@jbusecke
Copy link
Contributor Author

jbusecke commented Jan 9, 2023

I was using v0.0.10 for the record

@TomNicholas TomNicholas added the bug Something isn't working label Jan 9, 2023
@TomNicholas
Copy link
Member

This bug is something to do with not properly copying node.ds (or not converting it to a Dataset instead of DatasetView) before it gets passed to the user function in map_over_subtree

@TomNicholas
Copy link
Member

TomNicholas commented Jan 9, 2023

I am not sure if this is intended behavior (the return dataarray is converted back to a dataset?)

Yes that's intended - nodes of a tree can't hold just one DataArray, they actually extract all the underlying variable objects and reorganise them into something which makes the node look like it is a Dataset. So you should expect DataArrays to be promoted to single-variable Datasets.

@TomNicholas
Copy link
Member

You can actually reproduce this with just

dt = datatree.DataTree(data=a)
weighted_mean(dt.ds)

It's because dt.ds returns an immutable DatasetView object, which is like Dataset except you can't change it or create one yourself directly. That latter restriction prevents the type() call in xarray.weighted from working though.

I should be able to fix it either by:

  1. Not passing the DatasetView in map_over_subtree
  2. Removing the restriction on instantiating a DatasetView.

@TomNicholas
Copy link
Member

Actually a better solution here might have been to re-implement .map on DatasetView, as that would get around the problematic type(self) call in Dataset.map. This would be useful because it probably better to be passing immutable DatasetView objects to map_over_subtree if we can (see #266 for one motivation).

TomNicholas added a commit that referenced this issue Oct 24, 2023
* alternative fix for 188

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants