Skip to content

Commit

Permalink
Feature/aggregation (#71)
Browse files Browse the repository at this point in the history
* Update tests

* Add missing tests from history, solve SQL query, and manage DB Janitor properly

* Add support for more aggregation methods

* Add aggregation endpoint

* Add notebook example for aggregation
  • Loading branch information
zacdezgeo authored Oct 2, 2024
1 parent e1939ca commit 3a0068b
Show file tree
Hide file tree
Showing 7 changed files with 56,712 additions and 46 deletions.
56,478 changes: 56,456 additions & 22 deletions notebooks/space2stats_api_adm_example.ipynb

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions space2stats_api/src/space2stats/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..lib import StatsTable
from .db import close_db_connection, connect_to_db
from .errors import add_exception_handlers
from .schemas import SummaryRequest
from .schemas import AggregateRequest, SummaryRequest
from .settings import Settings

s3_client = boto3.client("s3")
Expand Down Expand Up @@ -50,7 +50,6 @@ async def lifespan(app: FastAPI):
add_exception_handlers(app)

def stats_table(request: Request):
"""Dependency to generate a per-request connection to stats table"""
with request.app.state.pool.connection() as conn:
yield StatsTable(conn=conn, table_name=settings.PGTABLENAME)

Expand All @@ -66,6 +65,18 @@ def get_summary(body: SummaryRequest, table: StatsTable = Depends(stats_table)):
except pg.errors.UndefinedColumn as e:
raise HTTPException(status_code=400, detail=e.diag.message_primary) from e

@app.post("/aggregate", response_model=Dict[str, float])
def get_aggregate(body: AggregateRequest, table: StatsTable = Depends(stats_table)):
try:
return table.aggregate(
aoi=body.aoi,
spatial_join_method=body.spatial_join_method,
fields=body.fields,
aggregation_type=body.aggregation_type,
)
except pg.errors.UndefinedColumn as e:
raise HTTPException(status_code=400, detail=e.diag.message_primary) from e

@app.get("/fields", response_model=List[str])
def fields(table: StatsTable = Depends(stats_table)):
return table.fields()
Expand Down
8 changes: 8 additions & 0 deletions space2stats_api/src/space2stats/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Literal, Optional

from geojson_pydantic import Feature
from pydantic import BaseModel

from ..types import AoiModel
Expand All @@ -10,3 +11,10 @@ class SummaryRequest(BaseModel):
spatial_join_method: Literal["touches", "centroid", "within"]
fields: List[str]
geometry: Optional[Literal["polygon", "point"]] = None


class AggregateRequest(BaseModel):
aoi: Feature
spatial_join_method: Literal["touches", "centroid", "within"]
fields: List[str]
aggregation_type: Literal["sum", "avg", "count", "max", "min"]
54 changes: 54 additions & 0 deletions space2stats_api/src/space2stats/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,57 @@ def fields(self) -> List[str]:
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]

return columns

def aggregate(
self,
aoi: AoiModel,
spatial_join_method: Literal["touches", "centroid", "within"],
fields: List[str],
aggregation_type: Literal["sum", "avg", "count", "max", "min"],
) -> Dict[str, float]:
"""Aggregate Statistics from a GeoJSON feature."""
if not isinstance(aoi, Feature):
aoi = AoiModel.model_validate(aoi)

# Get H3 ids from geometry
resolution = 6
h3_ids = list(
generate_h3_ids(
aoi.geometry.model_dump(exclude_none=True),
resolution,
spatial_join_method,
)
)

if not h3_ids:
return {}

