Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make gadm functions public, and tweak type annotations #66

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions dep_tools/grids.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from json import loads
from pathlib import Path
from typing import Literal
from typing import Literal, Iterator

import antimeridian
import geopandas as gpd
import pandas as pd
from geopandas import GeoDataFrame, GeoSeries
from odc.geo import XY, BoundingBox, Geometry
from odc.geo.gridspec import GridSpec
from odc.geo.gridspec import GridSpec, GeoBox
from shapely.geometry import shape

# This EPSG code is what we're using for now
Expand Down Expand Up @@ -42,7 +42,7 @@
}


def _get_gadm() -> GeoDataFrame:
def gadm() -> GeoDataFrame:
if not GADM_FILE.exists() or not GADM_UNION_FILE.exists():
all_polys = pd.concat(
[
Expand All @@ -60,9 +60,9 @@ def _get_gadm() -> GeoDataFrame:
return gpd.read_file(GADM_FILE)


def _get_gadm_union() -> GeoDataFrame:
def gadm_union() -> GeoDataFrame:
if not GADM_UNION_FILE.exists():
_get_gadm()
gadm()

return gpd.read_file(GADM_UNION_FILE)

Expand All @@ -71,17 +71,17 @@ def get_tiles(
resolution: int | float = 30,
country_codes: list[str] | None = None,
buffer_distance: int | float | None = None,
) -> list[(list[int, int], GridSpec)]:
) -> Iterator[tuple[tuple[int, int], GeoBox]]:
"""Returns a list of tile IDs for the Pacific region, optionally filtered by country code."""

if country_codes is None:
geometries = _get_gadm_union()
geometries = gadm_union()
else:
if not all(code in COUNTRIES_AND_CODES.values() for code in country_codes):
raise ValueError(
f"Invalid country code. Must be one of {', '.join(COUNTRIES_AND_CODES.values())}"
)
geometries = _get_gadm().loc[lambda df: df["GID_0"].isin(country_codes)]
geometries = gadm().loc[lambda df: df["GID_0"].isin(country_codes)]

return grid(
resolution=resolution,
Expand All @@ -97,7 +97,7 @@ def grid(
return_type: Literal["GridSpec", "GeoSeries", "GeoDataFrame"] = "GridSpec",
intersect_with: GeoDataFrame | None = None,
buffer_distance: int | float | None = None,
) -> GridSpec | GeoSeries | GeoDataFrame:
) -> GridSpec | GeoSeries | GeoDataFrame | Iterator[tuple[tuple[int, int], GeoBox]]:
"""Returns a GridSpec or GeoSeries representing the Pacific grid, optionally
intersected with an area of interest.

Expand All @@ -110,7 +110,10 @@ def grid(
this is ignored.
intersect_with: The output is intersected with the supplied GeoDataFrame
before returning, returning only tiles which overlap with those
features. Forces the output to be a GeoDataFrame.
features. If `return_type` is `GridSpec`, an iterator of tuples each
containing the tile id (in column, row order) and its GeoBox.
Otherwise, a GeoDataFrame containing only the portions of each tile
that intersect the given GeoDataFrame is returned.
"""

if intersect_with is not None:
Expand Down
Loading