Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy mask_fillvalues preprocessor function #2340

Merged
merged 13 commits into from
Jun 4, 2024
Merged
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies:
- geopy
- humanfriendly
- importlib_metadata # required for Python < 3.10
- iris >3.8.0
- iris >=3.9.0
- iris-esmf-regrid >=0.10.0 # github.com/SciTools-incubator/iris-esmf-regrid/pull/342
- isodate
- jinja2
Expand Down
149 changes: 87 additions & 62 deletions esmvalcore/preprocessor/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
masking with ancillary variables, masking with Natural Earth shapefiles
(land or ocean), masking on thresholds, missing values masking.
"""
from __future__ import annotations

import logging
import os
Expand Down Expand Up @@ -323,7 +324,14 @@ def _mask_with_shp(cube, shapefilename, region_indices=None):
return cube


def count_spells(data, threshold, axis, spell_length):
def count_spells(
data: np.ndarray | da.Array,
threshold: float | None,
axis: int,
spell_length,
) -> np.ndarray | da.Array:
# Copied from:
# https://scitools-iris.readthedocs.io/en/stable/generated/gallery/general/plot_custom_aggregation.html
"""Count data occurrences.

Define a function to perform the custom statistical operation.
Expand All @@ -338,10 +346,10 @@ def count_spells(data, threshold, axis, spell_length):

Parameters
----------
data: ndarray
data:
raw data to be compared with value threshold.

threshold: float
threshold:
threshold point for 'significant' datapoints.

axis: int
Expand All @@ -353,15 +361,17 @@ def count_spells(data, threshold, axis, spell_length):

Returns
-------
int
:obj:`numpy.ndarray` or :obj:`dask.array.Array`
Number of counts.
"""
if axis < 0:
# just cope with negative axis numbers
axis += data.ndim
# Threshold the data to find the 'significant' points.
array_module = da if isinstance(data, da.Array) else np
if not threshold:
data_hits = np.ones_like(data, dtype=bool)
# Keeps the mask of the input data.
data_hits = array_module.ma.ones_like(data, dtype=bool)
else:
data_hits = data > float(threshold)

Expand All @@ -371,17 +381,16 @@ def count_spells(data, threshold, axis, spell_length):
# if you want overlapping windows set the step to be m*spell_length
# where m is a float
###############################################################
hit_windows = rolling_window(data_hits,
window=spell_length,
step=spell_length,
axis=axis)

hit_windows = rolling_window(
data_hits,
window=spell_length,
step=spell_length,
axis=axis,
)
# Find the windows "full of True-s" (along the added 'window axis').
full_windows = np.all(hit_windows, axis=axis + 1)

full_windows = array_module.all(hit_windows, axis=axis + 1)
# Count points fulfilling the condition (along the time axis).
spell_point_counts = np.sum(full_windows, axis=axis, dtype=int)

spell_point_counts = array_module.sum(full_windows, axis=axis, dtype=int)
return spell_point_counts


Expand Down Expand Up @@ -572,10 +581,12 @@ def mask_multimodel(products):
f"got {product_types}")


def mask_fillvalues(products,
threshold_fraction,
min_value=None,
time_window=1):
def mask_fillvalues(
products,
threshold_fraction: float,
min_value: float | None = None,
time_window: int = 1,
):
"""Compute and apply a multi-dataset fillvalues mask.

