From 085a7fa62e03d7a40c59c042d4cb2be70c2e99b0 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 19 Sep 2024 17:16:30 -0700 Subject: [PATCH] Refactor to better support usage as module (#61) * Refactor to better support usage as module * Bump version * Ensure only 1 connection during lambda execution * Fix typing * Rework to lazy-app to make for easier testing * Update path * Cleanup helper for flexibility * Use context manager to close connections * Add tests * Pre-commit * Rework settings * Simplify db setup * Add missing quote and change steps to load data * Fix settings import * Update for fast reloading and how to run API locally * Update README.md Co-authored-by: Anthony Lukach --------- Co-authored-by: Zachary Deziel --- README.md | 62 +++++++- space2stats_api/cdk/aws_stack.py | 2 +- space2stats_api/src/space2stats/__init__.py | 6 +- space2stats_api/src/space2stats/__main__.py | 8 +- .../src/space2stats/api/__init__.py | 3 + space2stats_api/src/space2stats/api/app.py | 77 ++++++++++ .../src/space2stats/{ => api}/db.py | 5 +- .../src/space2stats/{ => api}/errors.py | 0 .../src/space2stats/{ => api}/handler.py | 8 +- .../src/space2stats/api/schemas.py | 12 ++ .../src/space2stats/api/settings.py | 6 + space2stats_api/src/space2stats/app.py | 78 ---------- space2stats_api/src/space2stats/lib.py | 125 ++++++++++++++++ space2stats_api/src/space2stats/main.py | 135 ------------------ space2stats_api/src/space2stats/settings.py | 3 - space2stats_api/src/space2stats/types.py | 6 + space2stats_api/src/tests/conftest.py | 19 ++- space2stats_api/src/tests/test_errors.py | 4 +- space2stats_api/src/tests/test_module.py | 37 +++++ 19 files changed, 355 insertions(+), 241 deletions(-) create mode 100644 space2stats_api/src/space2stats/api/__init__.py create mode 100644 space2stats_api/src/space2stats/api/app.py rename space2stats_api/src/space2stats/{ => api}/db.py (91%) rename space2stats_api/src/space2stats/{ => api}/errors.py (100%) rename space2stats_api/src/space2stats/{ => api}/handler.py (65%) create mode 100644 space2stats_api/src/space2stats/api/schemas.py create mode 100644 space2stats_api/src/space2stats/api/settings.py delete mode 100644 space2stats_api/src/space2stats/app.py create mode 100644 space2stats_api/src/space2stats/lib.py delete mode 100644 space2stats_api/src/space2stats/main.py create mode 100644 space2stats_api/src/space2stats/types.py create mode 100644 space2stats_api/src/tests/test_module.py diff --git a/README.md b/README.md index ccf6735..fb4799c 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,17 @@ # Space2Stats -Consistent, comparable, authoritative data describing sub-national variation is a constant point of complication for World Bank teams, our development partners, and client countries when assessing and investigating economic issues and national policy. This project will focus on creating and disseminating such data through aggregation of geospatial information at standard administrative divisions, and through the attribution of household survey data with foundational geospatial variables. +Consistent, comparable, authoritative data describing sub-national variation is a constant point of complication for World Bank teams, our development partners, and client countries when assessing and investigating economic issues and national policy. This project will focus on creating and disseminating such data through aggregation of geospatial information at standard administrative divisions, and through the attribution of household survey data with foundational geospatial variables. ## Getting Started Locally -- Setup the database: +- Setup the database: + ``` docker-compose up -d ``` - Create a `db.env` file: + ```.env PGHOST=localhost PGPORT=5439 @@ -20,15 +22,65 @@ PGTABLENAME=space2stats ``` - Load our dataset into the database + ``` ./postgres/download_parquet.sh -python postgres/chunk_parquet.py -./postgres/load_parquet_chunks.sh +./load_to_prod.sh ``` > You can get started with a subset of data for NYC with `./load_nyc_sample.sh` which requires changing your `db.env` value for `PGTABLENAME` to `space2stats_nyc_sample`. -- Access your data using the Space2statS API! See the [example notebook](notebooks/space2stats_api_demo.ipynb). +- Access your data using the Space2stats API! See the [example notebook](notebooks/space2stats_api_demo.ipynb). + +## Usage as an API + +The API can be run with: +``` +python -m space2stats +``` + +## Usage as a module +The module can be installed via `pip` directly from Github: +``` +pip install "git+https://github.com/worldbank/DECAT_Space2Stats.git#subdirectory=space2stats_api/src" +``` + +It can then be used within Python as such: + +```py +from space2stats import StatsTable + +with StatsTable.connect() as stats_table: + ... +``` + +Connection parameters may be explicitely provided. Otherwise, connection parameters will expected to be available via standard [PostgreSQL Environment Variables](https://www.postgresql.org/docs/current/libpq-envars.html#LIBPQ-ENVARS). + +```py +from space2stats import StatsTable + +with StatsTable.connect( + PGHOST="localhost", + PGPORT="5432", + PGUSER="postgres", + PGPASSWORD="changeme", + PGDATABASE="postgis", + PGTABLENAME="space2stats", +) as stats_table: + ... + +# alternatively: +# settings = Settings( +# PGHOST="localhost", +# PGPORT="5432", +# PGUSER="postgres", +# PGPASSWORD="changeme", +# PGDATABASE="postgis", +# PGTABLENAME="space2stats", +# ) +# with StatsTable.connect(settings): +# ... +``` diff --git a/space2stats_api/cdk/aws_stack.py b/space2stats_api/cdk/aws_stack.py index 4f113c6..251edb2 100644 --- a/space2stats_api/cdk/aws_stack.py +++ b/space2stats_api/cdk/aws_stack.py @@ -27,7 +27,7 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: "Space2StatsFunction", entry="../src", runtime=_lambda.Runtime.PYTHON_3_11, - index="space2stats/handler.py", + index="space2stats/api/handler.py", timeout=Duration.seconds(120), handler="handler", environment={ diff --git a/space2stats_api/src/space2stats/__init__.py b/space2stats_api/src/space2stats/__init__.py index 7298f5b..31bca4f 100644 --- a/space2stats_api/src/space2stats/__init__.py +++ b/space2stats_api/src/space2stats/__init__.py @@ -1,3 +1,7 @@ """space2stats.""" -__version__ = "0.1.0" +from .lib import StatsTable +from .settings import Settings + +__all__ = ["StatsTable", "Settings"] +__version__ = "1.0.0" diff --git a/space2stats_api/src/space2stats/__main__.py b/space2stats_api/src/space2stats/__main__.py index feb354d..670ca42 100644 --- a/space2stats_api/src/space2stats/__main__.py +++ b/space2stats_api/src/space2stats/__main__.py @@ -1,7 +1,5 @@ import os -from .app import app - try: import uvicorn # noqa @@ -15,9 +13,11 @@ ), "uvicorn must be installed: `python -m pip install 'space2stats[server]'`" uvicorn.run( - app=app, + app="space2stats.api.app:build_app", host=os.getenv("UVICORN_HOST", "127.0.0.1"), - port=os.getenv("UVICORN_PORT", "8000"), + port=int(os.getenv("UVICORN_PORT", "8000")), root_path=os.getenv("UVICORN_ROOT_PATH", ""), log_level="info", + factory=True, + reload=True, ) diff --git a/space2stats_api/src/space2stats/api/__init__.py b/space2stats_api/src/space2stats/api/__init__.py new file mode 100644 index 0000000..604140b --- /dev/null +++ b/space2stats_api/src/space2stats/api/__init__.py @@ -0,0 +1,3 @@ +from .app import build_app + +__all__ = ["build_app"] diff --git a/space2stats_api/src/space2stats/api/app.py b/space2stats_api/src/space2stats/api/app.py new file mode 100644 index 0000000..aeeb422 --- /dev/null +++ b/space2stats_api/src/space2stats/api/app.py @@ -0,0 +1,77 @@ +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +import boto3 +from asgi_s3_response_middleware import S3ResponseMiddleware +from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse +from starlette.requests import Request +from starlette_cramjam.middleware import CompressionMiddleware + +from ..lib import StatsTable +from .db import close_db_connection, connect_to_db +from .errors import add_exception_handlers +from .schemas import SummaryRequest +from .settings import Settings + +s3_client = boto3.client("s3") + + +def build_app(settings: Optional[Settings] = None) -> FastAPI: + settings = settings or Settings() + + @asynccontextmanager + async def lifespan(app: FastAPI): + await connect_to_db(app, settings=settings) + yield + await close_db_connection(app) + + app = FastAPI( + default_response_class=ORJSONResponse, + lifespan=lifespan, + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.add_middleware(CompressionMiddleware) + app.add_middleware( + S3ResponseMiddleware, + s3_bucket_name=settings.S3_BUCKET_NAME, + s3_client=s3_client, + ) + + add_exception_handlers(app) + + def stats_table(request: Request): + """Dependency to generate a per-request connection to stats table""" + with request.app.state.pool.connection() as conn: + yield StatsTable(conn=conn, table_name=settings.PGTABLENAME) + + @app.post("/summary", response_model=List[Dict[str, Any]]) + def get_summary(body: SummaryRequest, table: StatsTable = Depends(stats_table)): + return table.summaries( + body.aoi, + body.spatial_join_method, + body.fields, + body.geometry, + ) + + @app.get("/fields", response_model=List[str]) + def fields(table: StatsTable = Depends(stats_table)): + return table.fields() + + @app.get("/") + def read_root(): + return {"message": "Welcome to Space2Stats!"} + + @app.get("/health") + def health(): + return {"status": "ok"} + + return app diff --git a/space2stats_api/src/space2stats/db.py b/space2stats_api/src/space2stats/api/db.py similarity index 91% rename from space2stats_api/src/space2stats/db.py rename to space2stats_api/src/space2stats/api/db.py index 67ff0a2..15a622e 100644 --- a/space2stats_api/src/space2stats/db.py +++ b/space2stats_api/src/space2stats/api/db.py @@ -10,13 +10,10 @@ async def connect_to_db( app: FastAPI, - settings: Optional[Settings] = None, + settings: Settings, pool_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Connect to Database.""" - if not settings: - settings = Settings() - pool_kwargs = pool_kwargs or {} app.state.pool = ConnectionPool( diff --git a/space2stats_api/src/space2stats/errors.py b/space2stats_api/src/space2stats/api/errors.py similarity index 100% rename from space2stats_api/src/space2stats/errors.py rename to space2stats_api/src/space2stats/api/errors.py diff --git a/space2stats_api/src/space2stats/handler.py b/space2stats_api/src/space2stats/api/handler.py similarity index 65% rename from space2stats_api/src/space2stats/handler.py rename to space2stats_api/src/space2stats/api/handler.py index e9a7c33..ab8c7e6 100644 --- a/space2stats_api/src/space2stats/handler.py +++ b/space2stats_api/src/space2stats/api/handler.py @@ -5,14 +5,18 @@ from mangum import Mangum -from .app import app +from .app import build_app from .db import connect_to_db +from .settings import Settings + +settings = Settings(DB_MAX_CONN_SIZE=1) # disable connection pooling +app = build_app(settings) @app.on_event("startup") async def startup_event() -> None: """Connect to database on startup.""" - await connect_to_db(app) + await connect_to_db(app, settings=settings) handler = Mangum(app, lifespan="off") diff --git a/space2stats_api/src/space2stats/api/schemas.py b/space2stats_api/src/space2stats/api/schemas.py new file mode 100644 index 0000000..ec7e741 --- /dev/null +++ b/space2stats_api/src/space2stats/api/schemas.py @@ -0,0 +1,12 @@ +from typing import List, Literal, Optional + +from pydantic import BaseModel + +from ..types import AoiModel + + +class SummaryRequest(BaseModel): + aoi: AoiModel + spatial_join_method: Literal["touches", "centroid", "within"] + fields: List[str] + geometry: Optional[Literal["polygon", "point"]] = None diff --git a/space2stats_api/src/space2stats/api/settings.py b/space2stats_api/src/space2stats/api/settings.py new file mode 100644 index 0000000..a36edb4 --- /dev/null +++ b/space2stats_api/src/space2stats/api/settings.py @@ -0,0 +1,6 @@ +from ..settings import Settings as DbSettings + + +class Settings(DbSettings): + # Bucket for large responses + S3_BUCKET_NAME: str diff --git a/space2stats_api/src/space2stats/app.py b/space2stats_api/src/space2stats/app.py deleted file mode 100644 index af2baf2..0000000 --- a/space2stats_api/src/space2stats/app.py +++ /dev/null @@ -1,78 +0,0 @@ -from contextlib import asynccontextmanager -from typing import Any, Dict, List - -import boto3 -from asgi_s3_response_middleware import S3ResponseMiddleware -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import ORJSONResponse -from starlette.requests import Request -from starlette_cramjam.middleware import CompressionMiddleware - -from .db import close_db_connection, connect_to_db -from .errors import add_exception_handlers -from .main import ( - SummaryRequest, - get_available_fields, - get_summaries_from_geom, - settings, -) - -s3_client = boto3.client("s3") - - -@asynccontextmanager -async def lifespan(app: FastAPI): - await connect_to_db(app) - yield - await close_db_connection(app) - - -app = FastAPI( - default_response_class=ORJSONResponse, - lifespan=lifespan, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -app.add_middleware(CompressionMiddleware) -app.add_middleware( - S3ResponseMiddleware, - s3_bucket_name=settings.S3_BUCKET_NAME, - s3_client=s3_client, -) - -add_exception_handlers(app) - - -@app.post("/summary", response_model=List[Dict[str, Any]]) -def get_summary(request: Request, body: SummaryRequest): - with request.app.state.pool.connection() as conn: - return get_summaries_from_geom( - body.aoi, - body.spatial_join_method, - body.fields, - conn, - geometry=body.geometry, - ) - - -@app.get("/fields", response_model=List[str]) -def fields(request: Request): - with request.app.state.pool.connection() as conn: - return get_available_fields(conn) - - -@app.get("/") -def read_root(): - return {"message": "Welcome to Space2Stats!"} - - -@app.get("/health") -def health(): - return {"status": "ok"} diff --git a/space2stats_api/src/space2stats/lib.py b/space2stats_api/src/space2stats/lib.py new file mode 100644 index 0000000..689a101 --- /dev/null +++ b/space2stats_api/src/space2stats/lib.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass +from typing import Dict, List, Literal, Optional + +import psycopg as pg +from geojson_pydantic import Feature +from psycopg import Connection + +from .h3_utils import generate_h3_geometries, generate_h3_ids +from .settings import Settings +from .types import AoiModel + + +@dataclass +class StatsTable: + conn: Connection + table_name: str + + @classmethod + def connect(cls, settings: Optional[Settings] = None, **kwargs) -> "StatsTable": + """ + Helper method to connect to the database and return a StatsTable instance. + + ```py + with StatsTable.connect() as stats_table: + stats_table.fields() + ``` + """ + settings = settings or Settings(**kwargs, _extra="forbid") + conn = pg.connect(settings.DB_CONNECTION_STRING) + return cls(conn=conn, table_name=settings.PGTABLENAME) + + def __enter__(self) -> "StatsTable": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.conn: + self.conn.close() + + def _get_summaries(self, fields: List[str], h3_ids: List[str]): + colnames = ["hex_id"] + fields + cols = [pg.sql.Identifier(c) for c in colnames] + sql_query = pg.sql.SQL( + """ + SELECT {0} + FROM {1} + WHERE hex_id = ANY (%s) + """ + ).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(self.table_name)) + + # Convert h3_ids to a list to ensure compatibility with psycopg + h3_ids = list(h3_ids) + with self.conn.cursor() as cur: + cur.execute( + sql_query, + [ + h3_ids, + ], + ) + rows = cur.fetchall() + colnames = [desc[0] for desc in cur.description] + + return rows, colnames + + def summaries( + self, + aoi: AoiModel, + spatial_join_method: Literal["touches", "centroid", "within"], + fields: List[str], + geometry: Optional[Literal["polygon", "point"]] = None, + ): + """Retrieve Statistics from a GeoJSON feature.""" + if not isinstance(aoi, Feature): + aoi = AoiModel.model_validate(aoi) + + # Get H3 ids from geometry + resolution = 6 + h3_ids = generate_h3_ids( + aoi.geometry.model_dump(exclude_none=True), + resolution, + spatial_join_method, + ) + + if not h3_ids: + return [] + + # Get Summaries from H3 ids + rows, colnames = self._get_summaries(fields=fields, h3_ids=h3_ids) + if not rows: + return [] + + # Format Summaries + summaries: List[Dict] = [] + geometries = generate_h3_geometries(h3_ids, geometry) if geometry else None + + for idx, row in enumerate(rows): + summary = {"hex_id": row[0]} + if geometry and geometries: + summary["geometry"] = geometries[idx] + + summary.update( + { + col: row[idx] + for idx, col in enumerate(colnames[1:], start=1) + if col in fields + } + ) + summaries.append(summary) + + return summaries + + def fields(self) -> List[str]: + sql_query = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = %s + """ + + with self.conn.cursor() as cur: + cur.execute( + sql_query, + [self.table_name], + ) + columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"] + + return columns diff --git a/space2stats_api/src/space2stats/main.py b/space2stats_api/src/space2stats/main.py deleted file mode 100644 index 433f3c3..0000000 --- a/space2stats_api/src/space2stats/main.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Dict, List, Literal, Optional - -import psycopg as pg -from geojson_pydantic import Feature, Polygon -from psycopg import Connection -from pydantic import BaseModel -from typing_extensions import TypeAlias - -from .h3_utils import generate_h3_geometries, generate_h3_ids -from .settings import Settings - -settings = Settings() - -AoiModel: TypeAlias = Feature[Polygon, Dict] - - -class SummaryRequest(BaseModel): - aoi: AoiModel - spatial_join_method: Literal["touches", "centroid", "within"] - fields: List[str] - geometry: Optional[Literal["polygon", "point"]] = None - - -def _get_summaries(fields: List[str], h3_ids: List[str], conn: Connection): - colnames = ["hex_id"] + fields - cols = [pg.sql.Identifier(c) for c in colnames] - sql_query = pg.sql.SQL( - """ - SELECT {0} - FROM {1} - WHERE hex_id = ANY (%s) - """ - ).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(settings.PGTABLENAME)) - - # Convert h3_ids to a list to ensure compatibility with psycopg - h3_ids = list(h3_ids) - with conn.cursor() as cur: - cur.execute( - sql_query, - [ - h3_ids, - ], - ) - rows = cur.fetchall() - colnames = [desc[0] for desc in cur.description] - - return rows, colnames - - -def get_summaries_from_geom( - aoi: AoiModel, - spatial_join_method: Literal["touches", "centroid", "within"], - fields: List[str], - conn: Connection, - geometry: Optional[Literal["polygon", "point"]] = None, -): - # Get H3 ids from geometry - resolution = 6 - h3_ids = generate_h3_ids( - aoi.geometry.model_dump(exclude_none=True), - resolution, - spatial_join_method, - ) - - if not h3_ids: - return [] - - # Get Summaries from H3 ids - rows, colnames = _get_summaries(fields, h3_ids, conn) - if not rows: - return [] - - # Format Summaries - summaries: List[Dict] = [] - geometries = generate_h3_geometries(h3_ids, geometry) if geometry else None - - for idx, row in enumerate(rows): - summary = {"hex_id": row[0]} - if geometry and geometries: - summary["geometry"] = geometries[idx] - - summary.update( - { - col: row[idx] - for idx, col in enumerate(colnames[1:], start=1) - if col in fields - } - ) - summaries.append(summary) - - return summaries - - -def get_available_fields(conn: Connection) -> List[str]: - sql_query = """ - SELECT column_name - FROM information_schema.columns - WHERE table_name = %s - """ - with conn.cursor() as cur: - cur.execute( - sql_query, - [ - settings.PGTABLENAME, - ], - ) - columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"] - - return columns - - -def summaries( - aoi: AoiModel, - spatial_join_method: Literal["touches", "centroid", "within"], - fields: List[str], - geometry: Optional[Literal["polygon", "point"]] = None, -) -> List[Dict]: - """Retrieve Statistics from a GeoJSON feature.""" - if not isinstance(aoi, Feature): - aoi = AoiModel.model_validate(aoi) - - with pg.connect(settings.DB_CONNECTION_STRING) as conn: - return get_summaries_from_geom( - aoi, - spatial_join_method, - fields, - conn, - geometry=geometry, - ) - - -def fields() -> List[str]: - """List Available Fields in the Table.""" - with pg.connect(settings.DB_CONNECTION_STRING) as conn: - return get_available_fields(conn) diff --git a/space2stats_api/src/space2stats/settings.py b/space2stats_api/src/space2stats/settings.py index 4aff4c8..a791091 100644 --- a/space2stats_api/src/space2stats/settings.py +++ b/space2stats_api/src/space2stats/settings.py @@ -9,9 +9,6 @@ class Settings(BaseSettings): PGPASSWORD: str PGTABLENAME: str - # Bucket for large responses - S3_BUCKET_NAME: str - # see https://www.psycopg.org/psycopg3/docs/api/pool.html#the-connectionpool-class for options DB_MIN_CONN_SIZE: int = 1 DB_MAX_CONN_SIZE: int = 10 diff --git a/space2stats_api/src/space2stats/types.py b/space2stats_api/src/space2stats/types.py new file mode 100644 index 0000000..a358eb2 --- /dev/null +++ b/space2stats_api/src/space2stats/types.py @@ -0,0 +1,6 @@ +from typing import Dict + +from geojson_pydantic import Feature, Polygon +from typing_extensions import TypeAlias + +AoiModel: TypeAlias = Feature[Polygon, Dict] diff --git a/space2stats_api/src/tests/conftest.py b/space2stats_api/src/tests/conftest.py index 8c1e925..c5fdfb7 100644 --- a/space2stats_api/src/tests/conftest.py +++ b/space2stats_api/src/tests/conftest.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from moto import mock_aws from pytest_postgresql.janitor import DatabaseJanitor +from space2stats.api.app import build_app @pytest.fixture @@ -59,23 +60,27 @@ def database(postgresql_proc): ) with psycopg.connect(db_url) as conn: with conn.cursor() as cur: - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS space2stats ( hex_id TEXT PRIMARY KEY, sum_pop_2020 INT, sum_pop_f_10_2020 INT ); - """) - cur.execute(""" + """ + ) + cur.execute( + """ INSERT INTO space2stats (hex_id, sum_pop_2020, sum_pop_f_10_2020) VALUES ('hex_1', 100, 200), ('hex_2', 150, 250); - """) + """ + ) yield jan @pytest.fixture(autouse=True) -def client(monkeypatch, database, test_bucket): +def mock_env(monkeypatch, database, test_bucket): monkeypatch.setenv("PGHOST", database.host) monkeypatch.setenv("PGPORT", str(database.port)) monkeypatch.setenv("PGDATABASE", database.dbname) @@ -84,7 +89,9 @@ def client(monkeypatch, database, test_bucket): monkeypatch.setenv("PGTABLENAME", "space2stats") monkeypatch.setenv("S3_BUCKET_NAME", test_bucket) - from space2stats.app import app +@pytest.fixture +def client(): + app = build_app() with TestClient(app) as test_client: yield test_client diff --git a/space2stats_api/src/tests/test_errors.py b/space2stats_api/src/tests/test_errors.py index bb65fd3..95cc005 100644 --- a/space2stats_api/src/tests/test_errors.py +++ b/space2stats_api/src/tests/test_errors.py @@ -5,8 +5,8 @@ from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient from psycopg.errors import OperationalError -from space2stats.app import app -from space2stats.errors import ( +from space2stats.api import app +from space2stats.api.errors import ( database_exception_handler, http_exception_handler, validation_exception_handler, diff --git a/space2stats_api/src/tests/test_module.py b/space2stats_api/src/tests/test_module.py new file mode 100644 index 0000000..29bf1b2 --- /dev/null +++ b/space2stats_api/src/tests/test_module.py @@ -0,0 +1,37 @@ +from space2stats import Settings, StatsTable + + +def test_stats_table(mock_env): + with StatsTable.connect() as stats_table: + assert stats_table.table_name == "space2stats" + assert stats_table.conn.closed == 0 + stats_table.conn.execute("SELECT 1") + + +def test_stats_table_connect(mock_env, database): + with StatsTable.connect( + PGHOST=database.host, + PGPORT=database.port, + PGDATABASE=database.dbname, + PGUSER=database.user, + PGPASSWORD=database.password, + PGTABLENAME="XYZ", + ) as stats_table: + assert stats_table.table_name == "XYZ" + assert stats_table.conn.closed == 0 + stats_table.conn.execute("SELECT 1") + + +def test_stats_table_settings(mock_env, database): + settings = Settings( + PGHOST=database.host, + PGPORT=database.port, + PGDATABASE=database.dbname, + PGUSER=database.user, + PGPASSWORD=database.password, + PGTABLENAME="ABC", + ) + with StatsTable.connect(settings) as stats_table: + assert stats_table.table_name == "ABC" + assert stats_table.conn.closed == 0 + stats_table.conn.execute("SELECT 1")