diff --git a/dep_tools/grids.py b/dep_tools/grids.py index 98afd11..73e8ec2 100644 --- a/dep_tools/grids.py +++ b/dep_tools/grids.py @@ -1,13 +1,13 @@ from json import loads from pathlib import Path -from typing import Literal +from typing import Literal, Iterator import antimeridian import geopandas as gpd import pandas as pd from geopandas import GeoDataFrame, GeoSeries from odc.geo import XY, BoundingBox, Geometry -from odc.geo.gridspec import GridSpec +from odc.geo.gridspec import GridSpec, GeoBox from shapely.geometry import shape # This EPSG code is what we're using for now @@ -42,7 +42,7 @@ } -def _get_gadm() -> GeoDataFrame: +def gadm() -> GeoDataFrame: if not GADM_FILE.exists() or not GADM_UNION_FILE.exists(): all_polys = pd.concat( [ @@ -60,9 +60,9 @@ def _get_gadm() -> GeoDataFrame: return gpd.read_file(GADM_FILE) -def _get_gadm_union() -> GeoDataFrame: +def gadm_union() -> GeoDataFrame: if not GADM_UNION_FILE.exists(): - _get_gadm() + gadm() return gpd.read_file(GADM_UNION_FILE) @@ -71,17 +71,17 @@ def get_tiles( resolution: int | float = 30, country_codes: list[str] | None = None, buffer_distance: int | float | None = None, -) -> list[(list[int, int], GridSpec)]: +) -> Iterator[tuple[tuple[int, int], GeoBox]]: """Returns a list of tile IDs for the Pacific region, optionally filtered by country code.""" if country_codes is None: - geometries = _get_gadm_union() + geometries = gadm_union() else: if not all(code in COUNTRIES_AND_CODES.values() for code in country_codes): raise ValueError( f"Invalid country code. Must be one of {', '.join(COUNTRIES_AND_CODES.values())}" ) - geometries = _get_gadm().loc[lambda df: df["GID_0"].isin(country_codes)] + geometries = gadm().loc[lambda df: df["GID_0"].isin(country_codes)] return grid( resolution=resolution, @@ -97,7 +97,7 @@ def grid( return_type: Literal["GridSpec", "GeoSeries", "GeoDataFrame"] = "GridSpec", intersect_with: GeoDataFrame | None = None, buffer_distance: int | float | None = None, -) -> GridSpec | GeoSeries | GeoDataFrame: +) -> GridSpec | GeoSeries | GeoDataFrame | Iterator[tuple[tuple[int, int], GeoBox]]: """Returns a GridSpec or GeoSeries representing the Pacific grid, optionally intersected with an area of interest. @@ -110,7 +110,10 @@ def grid( this is ignored. intersect_with: The output is intersected with the supplied GeoDataFrame before returning, returning only tiles which overlap with those - features. Forces the output to be a GeoDataFrame. + features. If `return_type` is `GridSpec`, an iterator of tuples each + containing the tile id (in column, row order) and its GeoBox. + Otherwise, a GeoDataFrame containing only the portions of each tile + that intersect the given GeoDataFrame is returned. """ if intersect_with is not None: