From 00fdfd9c850c6f18418403ea2671687adb55f64e Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Tue, 17 Dec 2024 01:24:53 +0000 Subject: [PATCH] fill veg nodata with mode --- odc/stats/plugins/_utils.py | 86 +++++++++++++++++++++ odc/stats/plugins/lc_fc_wo_a0.py | 110 +++++++++++++++++++++++---- odc/stats/plugins/lc_veg_class_a1.py | 61 +++++++++++---- tests/test_landcover_plugin_a0.py | 35 +++++++-- tests/test_landcover_plugin_a1.py | 107 ++++++++++++++++++++++++-- 5 files changed, 358 insertions(+), 41 deletions(-) diff --git a/odc/stats/plugins/_utils.py b/odc/stats/plugins/_utils.py index b4deab9..a5d5729 100644 --- a/odc/stats/plugins/_utils.py +++ b/odc/stats/plugins/_utils.py @@ -1,7 +1,9 @@ import re import operator +import numpy as np import dask from osgeo import gdal, ogr, osr +from functools import partial def rasterize_vector_mask( @@ -140,3 +142,87 @@ def generate_numexpr_expressions(rules_df, final_class_column, previous): expressions = sorted(expressions, key=len) return expressions + + +def numpy_mode_exclude_nodata(values, target_value, exclude_values): + """ + Compute the mode of an array using NumPy, excluding nodata. + :param values: A flattened 1D array representing the neighborhood. + :param target_value: The value to be replaced + :param exclude_values: A list or set of values to exclude from the mode calculation. + :return: The mode of the array (smallest value in case of ties), excluding nodata. + """ + + valid_mask = ~( + np.isin(values, list(set(exclude_values) | {target_value})) | np.isnan(values) + ) + valid_values = values[valid_mask] + if len(valid_values) == 0: + return target_value + unique_vals, counts = np.unique(valid_values, return_counts=True) + max_count = counts.max() + # select the smallest value among ties + mode_value = unique_vals[counts == max_count].min() + return mode_value + + +def process_nodata_pixels(block, target_value, exclude_values, max_radius): + """ + Replace nodata pixels in a block with the mode of their 3x3 neighborhood. + :param block : numpy.ndarray The 2D array chunk. + :param target_value: The value to be replaced + :param exclude_values: A list or set of values to exclude from the mode calculation. + :param max_radius: maximum size of neighbourhood + :return: numpy.ndarray The modified block where nodata pixels are replaced. + """ + result = block.copy() + nodata_indices = np.argwhere(block == target_value) + + for i, j in nodata_indices: + # start from the smallest/nearest neighbourhood + # stop once finding the valid value otherwise expand till the max_radius + for radius in range(1, max_radius + 1): + i_min, i_max = max(0, i - radius), min(block.shape[0], i + radius + 1) + j_min, j_max = max(0, j - radius), min(block.shape[1], j + radius + 1) + + neighborhood = block[i_min:i_max, j_min:j_max].flatten() + tmp = numpy_mode_exclude_nodata(neighborhood, target_value, exclude_values) + if np.isnan(tmp) | (tmp == target_value): + continue + result[i, j] = tmp + break + + return result + + +def replace_nodata_with_mode( + arr, target_value, exclude_values=None, neighbourhood_size=3 +): + """ + Replace nodata-valued pixels in a Dask array with the mode of their neighborhood, + processing only the nodata pixels. + :param arr: A 2D Dask array. + :param target_value: The value to be replaced + :param exclude_values: A list or set of values to exclude from the mode calculation. + :param neighbourhood_size: the size of neighbourhood, e.g., 3:= 3*3 block, 5:=5*5 block + :return: A Dask array where nodata-valued pixels have been replaced. + """ + if exclude_values is None: + exclude_values = set() + + radius = neighbourhood_size // 2 + process_func = partial( + process_nodata_pixels, + target_value=target_value, + exclude_values=exclude_values, + max_radius=radius, + ) + # Use map_overlap to handle edges and target only the nodata pixels + result = arr.map_overlap( + process_func, + depth=(radius, radius), + boundary="nearest", + dtype=arr.dtype, + trim=True, + ) + return result diff --git a/odc/stats/plugins/lc_fc_wo_a0.py b/odc/stats/plugins/lc_fc_wo_a0.py index a9c37ab..230d064 100644 --- a/odc/stats/plugins/lc_fc_wo_a0.py +++ b/odc/stats/plugins/lc_fc_wo_a0.py @@ -65,15 +65,6 @@ def native_transform(self, xx): # clear dry pixels clear = xx["water"].data == 0 - # get "clear" wo pixels, both dry and wet used in water_frequency - wet_clear = expr_eval( - "where(a|b, a, _nan)", - {"a": wet, "b": clear}, - name="get_clear_pixels", - dtype="float32", - **{"_nan": np.nan}, - ) - # dilate both 'valid' and 'water' for key, val in self.BAD_BITS_MASK.items(): if self.cloud_filters.get(key) is not None: @@ -85,18 +76,42 @@ def native_transform(self, xx): "where(b>0, 0, a)", {"a": valid, "b": raw_mask.data}, name="get_valid_pixels", - dtype="uint8", + dtype="bool", ) - wet_clear = expr_eval( - "where(b>0, _nan, a)", - {"a": wet_clear, "b": raw_mask.data}, + clear = expr_eval( + "where(b>0, 0, a)", + {"a": clear, "b": raw_mask.data}, name="get_clear_pixels", - dtype="float32", - **{"_nan": np.nan}, + dtype="bool", + ) + wet = expr_eval( + "where(b>0, 0, a)", + {"a": wet, "b": raw_mask.data}, + name="get_wet_pixels", + dtype="bool", ) xx = xx.drop_vars(["water"]) + # get "clear" wo pixels, both dry and wet used in water_frequency + wet_clear = expr_eval( + "where(a|b, a, _nan)", + {"a": wet, "b": clear}, + name="get_clear_pixels", + dtype="float32", + **{"_nan": np.nan}, + ) + + # get "valid" wo pixels, both dry and wet + # to remark nodata reason in veg_frequency + wet_valid = expr_eval( + "where(a|b, a, _nan)", + {"a": wet, "b": valid}, + name="get_valid_pixels", + dtype="float32", + **{"_nan": np.nan}, + ) + # Pick out the fc pixels that have an unmixing error of less than the threshold valid = expr_eval( "where(b<_v, a, 0)", @@ -114,12 +129,16 @@ def native_transform(self, xx): xx["wet_clear"] = xr.DataArray( wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords ) + xx["wet_valid"] = xr.DataArray( + wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords + ) return xx def fuser(self, xx): wet_clear = xx["wet_clear"] + wet_valid = xx["wet_valid"] xx = _xr_fuse( xx.drop_vars(["wet_clear"]), @@ -128,6 +147,7 @@ def fuser(self, xx): ) xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan) + xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan) return xx @@ -171,6 +191,59 @@ def _water_or_not(self, xx: xr.Dataset): ) return data + def _wet_or_not(self, xx: xr.Dataset): + # mark water freq >= 0.5 as 1 + data = expr_eval( + "where(a>0, 1, 0)", + {"a": xx["wet_valid"].data}, + name="get_wet", + dtype="uint8", + ) + + # mark nans + data = expr_eval( + "where(a!=a, nodata, b)", + {"a": xx["wet_valid"].data, "b": data}, + name="get_wet", + dtype="uint8", + **{"nodata": int(NODATA)}, + ) + return data + + def _wet_valid_percent(self, data, nodata): + wet = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") + total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") + + for t in data: + # +1 if not nodata + wet = expr_eval( + "where(a==nodata, b, a+b)", + {"a": t, "b": wet}, + name="get_wet", + dtype="uint8", + **{"nodata": nodata}, + ) + + # total valid + total = expr_eval( + "where(a==nodata, b, b+1)", + {"a": t, "b": total}, + name="get_total_valid", + dtype="uint8", + **{"nodata": nodata}, + ) + + wet = expr_eval( + "where(a<=0, nodata, b/a*100)", + {"a": total, "b": wet}, + name="normalize_max_count", + dtype="float32", + **{"nodata": int(nodata)}, + ) + + wet = da.ceil(wet).astype("uint8") + return wet + def _max_consecutive_months(self, data, nodata, normalize=False): tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") @@ -264,11 +337,16 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: data = self._water_or_not(xx) max_count_water = self._max_consecutive_months(data, NODATA, normalize=True) + data = self._wet_or_not(xx) + wet_percent = self._wet_valid_percent(data, NODATA) + attrs = xx.attrs.copy() attrs["nodata"] = int(NODATA) data_vars = { k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs) - for k, v in zip(self.measurements, [max_count_veg, max_count_water]) + for k, v in zip( + self.measurements, [max_count_veg, max_count_water, wet_percent] + ) } coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:]) return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) diff --git a/odc/stats/plugins/lc_veg_class_a1.py b/odc/stats/plugins/lc_veg_class_a1.py index 18c2253..9107416 100644 --- a/odc/stats/plugins/lc_veg_class_a1.py +++ b/odc/stats/plugins/lc_veg_class_a1.py @@ -9,6 +9,7 @@ from odc.stats._algebra import expr_eval from ._registry import StatsPluginInterface, register +from ._utils import replace_nodata_with_mode NODATA = 255 @@ -117,6 +118,23 @@ def l3_class(self, xx: xr.Dataset): }, ) + # all unmarked values (0) and 255 > veg >= 2 is terretrial veg + l3_mask = expr_eval( + "where((a<=0)&(b0), m, b)", + "where(a&((b==_w)|(b==_s)), m, b)", {"a": data, "b": l3_mask}, name="intertidal_water", dtype="uint8", - **{"m": self.output_classes["intertidal"]}, + **{ + "m": self.output_classes["intertidal"], + "_w": self.output_classes["water"], + "_s": self.output_classes["surface"], + }, ) l3_mask = expr_eval( - "where(a&(b<=0), m, b)", + "where(a&(b==_v), m, b)", {"a": data, "b": l3_mask}, name="intertidal_veg", dtype="uint8", - **{"m": self.output_classes["aquatic_veg_herb"]}, + **{ + "m": self.output_classes["aquatic_veg_herb"], + "_v": self.output_classes["terrestrial_veg"], + }, ) elif b == "canopy_cover_class": # aquatic_veg: (mangroves > 0) & (mangroves != nodata) @@ -161,28 +186,34 @@ def l3_class(self, xx: xr.Dataset): }, ) - # all unmarked values (0) and 255 > veg >= 2 is terretrial veg + # all unmarked values (0) and wet_percentage != nodata is mode of neighbourhood + target_value = 254 l3_mask = expr_eval( - "where((a<=0)&(b>=2)&(b xr.Dataset: attrs["nodata"] = int(NODATA) data_vars = { k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs) - for k, v in zip(self.measurements, [l3_mask.squeeze(0)]) + for k, v in zip(self.measurements, [l3_mask]) } coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:]) return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) diff --git a/tests/test_landcover_plugin_a0.py b/tests/test_landcover_plugin_a0.py index 0285ec7..bf7446a 100644 --- a/tests/test_landcover_plugin_a0.py +++ b/tests/test_landcover_plugin_a0.py @@ -319,7 +319,9 @@ def fc_wo_dataset(): def test_native_transform(fc_wo_dataset, bits): xx = fc_wo_dataset.copy() xx["water"] = da.bitwise_or(xx["water"], bits) - stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) + stats_veg = StatsVegCount( + measurements=["veg_frequency", "water_frequency", "wet_percentage"] + ) out_xx = stats_veg.native_transform(xx).compute() expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2])) @@ -340,7 +342,9 @@ def test_native_transform(fc_wo_dataset, bits): def test_fusing(fc_wo_dataset): - stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) + stats_veg = StatsVegCount( + measurements=["veg_frequency", "water_frequency", "wet_percentage"] + ) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)).compute() valid_index = ( @@ -360,7 +364,9 @@ def test_fusing(fc_wo_dataset): def test_veg_or_not(fc_wo_dataset): - stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) + stats_veg = StatsVegCount( + measurements=["veg_frequency", "water_frequency", "wet_percentage"] + ) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) yy = stats_veg._veg_or_not(xx).compute() @@ -394,7 +400,9 @@ def test_water_or_not(fc_wo_dataset): def test_reduce(fc_wo_dataset): - stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) + stats_veg = StatsVegCount( + measurements=["veg_frequency", "water_frequency", "wet_percentage"] + ) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) xx = stats_veg.reduce(xx).compute() @@ -428,9 +436,26 @@ def test_reduce(fc_wo_dataset): assert (xx.water_frequency.data == expected_value).all() + expected_value = np.array( + [ + [0, 255, 100, 255, 255, 255, 255], + [0, 255, 255, 0, 255, 255, 255], + [255, 100, 255, 255, 0, 0, 255], + [255, 255, 0, 255, 255, 255, 255], + [255, 255, 255, 255, 255, 255, 255], + [255, 0, 255, 255, 255, 0, 255], + [255, 255, 255, 0, 255, 255, 50], + ], + dtype="uint8", + ) + + assert (xx.wet_percentage.data == expected_value).all() + def test_consecutive_month(consecutive_count): - stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) + stats_veg = StatsVegCount( + measurements=["veg_frequency", "water_frequency", "wet_percentage"] + ) xx = stats_veg._max_consecutive_months(consecutive_count, 255).compute() expected_value = np.array( [ diff --git a/tests/test_landcover_plugin_a1.py b/tests/test_landcover_plugin_a1.py index 8651734..5c49083 100644 --- a/tests/test_landcover_plugin_a1.py +++ b/tests/test_landcover_plugin_a1.py @@ -2,6 +2,7 @@ import xarray as xr import dask.array as da from odc.stats.plugins.lc_veg_class_a1 import StatsVegClassL1 +from odc.stats.plugins._utils import replace_nodata_with_mode import pytest import pandas as pd @@ -22,10 +23,14 @@ def dataset(): wo_fq = da.from_array(wo_fq, chunks=(1, -1, -1)) veg_fq = np.array( - [[[0, 3, 1, 2], [0, 7, 5, 0], [0, 2, 11, 3], [11, 5, 8, 4]]], dtype="uint8" + [[[0, 3, 1, 2], [0, 7, 5, 0], [0, 2, 11, 3], [11, 255, 8, 4]]], dtype="uint8" ) veg_fq = da.from_array(veg_fq, chunks=(1, -1, -1)) + wet_percentage = np.array( + [[[0, 10, 20, 30], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 80, 40]]], dtype="uint8" + ) + dem_h = np.array( [ [ @@ -58,7 +63,7 @@ def dataset(): [5529, 833, 580, 1144], [1172, 4680, 4999, 1746], [2702, 5572, 3048, 1382], - [3080, 3149, 4080, 2463], + [3080, -999, 4080, 2463], ] ], dtype="int16", @@ -71,7 +76,7 @@ def dataset(): [5159, 801, 4187, 1861], [1123, 5827, 5080, 3464], [1209, 1744, 4020, 413], - [4375, 4321, 4531, 4030], + [4375, -999, 4531, 4030], ] ], dtype="int16", @@ -84,7 +89,7 @@ def dataset(): [2798, 5539, 4431, 5996], [705, 2869, 4741, 4349], [1716, 4392, 5325, 878], - [4174, 3233, 3368, 1118], + [4174, -999, 3368, 1118], ] ], dtype="int16", @@ -106,6 +111,9 @@ def dataset(): "veg_frequency": xr.DataArray( veg_fq, dims=("spec", "y", "x"), attrs={"nodata": 255} ), + "wet_percentage": xr.DataArray( + wet_percentage, dims=("spec", "y", "x"), attrs={"nodata": 255} + ), "dem_h": xr.DataArray(dem_h, dims=("spec", "y", "x"), attrs={"nodata": np.nan}), "elevation": xr.DataArray( nidem, dims=("spec", "y", "x"), attrs={"nodata": np.nan} @@ -125,6 +133,95 @@ def dataset(): return xx +@pytest.fixture +def setup_data(): + target_value = 0 + # Case 1: Replace within smallest neighborhood (3x3) + input_1 = np.array( + [ + [1, 1, 1, 1, 1], + [1, 0, 2, 0, 1], + [1, 3, 4, 3, 1], + [1, 0, 2, 0, 1], + [1, 1, 1, 1, 1], + ] + ) + expected_1 = np.array( + [ + [1, 1, 1, 1, 1], + [1, 1, 2, 1, 1], + [1, 3, 4, 3, 1], + [1, 1, 2, 1, 1], + [1, 1, 1, 1, 1], + ] + ) + + # Case 2: Replace after expanding to maximum neighborhood (5x5) + input_2 = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ) + expected_2 = np.array( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + ] + ) # Correct propagation of '1' to the valid 5x5 neighborhood. + + # Case 3: No valid replacement (everything excluded) + input_3 = np.array( + [ + [5, 5, 5, 5, 5], + [5, 0, 5, 0, 5], + [5, 5, 5, 5, 5], + [5, 0, 5, 0, 5], + [5, 5, 5, 5, 5], + ] + ) + exclude_values_3 = [5] + expected_3 = np.array( + [ + [5, 5, 5, 5, 5], + [5, 0, 5, 0, 5], + [5, 5, 5, 5, 5], + [5, 0, 5, 0, 5], + [5, 5, 5, 5, 5], + ] + ) # Zeros remain unchanged because '5' is excluded. + + input_1 = da.from_array(input_1, chunks=(5, 5)) + + input_2 = da.from_array(input_2, chunks=(5, 5)) + + input_3 = da.from_array(input_3, chunks=(5, 5)) + + return [ + (input_1, target_value, expected_1, None), + (input_2, target_value, expected_2, None), + (input_3, target_value, expected_3, exclude_values_3), + ] + + +def test_replace_nodata_with_mode(setup_data): + for input_dask_array, target_value, expected, exclude_values in setup_data: + result = replace_nodata_with_mode( + input_dask_array, + target_value, + exclude_values=exclude_values, + neighbourhood_size=5, + ) + + assert (result.compute() == expected).all() + + def test_l3_classes(dataset): stats_l3 = StatsVegClassL1( output_classes={ @@ -151,7 +248,7 @@ def test_l3_classes(dataset): dtype="uint8", ) - res = stats_l3.l3_class(dataset) + res = stats_l3.l3_class(dataset).compute() assert (res == expected_res).all() res = stats_l3.reduce(dataset)