Skip to content

Commit

Permalink
Merge pull request #63 from digitalearthpacific/update-tweak-for-geomad
Browse files Browse the repository at this point in the history
Update tweak for geomad
  • Loading branch information
alexgleith authored Oct 2, 2024
2 parents f9f5e2d + d22ef59 commit f25c342
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 31 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# Ignore cached geopackage files
*.gpkg
7 changes: 6 additions & 1 deletion dep_tools/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def write_to_s3(
s3_dump(buffer.read(), bucket, key, client, **s3_dump_kwargs)
elif isinstance(d, Item):
s3_dump(
json.dumps(d.to_dict(), indent=4), bucket, key, client, **s3_dump_kwargs
json.dumps(d.to_dict(), indent=4),
bucket,
key,
client,
ContentType="application/json",
**s3_dump_kwargs,
)
elif isinstance(d, str):
s3_dump(d, bucket, key, client, **s3_dump_kwargs)
Expand Down
105 changes: 99 additions & 6 deletions dep_tools/grids.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,102 @@
from json import loads
from pathlib import Path
from typing import Literal

import antimeridian
import geopandas as gpd
import pandas as pd
from geopandas import GeoDataFrame, GeoSeries
from odc.geo import XY, BoundingBox
from odc.geo import XY, BoundingBox, Geometry
from odc.geo.gridspec import GridSpec
from shapely.geometry import shape

# This EPSG code is what we're using for now
# but it's not ideal, as its not an equal area projection...
PACIFIC_EPSG = "EPSG:3832"

GADM_FILE = Path(__file__).parent / "gadm_pacific.gpkg"
GADM_UNION_FILE = Path(__file__).parent / "gadm_pacific_union.gpkg"
COUNTRIES_AND_CODES = {
"American Samoa": "ASM",
"Cook Islands": "COK",
"Fiji": "FJI",
"French Polynesia": "PYF",
"Guam": "GUM",
"Kiribati": "KIR",
"Marshall Islands": "MHL",
"Micronesia": "FSM",
"Nauru": "NRU",
"New Caledonia": "NCL",
"Niue": "NIU",
"Northern Mariana Islands": "MNP",
"Palau": "PLW",
"Papua New Guinea": "PNG",
"Pitcairn Islands": "PCN",
"Solomon Islands": "SLB",
"Samoa": "WSM",
"Tokelau": "TKL",
"Tonga": "TON",
"Tuvalu": "TUV",
"Vanuatu": "VUT",
"Wallis and Futuna": "WLF",
}


def _get_gadm() -> GeoDataFrame:
if not GADM_FILE.exists() or not GADM_UNION_FILE.exists():
all_polys = pd.concat(
[
gpd.read_file(
f"https://geodata.ucdavis.edu/gadm/gadm4.1/gpkg/gadm41_{code}.gpkg",
layer="ADM_ADM_0",
)
for code in COUNTRIES_AND_CODES.values()
]
)

all_polys.to_file(GADM_FILE)
all_polys.dissolve()[["geometry"]].to_file(GADM_UNION_FILE)

return gpd.read_file(GADM_FILE)


def _get_gadm_union() -> GeoDataFrame:
if not GADM_UNION_FILE.exists():
_get_gadm()

return gpd.read_file(GADM_UNION_FILE)


def get_tiles(
resolution: int | float = 30,
country_codes: list[str] | None = None,
buffer_distance: int | float | None = None,
) -> list[(list[int, int], GridSpec)]:
"""Returns a list of tile IDs for the Pacific region, optionally filtered by country code."""

if country_codes is None:
geometries = _get_gadm_union()
else:
if not all(code in COUNTRIES_AND_CODES.values() for code in country_codes):
raise ValueError(
f"Invalid country code. Must be one of {', '.join(COUNTRIES_AND_CODES.values())}"
)
geometries = _get_gadm().loc[lambda df: df["GID_0"].isin(country_codes)]

return grid(
resolution=resolution,
return_type="GridSpec",
intersect_with=geometries,
buffer_distance=buffer_distance,
)


def grid(
resolution: int | float = 30,
crs=PACIFIC_EPSG,
return_type: Literal["GridSpec", "GeoSeries", "GeoDataFrame"] = "GridSpec",
intersect_with: GeoDataFrame | None = None,
buffer_distance: int | float | None = None,
) -> GridSpec | GeoSeries | GeoDataFrame:
"""Returns a GridSpec or GeoSeries representing the Pacific grid, optionally
intersected with an area of interest.
Expand All @@ -34,8 +114,21 @@ def grid(
"""

if intersect_with is not None:
full_grid = _geoseries(resolution, crs)
return _intersect_grid(full_grid, intersect_with)
if return_type != "GridSpec":
full_grid = _geoseries(resolution, crs)
return _intersect_grid(full_grid, intersect_with)
else:
gridspec = _gridspec(resolution, crs)
geometry = Geometry(loads(intersect_with.to_json()))
# This is a bit of a hack, but it works. Geometries that are transformed by the tiles_from_geopolygon
# are not valid, but doing the simplification and buffer fixes them.
buffer = 0.0 if buffer_distance is None else buffer_distance
fixed = (
geometry.to_crs(PACIFIC_EPSG, check_and_fix=True, wrapdateline=True)
.simplify(0.01)
.buffer(buffer)
)
return gridspec.tiles_from_geopolygon(geopolygon=fixed)

return {
"GridSpec": _gridspec,
Expand All @@ -44,13 +137,13 @@ def grid(
}[return_type](resolution, crs)


def _intersect_grid(grid: GeoSeries, areas_of_interest):
def _intersect_grid(grid: GeoSeries, areas_of_interest) -> GeoDataFrame:
return gpd.sjoin(
gpd.GeoDataFrame(geometry=grid), areas_of_interest.to_crs(grid.crs)
).drop(columns=["index_right"])


def _gridspec(resolution, crs=PACIFIC_EPSG):
def _gridspec(resolution, crs=PACIFIC_EPSG) -> GridSpec:
gridspec_origin = XY(-3000000.0, -4000000.0)

side_in_meters = 96_000
Expand All @@ -64,7 +157,7 @@ def _gridspec(resolution, crs=PACIFIC_EPSG):
)


def _geodataframe(resolution, crs=PACIFIC_EPSG):
def _geodataframe(resolution, crs=PACIFIC_EPSG) -> GeoDataFrame:
return GeoDataFrame(geometry=_geoseries(resolution, crs), crs=crs)


Expand Down
13 changes: 4 additions & 9 deletions dep_tools/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from xarray import DataArray, Dataset

from .landsat_utils import mask_clouds as mask_clouds_landsat
from .s2_utils import harmonize_to_old
from .s2_utils import mask_clouds as mask_clouds_s2
from .utils import scale_and_offset, scale_to_int16

Expand Down Expand Up @@ -33,6 +32,7 @@ def __init__(
def process(self, xr: DataArray | Dataset) -> DataArray | Dataset:
if self.mask_clouds:
xr = mask_clouds_landsat(xr, **self.mask_kwargs)

if self.scale_and_offset:
# These values only work for SR bands of landsat. Ideally we could
# read from metadata. _Really_ ideally we could just pass "scale"
Expand All @@ -49,29 +49,24 @@ class S2Processor(Processor):
def __init__(
self,
send_area_to_processor: bool = False,
harmonize_to_old: bool = True,
scale_and_offset: bool = True,
scale_and_offset: bool = False,
mask_clouds: bool = True,
mask_clouds_kwargs: dict = dict(),
) -> None:
super().__init__(send_area_to_processor)
self.harmonize_to_old = harmonize_to_old
self.scale_and_offset = scale_and_offset
self.mask_clouds = mask_clouds
self.mask_kwargs = mask_clouds_kwargs
self.mask_clouds_kwargs = mask_clouds_kwargs

def process(self, xr: DataArray) -> DataArray:
if self.mask_clouds:
xr = mask_clouds_s2(xr, **self.mask_kwargs)
xr = mask_clouds_s2(xr, **self.mask_clouds_kwargs)

if self.scale_and_offset and not self.harmonize_to_old:
print(
"Warning: scale and offset is dangerous when used without harmonize_to_old"
)

if self.harmonize_to_old:
xr = harmonize_to_old(xr)

if self.scale_and_offset:
scale = 1 / 10000
offset = 0
Expand Down
35 changes: 20 additions & 15 deletions dep_tools/s2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,43 @@ def mask_clouds(
xr: DataArray,
filters: Iterable[Tuple[str, int]] | None = None,
keep_ints: bool = False,
return_mask: bool = False,
) -> DataArray:
# NO_DATA = 0
# SATURATED_OR_DEFECTIVE = 1
SATURATED_OR_DEFECTIVE = 1
# DARK_AREA_PIXELS = 2
CLOUD_SHADOWS = 3
# VEGETATION = 4
# NOT_VEGETATED = 5
# WATER = 6
# UNCLASSIFIED = 7
# CLOUD_MEDIUM_PROBABILITY = 8
CLOUD_MEDIUM_PROBABILITY = 8
CLOUD_HIGH_PROBABILITY = 9
# THIN_CIRRUS = 10
THIN_CIRRUS = 10
# SNOW = 11

bitmask = 0
for field in [CLOUD_SHADOWS, CLOUD_HIGH_PROBABILITY]:
bitmask |= 1 << field

scl = "scl" if "scl" in xr else "SCL"

try:
cloud_mask = xr.sel(band=scl).astype("uint16") & bitmask != 0
except KeyError:
cloud_mask = xr[scl].astype("uint16") & bitmask != 0
cloud_mask = xr.scl.isin(
[
SATURATED_OR_DEFECTIVE,
CLOUD_SHADOWS,
CLOUD_MEDIUM_PROBABILITY,
CLOUD_HIGH_PROBABILITY,
THIN_CIRRUS,
]
)

if filters is not None:
cloud_mask = mask_cleanup(cloud_mask, filters)

if keep_ints:
return erase_bad(xr, cloud_mask)
masked = erase_bad(xr, cloud_mask)
else:
masked = xr.where(~cloud_mask)

if return_mask:
return masked, cloud_mask
else:
return xr.where(~cloud_mask)
return masked


def harmonize_to_old(data: DataArray) -> DataArray:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dep_tools.grids import get_tiles, _get_gadm, PACIFIC_EPSG
from json import loads

from odc.geo import Geometry

def test_get_gadm():
all = _get_gadm()
assert len(all) == 22

# Convert to a ODC Geometry
geom = Geometry(loads(all.to_json()), crs=all.crs)
assert geom.crs == 4326


def test_get_tiles():
# This takes 2 minutes to retrieve all the tiles. Keep it as a generator.
print("a")
tiles = get_tiles(resolution=30)

# Let's just check one
# Each item in the generator is a tuple with the tile index and the geobox object
tile, geobox = next(tiles)
assert type(tile) is tuple
assert geobox.crs == PACIFIC_EPSG

print("b")
# Check the count here, shouldn't take too long
tiles = list(get_tiles(resolution=30, country_codes=["FJI"]))
assert len(tiles) == 27

print("c")

0 comments on commit f25c342

Please sign in to comment.