Skip to content

Commit

Permalink
combine keep_attrs and combine_attrs in apply_ufunc (#5041)
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored May 13, 2021
1 parent 1f52ae0 commit 751f76a
Show file tree
Hide file tree
Showing 4 changed files with 460 additions and 34 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ v0.18.1 (unreleased)

New Features
~~~~~~~~~~~~
- allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of
:py:func:`apply_ufunc` (:pull:`5041`)
By `Justus Magin <https://github.com/keewis>`_.
- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes,
such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`).
By `Jimmy Westling <https://github.com/illviljan>`_.
Expand Down
92 changes: 59 additions & 33 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from . import dtypes, duck_array_ops, utils
from .alignment import align, deep_align
from .merge import merge_coordinates_without_align
from .options import OPTIONS
from .merge import merge_attrs, merge_coordinates_without_align
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .utils import is_dict_like
from .variable import Variable
Expand All @@ -50,6 +50,11 @@ def _first_of_type(args, kind):
raise ValueError("This should be unreachable.")


def _all_of_type(args, kind):
"""Return all objects of type 'kind'"""
return [arg for arg in args if isinstance(arg, kind)]


class _UFuncSignature:
"""Core dimensions signature for a given function.
Expand Down Expand Up @@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]:


def build_output_coords(
args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset()
args: list,
signature: _UFuncSignature,
exclude_dims: AbstractSet = frozenset(),
combine_attrs: str = "override",
) -> "List[Dict[Any, Variable]]":
"""Build output coordinates for an operation.
Expand Down Expand Up @@ -230,7 +238,7 @@ def build_output_coords(
else:
# TODO: save these merged indexes, instead of re-computing them later
merged_vars, unused_indexes = merge_coordinates_without_align(
coords_list, exclude_dims=exclude_dims
coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs
)

output_coords = []
Expand All @@ -248,7 +256,12 @@ def build_output_coords(


def apply_dataarray_vfunc(
func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False
func,
*args,
signature,
join="inner",
exclude_dims=frozenset(),
keep_attrs="override",
):
"""Apply a variable level function over DataArray, Variable and/or ndarray
objects.
Expand All @@ -260,12 +273,16 @@ def apply_dataarray_vfunc(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

if keep_attrs:
objs = _all_of_type(args, DataArray)

if keep_attrs == "drop":
name = result_name(args)
else:
first_obj = _first_of_type(args, DataArray)
name = first_obj.name
else:
name = result_name(args)
result_coords = build_output_coords(args, signature, exclude_dims)
result_coords = build_output_coords(
args, signature, exclude_dims, combine_attrs=keep_attrs
)

data_vars = [getattr(a, "variable", a) for a in args]
result_var = func(*data_vars)
Expand All @@ -279,13 +296,12 @@ def apply_dataarray_vfunc(
(coords,) = result_coords
out = DataArray(result_var, coords, name=name, fastpath=True)

if keep_attrs:
if isinstance(out, tuple):
for da in out:
# This is adding attrs in place
da._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
if isinstance(out, tuple):
for da in out:
da.attrs = attrs
else:
out.attrs = attrs

return out

Expand Down Expand Up @@ -400,7 +416,7 @@ def apply_dataset_vfunc(
dataset_join="exact",
fill_value=_NO_FILL_VALUE,
exclude_dims=frozenset(),
keep_attrs=False,
keep_attrs="override",
):
"""Apply a variable level function over Dataset, dict of DataArray,
DataArray, Variable and/or ndarray objects.
Expand All @@ -414,15 +430,16 @@ def apply_dataset_vfunc(
"dataset_fill_value argument."
)

if keep_attrs:
first_obj = _first_of_type(args, Dataset)
objs = _all_of_type(args, Dataset)

if len(args) > 1:
args = deep_align(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

list_of_coords = build_output_coords(args, signature, exclude_dims)
list_of_coords = build_output_coords(
args, signature, exclude_dims, combine_attrs=keep_attrs
)
args = [getattr(arg, "data_vars", arg) for arg in args]

result_vars = apply_dict_of_variables_vfunc(
Expand All @@ -435,13 +452,13 @@ def apply_dataset_vfunc(
(coord_vars,) = list_of_coords
out = _fast_dataset(result_vars, coord_vars)

if keep_attrs:
if isinstance(out, tuple):
for ds in out:
# This is adding attrs in place
ds._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
if isinstance(out, tuple):
for ds in out:
ds.attrs = attrs
else:
out.attrs = attrs

return out


Expand Down Expand Up @@ -609,14 +626,12 @@ def apply_variable_ufunc(
dask="forbidden",
output_dtypes=None,
vectorize=False,
keep_attrs=False,
keep_attrs="override",
dask_gufunc_kwargs=None,
):
"""Apply a ndarray level function over Variable and/or ndarray objects."""
from .variable import Variable, as_compatible_data

first_obj = _first_of_type(args, Variable)

dim_sizes = unified_dim_sizes(
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
)
Expand Down Expand Up @@ -736,6 +751,12 @@ def func(*arrays):
)
)

objs = _all_of_type(args, Variable)
attrs = merge_attrs(
[obj.attrs for obj in objs],
combine_attrs=keep_attrs,
)

output = []
for dims, data in zip(output_dims, result_data):
data = as_compatible_data(data)
Expand All @@ -758,8 +779,7 @@ def func(*arrays):
)
)

if keep_attrs:
var.attrs.update(first_obj.attrs)
var.attrs = attrs
output.append(var)

if signature.num_outputs == 1:
Expand Down Expand Up @@ -801,7 +821,7 @@ def apply_ufunc(
join: str = "exact",
dataset_join: str = "exact",
dataset_fill_value: object = _NO_FILL_VALUE,
keep_attrs: bool = False,
keep_attrs: Union[bool, str] = None,
kwargs: Mapping = None,
dask: str = "forbidden",
output_dtypes: Sequence = None,
Expand Down Expand Up @@ -1098,6 +1118,12 @@ def apply_ufunc(
if kwargs:
func = functools.partial(func, **kwargs)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if isinstance(keep_attrs, bool):
keep_attrs = "override" if keep_attrs else "drop"

variables_vfunc = functools.partial(
apply_variable_ufunc,
func,
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def merge_coordinates_without_align(
objects: "List[Coordinates]",
prioritized: Mapping[Hashable, MergeElement] = None,
exclude_dims: AbstractSet = frozenset(),
combine_attrs: str = "override",
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]:
"""Merge variables/indexes from coordinates without automatic alignments.
Expand All @@ -335,7 +336,7 @@ def merge_coordinates_without_align(
else:
filtered = collected

return merge_collected(filtered, prioritized)
return merge_collected(filtered, prioritized, combine_attrs=combine_attrs)


def determine_coords(
Expand Down
Loading

0 comments on commit 751f76a

Please sign in to comment.