Skip to content

Commit

Permalink
fill veg nodata with mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Dec 17, 2024
1 parent 97df5ad commit 00fdfd9
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 41 deletions.
86 changes: 86 additions & 0 deletions odc/stats/plugins/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
import operator
import numpy as np
import dask
from osgeo import gdal, ogr, osr
from functools import partial


def rasterize_vector_mask(
Expand Down Expand Up @@ -140,3 +142,87 @@ def generate_numexpr_expressions(rules_df, final_class_column, previous):
expressions = sorted(expressions, key=len)

return expressions


def numpy_mode_exclude_nodata(values, target_value, exclude_values):
"""
Compute the mode of an array using NumPy, excluding nodata.
:param values: A flattened 1D array representing the neighborhood.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:return: The mode of the array (smallest value in case of ties), excluding nodata.
"""

valid_mask = ~(
np.isin(values, list(set(exclude_values) | {target_value})) | np.isnan(values)
)
valid_values = values[valid_mask]
if len(valid_values) == 0:
return target_value
unique_vals, counts = np.unique(valid_values, return_counts=True)
max_count = counts.max()
# select the smallest value among ties
mode_value = unique_vals[counts == max_count].min()
return mode_value


def process_nodata_pixels(block, target_value, exclude_values, max_radius):
"""
Replace nodata pixels in a block with the mode of their 3x3 neighborhood.
:param block : numpy.ndarray The 2D array chunk.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:param max_radius: maximum size of neighbourhood
:return: numpy.ndarray The modified block where nodata pixels are replaced.
"""
result = block.copy()
nodata_indices = np.argwhere(block == target_value)

for i, j in nodata_indices:
# start from the smallest/nearest neighbourhood
# stop once finding the valid value otherwise expand till the max_radius
for radius in range(1, max_radius + 1):
i_min, i_max = max(0, i - radius), min(block.shape[0], i + radius + 1)
j_min, j_max = max(0, j - radius), min(block.shape[1], j + radius + 1)

neighborhood = block[i_min:i_max, j_min:j_max].flatten()
tmp = numpy_mode_exclude_nodata(neighborhood, target_value, exclude_values)
if np.isnan(tmp) | (tmp == target_value):
continue
result[i, j] = tmp
break

return result


def replace_nodata_with_mode(
arr, target_value, exclude_values=None, neighbourhood_size=3
):
"""
Replace nodata-valued pixels in a Dask array with the mode of their neighborhood,
processing only the nodata pixels.
:param arr: A 2D Dask array.
:param target_value: The value to be replaced
:param exclude_values: A list or set of values to exclude from the mode calculation.
:param neighbourhood_size: the size of neighbourhood, e.g., 3:= 3*3 block, 5:=5*5 block
:return: A Dask array where nodata-valued pixels have been replaced.
"""
if exclude_values is None:
exclude_values = set()

radius = neighbourhood_size // 2
process_func = partial(
process_nodata_pixels,
target_value=target_value,
exclude_values=exclude_values,
max_radius=radius,
)
# Use map_overlap to handle edges and target only the nodata pixels
result = arr.map_overlap(
process_func,
depth=(radius, radius),
boundary="nearest",
dtype=arr.dtype,
trim=True,
)
return result
110 changes: 94 additions & 16 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ def native_transform(self, xx):
# clear dry pixels
clear = xx["water"].data == 0

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
if self.cloud_filters.get(key) is not None:
Expand All @@ -85,18 +76,42 @@ def native_transform(self, xx):
"where(b>0, 0, a)",
{"a": valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="uint8",
dtype="bool",
)
wet_clear = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_clear, "b": raw_mask.data},
clear = expr_eval(
"where(b>0, 0, a)",
{"a": clear, "b": raw_mask.data},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
dtype="bool",
)
wet = expr_eval(
"where(b>0, 0, a)",
{"a": wet, "b": raw_mask.data},
name="get_wet_pixels",
dtype="bool",
)

xx = xx.drop_vars(["water"])

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# get "valid" wo pixels, both dry and wet
# to remark nodata reason in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
valid = expr_eval(
"where(b<_v, a, 0)",
Expand All @@ -114,12 +129,16 @@ def native_transform(self, xx):
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)

return xx

def fuser(self, xx):

wet_clear = xx["wet_clear"]
wet_valid = xx["wet_valid"]

xx = _xr_fuse(
xx.drop_vars(["wet_clear"]),
Expand All @@ -128,6 +147,7 @@ def fuser(self, xx):
)

xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)

return xx

Expand Down Expand Up @@ -171,6 +191,59 @@ def _water_or_not(self, xx: xr.Dataset):
)
return data

