Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Commit

Permalink
Alternative fix for #188 (#268)
Browse files Browse the repository at this point in the history
* alternative fix for 188

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
TomNicholas and pre-commit-ci[bot] authored Oct 24, 2023
1 parent 86e9f59 commit ff86d3c
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
HybridMappingProxy,
_default,
either_dict_or_kwargs,
maybe_wrap_array,
)
from xarray.core.variable import Variable

Expand Down Expand Up @@ -235,6 +236,65 @@ def _replace(
inplace=inplace,
)

def map(
self,
func: Callable,
keep_attrs: bool | None = None,
args: Iterable[Any] = (),
**kwargs: Any,
) -> Dataset:
"""Apply a function to each data variable in this dataset
Parameters
----------
func : callable
Function which can be called in the form `func(x, *args, **kwargs)`
to transform each DataArray `x` in this dataset into another
DataArray.
keep_attrs : bool or None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.
Returns
-------
applied : Dataset
Resulting dataset from applying ``func`` to each data variable.
Examples
--------
>>> da = xr.DataArray(np.random.randn(2, 3))
>>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
>>> ds
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773
bar (x) int64 -1 2
>>> ds.map(np.fabs)
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 1.0 2.0
"""

# Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188).
# TODO Refactor xarray upstream to avoid needing to overwrite this.
# TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
variables = {
k: maybe_wrap_array(v, func(v, *args, **kwargs))
for k, v in self.data_vars.items()
}
# return type(self)(variables, attrs=attrs)
return Dataset(variables)


class DataTree(
NamedNode,
Expand Down

0 comments on commit ff86d3c

Please sign in to comment.