Skip to content

Commit

Permalink
move columns of dependencies into config
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 28, 2024
1 parent daf6e19 commit 50ed73b
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions odc/stats/plugins/lc_level34.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
Plugin of Module A3 in LandCover PipeLine
"""

from typing import Optional, Dict
from typing import Optional, Dict, List

import xarray as xr
import s3fs
import os
import pandas as pd
import dask.array as da
import logging

from ._registry import StatsPluginInterface, register
from ._utils import rasterize_vector_mask, generate_numexpr_expressions
from odc.stats._algebra import expr_eval
from osgeo import gdal

NODATA = 255
_log = logging.getLogger(__name__)


class StatsLccsLevel4(StatsPluginInterface):
Expand All @@ -27,6 +29,7 @@ class StatsLccsLevel4(StatsPluginInterface):
def __init__(
self,
class_def_path: str = None,
class_condition: Dict[str, List] = None,
urban_mask: str = None,
filter_expression: str = None,
mask_threshold: Optional[float] = None,
Expand All @@ -43,6 +46,9 @@ def __init__(
elif not os.path.exists(class_def_path):
raise FileNotFoundError(f"{class_def_path} not found")

if class_condition is None:
raise ValueError("Missing input to generate classification conditions")

if urban_mask is None:
raise ValueError("Missing urban mask shapefile")

Expand All @@ -54,8 +60,12 @@ def __init__(
raise ValueError("Missing urban mask filter")

self.class_def = pd.read_csv(class_def_path)
cols = list(self.class_def.columns[:6]) + list(self.class_def.columns[9:-6])
self.class_def = self.class_def[cols].astype(str).fillna("nan")
self.class_condition = class_condition
cols = set()
for k, v in self.class_condition.items():
cols |= {k} | set(v)

self.class_def = self.class_def[list(cols)].astype(str).fillna("nan")

self.urban_mask = urban_mask
self.filter_expression = filter_expression
Expand All @@ -77,6 +87,7 @@ def classification(self, xx, class_def, con_cols, class_col):
res = da.full(xx.level_3_4.shape, 0, dtype="uint8")

for expression in expressions:
_log.info(expression)
local_dict.update({"res": res})
res = expr_eval(
expression,
Expand All @@ -98,9 +109,10 @@ def classification(self, xx, class_def, con_cols, class_col):
return res

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
con_cols = ["level1", "artificial_surface", "cultivated"]
class_col = "level3"
level3 = self.classification(xx, self.class_def, con_cols, class_col)
level3 = self.classification(
xx, self.class_def, self.class_condition[class_col], class_col
)

# apply urban mask
# 215 -> 216 if urban_mask == 0
Expand All @@ -119,6 +131,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
dtype="uint8",
)

# append level3 to the input dataset so it can be used
# to classify level4
attrs = xx.attrs.copy()
attrs["nodata"] = NODATA
dims = xx.level_3_4.dims[1:]
Expand All @@ -127,18 +141,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
level3.squeeze(), dims=dims, attrs=attrs, coords=coords
)

con_cols = [
"level1",
"level3",
"woody",
"water_season",
"water_frequency",
"pv_pc_50",
"bs_pc_50",
]
class_col = "level4"

level4 = self.classification(xx, self.class_def, con_cols, class_col)
level4 = self.classification(
xx, self.class_def, self.class_condition[class_col], class_col
)

data_vars = {
k: xr.DataArray(v, dims=dims, attrs=attrs)
Expand Down

0 comments on commit 50ed73b

Please sign in to comment.