Construct the mask that fills a certain time window with missing values
Expand All @@ -590,15 +601,15 @@ def mask_fillvalues(products,
products: iris.cube.Cube
data products to be masked.

threshold_fraction: float
threshold_fraction:
fractional threshold to be used as argument for Aggregator.
Must be between 0 and 1.

min_value: float
min_value:
minimum value threshold; default None
If default, no thresholding applied so the full mask will be selected.

time_window: float
time_window:
time window to compute missing data counts; default set to 1.

Returns
Expand All @@ -611,48 +622,58 @@ def mask_fillvalues(products,
NotImplementedError
Implementation missing for data with higher dimensionality than 4.
"""
combined_mask = None
array_module = da if any(c.has_lazy_data() for p in products
for c in p.cubes) else np

logger.debug("Creating fillvalues mask")
used = set()
combined_mask = None
for product in products:
for cube in product.cubes:
cube.data = np.ma.fix_invalid(cube.data, copy=False)
mask = _get_fillvalues_mask(cube, threshold_fraction, min_value,
time_window)
for i, cube in enumerate(product.cubes):
cube = cube.copy()
product.cubes[i] = cube
cube.data = array_module.ma.fix_invalid(cube.core_data())
mask = _get_fillvalues_mask(
cube,
threshold_fraction,
min_value,
time_window,
)
if combined_mask is None:
combined_mask = np.zeros_like(mask)
combined_mask = array_module.zeros_like(mask)
# Select only valid (not all masked) pressure levels
n_dims = len(mask.shape)
if n_dims == 2:
valid = ~np.all(mask)
if valid:
combined_mask |= mask
used.add(product)
elif n_dims == 3:
valid = ~np.all(mask, axis=(1, 2))
combined_mask[valid] |= mask[valid]
if np.any(valid):
used.add(product)
if mask.ndim in (2, 3):
valid = ~mask.all(axis=(-2, -1), keepdims=True)
else:
raise NotImplementedError(
f"Unable to handle {n_dims} dimensional data"
f"Unable to handle {mask.ndim} dimensional data"
)
combined_mask = array_module.where(
valid,
combined_mask | mask,
combined_mask,
)

if np.any(combined_mask):
logger.debug("Applying fillvalues mask")
used = {p.copy_provenance() for p in used}
for product in products:
for cube in product.cubes:
cube.data.mask |= combined_mask
for other in used:
if other.filename != product.filename:
product.wasderivedfrom(other)
for product in products:
for cube in product.cubes:
array = cube.core_data()
data = array_module.ma.getdata(array)
mask = array_module.ma.getmaskarray(array) | combined_mask
cube.data = array_module.ma.masked_array(data, mask)

# Record provenance
input_products = {p.copy_provenance() for p in products}
for other in input_products:
if other.filename != product.filename:
product.wasderivedfrom(other)
schlunma marked this conversation as resolved.
Show resolved Hide resolved

return products


def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
def _get_fillvalues_mask(
cube: iris.cube.Cube,
threshold_fraction: float,
min_value: float | None,
time_window: int,
) -> np.ndarray | da.Array:
"""Compute the per-model missing values mask.

Construct the mask that fills a certain time window with missing
Expand All @@ -662,7 +683,6 @@ def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
counts the number of valid (unmasked) data points within that
window; a simple value thresholding is also applied if needed.
"""
# basic checks
if threshold_fraction < 0 or threshold_fraction > 1.0:
raise ValueError(
f"Fraction of missing values {threshold_fraction} should be "
Expand All @@ -678,19 +698,24 @@ def _get_fillvalues_mask(cube, threshold_fraction, min_value, time_window):
counts_threshold = int(max_counts_per_time_window * threshold_fraction)

# Make an aggregator
spell_count = Aggregator('spell_count',
count_spells,
units_func=lambda units: 1)
spell_count = Aggregator(
'spell_count',
count_spells,
lazy_func=count_spells,
units_func=lambda units: 1,
)

# Calculate the statistic.
counts_windowed_cube = cube.collapsed('time',
spell_count,
threshold=min_value,
spell_length=time_window)
counts_windowed_cube = cube.collapsed(
'time',
spell_count,
threshold=min_value,
spell_length=time_window,
)

# Create mask
mask = counts_windowed_cube.data < counts_threshold
if np.ma.isMaskedArray(mask):
mask = mask.data | mask.mask
mask = counts_windowed_cube.core_data() < counts_threshold
array_module = da if isinstance(mask, da.Array) else np
mask = array_module.ma.filled(mask, True)

return mask
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@
'pyyaml',
'requests',
'scipy>=1.6',
# See the following issue for info on the iris pin below:
# https://github.com/ESMValGroup/ESMValCore/issues/2407
'scitools-iris>3.8.0',
'scitools-iris>=3.9.0',
'shapely>=2.0.0',
'stratify>=0.3',
'yamale',
Expand Down
35 changes: 32 additions & 3 deletions tests/integration/preprocessor/_mask/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ def test_mask_landseaice(self):
np.ma.set_fill_value(expected, 1e+20)
assert_array_equal(result_ice.data, expected)

def test_mask_fillvalues(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues(self, mocker, lazy):
"""Test the fillvalues mask: func mask_fillvalues."""
data_1 = data_2 = self.mock_data
data_2.mask = np.ones((4, 3, 3), bool)
coords_spec = [(self.times, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()
filename_1 = 'file1.nc'
filename_2 = 'file2.nc'
product_1 = mocker.create_autospec(
Expand All @@ -215,10 +219,17 @@ def test_mask_fillvalues(self, mocker):
result_1 = product.cubes[0]
if product.filename == filename_2:
result_2 = product.cubes[0]

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert result_1.has_lazy_data() == lazy
assert result_2.has_lazy_data() == lazy

assert_array_equal(result_2.data.mask, data_2.mask)
assert_array_equal(result_1.data, data_1)

def test_mask_fillvalues_zero_threshold(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues_zero_threshold(self, mocker, lazy):
"""Test the fillvalues mask: func mask_fillvalues for 0-threshold."""
data_1 = self.mock_data
data_2 = self.mock_data[0:3]
Expand All @@ -232,6 +243,10 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
coords_spec2 = [(self.time2, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec2)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()

filename_1 = Path('file1.nc')
filename_2 = Path('file2.nc')
product_1 = mocker.create_autospec(
Expand All @@ -255,6 +270,12 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
result_1 = product.cubes[0]
if product.filename == filename_2:
result_2 = product.cubes[0]

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert result_1.has_lazy_data() == lazy
assert result_2.has_lazy_data() == lazy

# identical masks
assert_array_equal(
result_2.data[0, ...].mask,
Expand All @@ -265,7 +286,8 @@ def test_mask_fillvalues_zero_threshold(self, mocker):
assert_array_equal(result_1[1:2].data.mask, cumulative_mask)
assert_array_equal(result_2[2:3].data.mask, cumulative_mask)

def test_mask_fillvalues_min_value_none(self, mocker):
@pytest.mark.parametrize('lazy', [True, False])
def test_mask_fillvalues_min_value_none(self, mocker, lazy):
"""Test ``mask_fillvalues`` for min_value=None."""
# We use non-masked data here and explicitly set some values to 0 here
# since this caused problems in the past, see
Expand All @@ -278,6 +300,10 @@ def test_mask_fillvalues_min_value_none(self, mocker):
coords_spec2 = [(self.time2, 0), (self.lats, 1), (self.lons, 2)]
cube_1 = iris.cube.Cube(data_1, dim_coords_and_dims=coords_spec)
cube_2 = iris.cube.Cube(data_2, dim_coords_and_dims=coords_spec2)
if lazy:
cube_1.data = cube_1.lazy_data().rechunk((2, None, None))
cube_2.data = cube_2.lazy_data()

filename_1 = Path('file1.nc')
filename_2 = Path('file2.nc')

Expand All @@ -303,10 +329,13 @@ def test_mask_fillvalues_min_value_none(self, mocker):
min_value=None,
)

assert cube_1.has_lazy_data() == lazy
assert cube_2.has_lazy_data() == lazy
assert len(results) == 2
for product in results:
if product.filename in (filename_1, filename_2):
assert len(product.cubes) == 1
assert product.cubes[0].has_lazy_data() == lazy
assert not np.ma.is_masked(product.cubes[0].data)
else:
assert False, f"Invalid filename: {product.filename}"
1 change: 1 addition & 0 deletions tests/unit/preprocessor/_mask/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

import iris
import iris.fileformats
import tests
from cf_units import Unit
from esmvalcore.preprocessor._mask import (_apply_fx_mask,
Expand Down