Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fill_value and set_empty_bucket_to in BucketResampler get_sum #602

Merged
merged 11 commits into from
Jul 24, 2024
53 changes: 37 additions & 16 deletions pyresample/bucket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _expand_bin_statistics(bins, unique_bin, unique_idx, weights_sorted):
# assign the valid index to array
weight_idx[unique_bin[~unique_bin.mask].data] = unique_idx[~unique_bin.mask]

return weights_sorted[weight_idx] # last value of weigths_sorted always nan
return weights_sorted[weight_idx] # last value of weights_sorted always nan


@dask.delayed(pure=True)
Expand Down Expand Up @@ -202,20 +202,29 @@ def _get_indices(self):
target_shape = self.target_area.shape
self.idxs = self.y_idxs * target_shape[1] + self.x_idxs

def get_sum(self, data, skipna=True):
def get_sum(self, data, fill_value=np.nan, skipna=True, empty_bucket_value=0):
"""Calculate sums for each bin with drop-in-a-bucket resampling.

Parameters
----------
data : Numpy or Dask array
Data to be binned and summed.
fill_value : float
Fill value of the input data marking missing/invalid values.
Default: np.nan
skipna : boolean (optional)
If True, skips NaN values for the sum calculation
(similarly to Numpy's `nansum`). Buckets containing only NaN are set to zero.
If False, sets the bucket to NaN if one or more NaN values are present in the bucket
(similarly to Numpy's `sum`).
In both cases, empty buckets are set to 0.
Default: True
If True, skips missing values (as marked by NaN or `fill_value`) for the sum calculation
(similarly to Numpy's `nansum`). Buckets containing only missing values are set to `empty_bucket_value`.
If False, sets the bucket to fill_value if one or more missing values are present in the bucket
(similarly to Numpy's `sum`).
In both cases, empty buckets are set to `empty_bucket_value`.
Default: True
empty_bucket_value : float
Set empty buckets to the given value. Empty buckets are considered as the buckets with value 0.
Note that a bucket could become 0 as the result of a sum
of positive and negative values. If the user needs to identify these zero-buckets reliably,
`get_count()` can be used for this purpose.
Default: 0

Returns
-------
Expand All @@ -228,8 +237,9 @@ def get_sum(self, data, skipna=True):
data = data.data
data = data.ravel()

# Remove NaN values from the data when used as weights
weights = da.where(np.isnan(data), 0, data)
# Remove fill_values values from the data when used as weights
invalid_mask = _get_invalid_mask(data, fill_value)
weights = da.where(invalid_mask, 0, data)

# Rechunk indices to match the data chunking
if weights.chunks != self.idxs.chunks:
Expand All @@ -241,16 +251,19 @@ def get_sum(self, data, skipna=True):
weights=weights, density=False)

# TODO remove following line in favour of weights = data when dask histogram bug (issue #6935) is fixed
sums = self._mask_bins_with_nan_if_not_skipna(skipna, data, out_size, sums)
sums = self._mask_bins_with_nan_if_not_skipna(skipna, data, out_size, sums, fill_value)

if empty_bucket_value != 0:
sums = da.where(sums == 0, empty_bucket_value, sums)

return sums.reshape(self.target_area.shape)

def _mask_bins_with_nan_if_not_skipna(self, skipna, data, out_size, statistic):
def _mask_bins_with_nan_if_not_skipna(self, skipna, data, out_size, statistic, fill_value):
if not skipna:
nans = np.isnan(data)
nan_bins, _ = da.histogram(self.idxs[nans], bins=out_size,
range=(0, out_size))
statistic = da.where(nan_bins > 0, np.nan, statistic)
missing_val = _get_invalid_mask(data, fill_value)
missing_val_bins, _ = da.histogram(self.idxs[missing_val], bins=out_size,
range=(0, out_size))
statistic = da.where(missing_val_bins > 0, fill_value, statistic)
return statistic

def _call_bin_statistic(self, statistic_method, data, fill_value=None, skipna=None):
Expand Down Expand Up @@ -456,6 +469,14 @@ def get_fractions(self, data, categories=None, fill_value=np.nan):
return results


def _get_invalid_mask(data, fill_value):
"""Get a boolean array where values equal to fill_value in data are True."""
if np.isnan(fill_value):
return np.isnan(data)
else:
return data == fill_value


def round_to_resolution(arr, resolution):
"""Round the values in *arr* to closest resolution element.

Expand Down
Loading
Loading