Skip to content

Commit

Permalink
fix data type casting issue in l3 plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Oct 23, 2024
1 parent bac430c commit 6befa4d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
36 changes: 21 additions & 15 deletions odc/stats/plugins/lc_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Tuple
import xarray as xr
from odc.stats._algebra import expr_eval
from ._registry import StatsPluginInterface, register

NODATA = 255
Expand All @@ -22,30 +23,35 @@ def measurements(self) -> Tuple[str, ...]:

def reduce(self, xx: xr.Dataset) -> xr.Dataset:

l34_dss = xx.classes_l3_l4
urban_dss = xx.urban_classes
cultivated_dss = xx.cultivated_class

# Cultivated pipeline applies a mask which feeds only terrestrial veg (110) to the model
# Just exclude no data (255) and apply the cultivated results
cultivated_mask = cultivated_dss != int(NODATA)
l34_cultivated_masked = xr.where(cultivated_mask, cultivated_dss, l34_dss)
res = expr_eval(
"where(a<nodata, a, b)",
{"a": xx.cultivated_class.data, "b": xx.classes_l3_l4.data},
name="mask_cultivated",
dtype="float32",
**{"nodata": NODATA},
)

# Urban is classified on l3/4 surface output (210)
urban_mask = l34_dss == 210
l34_urban_cultivated_masked = xr.where(
urban_mask, urban_dss, l34_cultivated_masked
# Mask urban results with bare sfc (210)

res = expr_eval(
"where(a==_u, b, a)",
{
"a": res,
"b": xx.urban_classes.data,
},
name="mark_urban",
dtype="uint8",
**{"_u": 210},
)

attrs = xx.attrs.copy()
attrs["nodata"] = NODATA
l34_urban_cultivated_masked = l34_urban_cultivated_masked.squeeze(dim=["spec"])
dims = l34_urban_cultivated_masked.dims
dims = xx.classes_l3_l4.dims[1:]

data_vars = {
"level3_class": xr.DataArray(
l34_urban_cultivated_masked.data, dims=dims, attrs=attrs
)
"level3_class": xr.DataArray(res.squeeze(), dims=dims, attrs=attrs)
}

coords = dict((dim, xx.coords[dim]) for dim in dims)
Expand Down
19 changes: 13 additions & 6 deletions tests/test_lc_level3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import xarray as xr
import dask.array as da

from odc.stats.plugins.lc_level3 import StatsLccsLevel3
import pytest
Expand All @@ -24,7 +25,7 @@ def image_groups():
[223, 255, 223],
]
],
dtype="int",
dtype="uint8",
)

urban = np.array(
Expand All @@ -36,7 +37,7 @@ def image_groups():
[216, 216, 216],
]
],
dtype="int",
dtype="uint8",
)

cultivated = np.array(
Expand All @@ -48,7 +49,7 @@ def image_groups():
[255, 255, 255],
]
],
dtype="int",
dtype="uint8",
)

tuples = [
Expand All @@ -63,13 +64,19 @@ def image_groups():

data_vars = {
"classes_l3_l4": xr.DataArray(
l34, dims=("spec", "y", "x"), attrs={"nodata": 255}
da.from_array(l34, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"urban_classes": xr.DataArray(
urban, dims=("spec", "y", "x"), attrs={"nodata": 255}
da.from_array(urban, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"cultivated_class": xr.DataArray(
cultivated, dims=("spec", "y", "x"), attrs={"nodata": 255}
da.from_array(cultivated, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
Expand Down

0 comments on commit 6befa4d

Please sign in to comment.