# Prepare SQL aggregation query
aggregations = [f"{aggregation_type}({field}) AS {field}" for field in fields]
sql_query = pg.sql.SQL(
"""
SELECT {0}
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(
pg.sql.SQL(", ").join(pg.sql.SQL(a) for a in aggregations),
pg.sql.Identifier(self.table_name),
)

# Convert h3_ids to a list to ensure compatibility with psycopg
h3_ids = list(h3_ids)
with self.conn.cursor() as cur:
cur.execute(
sql_query,
[h3_ids],
)
row = cur.fetchone() # Get a single row of results
colnames = [desc[0] for desc in cur.description]

# Create a dictionary to hold the aggregation results
aggregated_results: Dict[str, float] = {}
for idx, col in enumerate(colnames):
aggregated_results[col] = row[idx]

return aggregated_results
57 changes: 38 additions & 19 deletions space2stats_api/src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
import psycopg
import pytest
from fastapi.testclient import TestClient
from geojson_pydantic import Feature
from moto import mock_aws
from pytest_postgresql.janitor import DatabaseJanitor
from space2stats.api.app import build_app


@pytest.fixture
def s3_mock():
"""
Mock S3 environment and create a test bucket.
"""
"""Mock S3 environment and create a test bucket."""
with mock_aws():
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
Expand All @@ -22,21 +21,17 @@ def s3_mock():

@pytest.fixture()
def aws_credentials():
"""
Mocked AWS credentials for moto.
"""
"""Mocked AWS credentials for moto."""
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"


@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def database(postgresql_proc):
"""
Set up a PostgreSQL database for testing and clean up afterwards.
"""
"""Set up a PostgreSQL database for testing and clean up afterwards."""
with DatabaseJanitor(
user=postgresql_proc.user,
host=postgresql_proc.host,
Expand All @@ -50,7 +45,6 @@ def database(postgresql_proc):
)
with psycopg.connect(db_url) as conn:
with conn.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS space2stats")
cur.execute(
"""
CREATE TABLE space2stats (
Expand All @@ -60,21 +54,28 @@ def database(postgresql_proc):
);
"""
)
conn.commit()

# Insert data that corresponds to the expected H3 IDs
cur.execute(
"""
INSERT INTO space2stats (hex_id, sum_pop_2020, sum_pop_f_10_2020)
VALUES ('862a1070fffffff', 100, 200), ('862a10767ffffff', 150, 250);
"""
VALUES
('862a1070fffffff', 100, 200),
('862a10767ffffff', 150, 250),
('862a1073fffffff', 120, 220),
('867a74817ffffff', 125, 225),
('867a74807ffffff', 125, 225);
"""
)
conn.commit()

yield jan


@pytest.fixture(autouse=True)
def mock_env(monkeypatch, database):
"""
Automatically set environment variables for PostgreSQL and S3.
"""
"""Automatically set environment variables for PostgreSQL and S3."""
monkeypatch.setenv("PGHOST", database.host)
monkeypatch.setenv("PGPORT", str(database.port))
monkeypatch.setenv("PGDATABASE", database.dbname)
Expand All @@ -86,9 +87,27 @@ def mock_env(monkeypatch, database):

@pytest.fixture
def client():
"""
Provide a test client for FastAPI.
"""
"""Provide a test client for FastAPI."""
app = build_app()
with TestClient(app) as test_client:
yield test_client


@pytest.fixture
def aoi_example():
"""Provide an example AOI feature for testing."""
return Feature(
type="Feature",
geometry={
"type": "Polygon",
"coordinates": [
[
[41.14127371265408, -2.1034653113510444],
[41.140645873470845, -2.104696345752785],
[41.14205369446421, -2.104701102391104],
[41.14127371265408, -2.1034653113510444],
]
],
},
properties={},
)
21 changes: 19 additions & 2 deletions space2stats_api/src/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

aoi = {
"type": "Feature",
"geometry": {
# This polygon intersects with the test data
"type": "Polygon",
"coordinates": [
[
Expand Down Expand Up @@ -55,14 +56,30 @@ def test_bad_fields_validated(client):
assert response.json() == {"error": 'column "a_non_existent_field" does not exist'}


@pytest.mark.parametrize("aggregation_type", ["sum", "avg", "count", "max", "min"])
def test_aggregate_methods(client, aggregation_type):
request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
"aggregation_type": aggregation_type,
}

response = client.post("/aggregate", json=request_payload)
assert response.status_code == 200
response_json = response.json()
assert isinstance(response_json, dict)
assert "sum_pop_2020" in response_json
assert "sum_pop_f_10_2020" in response_json


def test_get_summary_with_geometry_multipolygon(client):
request_payload = {
"aoi": {
**aoi,
"geometry": {
"type": "MultiPolygon",
"coordinates": [
# Ensure at least one multipolygon interacts with test data
aoi["geometry"]["coordinates"],
[
[
Expand Down
Loading

0 comments on commit 3a0068b

Please sign in to comment.