def _wet_or_not(self, xx: xr.Dataset):
# mark water freq >= 0.5 as 1
data = expr_eval(
"where(a>0, 1, 0)",
{"a": xx["wet_valid"].data},
name="get_wet",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet_valid"].data, "b": data},
name="get_wet",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _wet_valid_percent(self, data, nodata):
wet = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")

for t in data:
# +1 if not nodata
wet = expr_eval(
"where(a==nodata, b, a+b)",
{"a": t, "b": wet},
name="get_wet",
dtype="uint8",
**{"nodata": nodata},
)

# total valid
total = expr_eval(
"where(a==nodata, b, b+1)",
{"a": t, "b": total},
name="get_total_valid",
dtype="uint8",
**{"nodata": nodata},
)

wet = expr_eval(
"where(a<=0, nodata, b/a*100)",
{"a": total, "b": wet},
name="normalize_max_count",
dtype="float32",
**{"nodata": int(nodata)},
)

wet = da.ceil(wet).astype("uint8")
return wet

def _max_consecutive_months(self, data, nodata, normalize=False):
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
Expand Down Expand Up @@ -264,11 +337,16 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

data = self._wet_or_not(xx)
wet_percent = self._wet_valid_percent(data, NODATA)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
for k, v in zip(
self.measurements, [max_count_veg, max_count_water, wet_percent]
)
}
coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
61 changes: 46 additions & 15 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from odc.stats._algebra import expr_eval

from ._registry import StatsPluginInterface, register
from ._utils import replace_nodata_with_mode

NODATA = 255

Expand Down Expand Up @@ -117,6 +118,23 @@ def l3_class(self, xx: xr.Dataset):
},
)

# all unmarked values (0) and 255 > veg >= 2 is terretrial veg
l3_mask = expr_eval(
"where((a<=0)&(b<nodata), m, a)",
{"a": l3_mask, "b": xx["veg_frequency"].data},
name="mark_veg",
dtype="uint8",
**{
"m": self.output_classes["terrestrial_veg"],
"nodata": (
NODATA
if xx["veg_frequency"].attrs["nodata"]
!= xx["veg_frequency"].attrs["nodata"]
else xx["veg_frequency"].attrs["nodata"]
),
},
)

# if its mangrove or coast region
for b in self.optional_bands:
if b in xx.data_vars:
Expand All @@ -133,19 +151,26 @@ def l3_class(self, xx: xr.Dataset):
)

l3_mask = expr_eval(
"where(a&(b>0), m, b)",
"where(a&((b==_w)|(b==_s)), m, b)",
{"a": data, "b": l3_mask},
name="intertidal_water",
dtype="uint8",
**{"m": self.output_classes["intertidal"]},
**{
"m": self.output_classes["intertidal"],
"_w": self.output_classes["water"],
"_s": self.output_classes["surface"],
},
)

l3_mask = expr_eval(
"where(a&(b<=0), m, b)",
"where(a&(b==_v), m, b)",
{"a": data, "b": l3_mask},
name="intertidal_veg",
dtype="uint8",
**{"m": self.output_classes["aquatic_veg_herb"]},
**{
"m": self.output_classes["aquatic_veg_herb"],
"_v": self.output_classes["terrestrial_veg"],
},
)
elif b == "canopy_cover_class":
# aquatic_veg: (mangroves > 0) & (mangroves != nodata)
Expand All @@ -161,28 +186,34 @@ def l3_class(self, xx: xr.Dataset):
},
)

# all unmarked values (0) and 255 > veg >= 2 is terretrial veg
# all unmarked values (0) and wet_percentage != nodata is mode of neighbourhood
target_value = 254
l3_mask = expr_eval(
"where((a<=0)&(b>=2)&(b<nodata), m, a)",
{"a": l3_mask, "b": xx["veg_frequency"].data},
name="mark_veg",
"where((a<=0)&(b<nodata), _u, a)",
{"a": l3_mask, "b": xx["wet_percentage"].data},
name="mark_other_valid",
dtype="uint8",
**{
"m": self.output_classes["terrestrial_veg"],
"nodata": (
NODATA
if xx["veg_frequency"].attrs["nodata"]
!= xx["veg_frequency"].attrs["nodata"]
else xx["veg_frequency"].attrs["nodata"]
if xx["wet_percentage"].attrs["nodata"]
!= xx["wet_percentage"].attrs["nodata"]
else xx["wet_percentage"].attrs["nodata"]
),
"_u": target_value,
},
)
l3_mask = replace_nodata_with_mode(
l3_mask.squeeze(0),
target_value=target_value,
exclude_values=[0],
neighbourhood_size=5,
)

# Mask nans and pixels where non of classes applicable
l3_mask = expr_eval(
"where((a!=a)|(e<=0), nodata, e)",
"where((e<=0)|(e==254), nodata, e)",
{
"a": si5,
"e": l3_mask,
},
name="mark_nodata",
Expand All @@ -198,7 +229,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [l3_mask.squeeze(0)])
for k, v in zip(self.measurements, [l3_mask])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
Loading

0 comments on commit 00fdfd9

Please sign in to comment.