Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A broadcasting sum for xarray.Dataset #6053

Open
mjwillson opened this issue Dec 8, 2021 · 4 comments
Open

A broadcasting sum for xarray.Dataset #6053

mjwillson opened this issue Dec 8, 2021 · 4 comments

Comments

@mjwillson
Copy link

mjwillson commented Dec 8, 2021

I've found it useful to have a version of Dataset.sum which sums variables in a way that's consistent with what would happen if they were broadcast to the full Dataset dimensions.

The difference is in what it does with variables that don't contain some of the dimensions it's asked to sum over: standard sum just ignores the summation over these dimensions for these variables, whereas a broadcasting_sum will multiply the variable by the product of sizes the missing dimensions, like so:

def broadcast_sum(dataset, dims):
  def broadcast_sum_var(var):
    present_sum_dims = [dim for dim in dims if dim in var.dims]
    non_present_sum_dims = [dim for dim in dims if dim not in var.dims]
    return var.sum(present_sum_dims) * np.prod([dataset.sizes[dim] for dim in non_present_sum_dims])
  return dataset.map(broadcast_sum_var)

This is consistent with mathematical sum notation, where the sum doesn't become a no-op just because the summand doesn't reference the index being summed over. E.g.:

$\sum_{n=1}^N x = N x$

I've found it useful when you need to do some broadcasting operations across different variables after the sum, and you want the summation done in a way that's consistent with the broadcasting logic that will be applied later.

Would you be open to adding this, and if so any preference how? (A separate method, an option to .sum ?)

@dcherian
Copy link
Contributor

dcherian commented Jul 7, 2022

xr.broadcast(ds)[0].sum(dims) should do this.

We could add it here: https://xarray.pydata.org/en/latest/howdoi.html and to the docs under Aggregations

@headtr1ck
Copy link
Collaborator

See discussion in #6749

Maybe the current implementation of sum is not correct?

@mjwillson
Copy link
Author

Re xr.broadcast(ds)[0].sum(dims) -- Thanks, that's neat and may be useful as a workaround, but it looks like it could incur significant extra CPU and RAM costs (tiling all variables to the full size in memory before summing over the tiled values)? Or is there some clever optimisation under the hood which would avoid this?

I also only wanted it to (behave as though it) broadcast the dims that are summed over, but this looks like it will broadcast all dims including those not summed over?

Overall I think it'd be better to have an option on sum (like missing_dim='broadcast' as suggested in #6749), rather than documenting a partial workaround like this, given the caveats attached to the workaround and that (to me at least) the broadcasting sum is more in keeping with the usual mathematical semantics of 'sum' than what 'sum' currently does.

@dcherian
Copy link
Contributor

A more explicit API could be ds.broadcasting.sum()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants