Skip to content

Commit

Permalink
Added a plugin for processing level3 Land Cover product. (#151)
Browse files Browse the repository at this point in the history
* Added a plugin for land cover level 3 and a unit test for this level.

* Add the level3 unit test

* Applied formatting.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added comments on classes and passed cultivated classification as is.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tebadi and pre-commit-ci[bot] authored Sep 18, 2024
1 parent cc2b79a commit b960c1d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
1 change: 1 addition & 0 deletions odc/stats/plugins/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def import_all():
# TODO: make that more automatic
modules = [
"odc.stats.plugins.lc_treelite_cultivated.py",
"odc.stats.plugins.lc_level3",
"odc.stats.plugins.lc_treelite_woody",
"odc.stats.plugins.lc_tf_urban",
"odc.stats.plugins.lc_veg_class_a1",
Expand Down
57 changes: 57 additions & 0 deletions odc/stats/plugins/lc_level3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Land Cover Level3 classification
"""

from typing import Tuple
import xarray as xr
from ._registry import StatsPluginInterface, register

NODATA = 255


class StatsLccsLevel3(StatsPluginInterface):
NAME = "ga_ls_lccs_level3"
SHORT_NAME = NAME
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["level3_class"]
return _measurements

def reduce(self, xx: xr.Dataset) -> xr.Dataset:

l34_dss = xx.classes_l3_l4
urban_dss = xx.urban_classes
cultivated_dss = xx.cultivated_class

# Cultivated pipeline applies a mask which feeds only terrestrial veg (110) to the model
# Just exclude no data (255) and apply the cultivated results
cultivated_mask = cultivated_dss != int(NODATA)
l34_cultivated_masked = xr.where(cultivated_mask, cultivated_dss, l34_dss)

# Urban is classified on l3/4 surface output (210)
urban_mask = l34_dss == 210
l34_urban_cultivated_masked = xr.where(
urban_mask, urban_dss, l34_cultivated_masked
)

attrs = xx.attrs.copy()
attrs["nodata"] = NODATA
l34_urban_cultivated_masked = l34_urban_cultivated_masked.squeeze(dim=["spec"])
dims = l34_urban_cultivated_masked.dims

data_vars = {
"level3_class": xr.DataArray(
l34_urban_cultivated_masked.data, dims=dims, attrs=attrs
)
}

coords = dict((dim, xx.coords[dim]) for dim in dims)
level3 = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)

return level3


register("lccs_level3", StatsLccsLevel3)
83 changes: 83 additions & 0 deletions tests/test_lc_level3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
import pandas as pd
import xarray as xr

from odc.stats.plugins.lc_level3 import StatsLccsLevel3
import pytest

expected_l3_classes = [
[111, 112, 215],
[124, 112, 215],
[221, 215, 216],
[223, 255, 223],
]


@pytest.fixture(scope="module")
def image_groups():
l34 = np.array(
[
[
[110, 110, 210],
[124, 110, 210],
[221, 210, 210],
[223, 255, 223],
]
],
dtype="int",
)

urban = np.array(
[
[
[215, 215, 215],
[216, 216, 215],
[116, 215, 216],
[216, 216, 216],
]
],
dtype="int",
)

cultivated = np.array(
[
[
[111, 112, 255],
[255, 112, 255],
[255, 255, 255],
[255, 255, 255],
]
],
dtype="int",
)

tuples = [
(np.datetime64("2000-01-01T00"), np.datetime64("2000-01-01")),
]
index = pd.MultiIndex.from_tuples(tuples, names=["time", "solar_day"])
coords = {
"x": np.linspace(10, 20, l34.shape[2]),
"y": np.linspace(0, 5, l34.shape[1]),
"spec": index,
}

data_vars = {
"classes_l3_l4": xr.DataArray(
l34, dims=("spec", "y", "x"), attrs={"nodata": 255}
),
"urban_classes": xr.DataArray(
urban, dims=("spec", "y", "x"), attrs={"nodata": 255}
),
"cultivated_class": xr.DataArray(
cultivated, dims=("spec", "y", "x"), attrs={"nodata": 255}
),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
return xx


def test_urban_class(image_groups):

lc_level3 = StatsLccsLevel3()
level3_classes = lc_level3.reduce(image_groups)
assert (level3_classes.level3_class.values == expected_l3_classes).all()

0 comments on commit b960c1d

Please sign in to comment.