Skip to content

Commit

Permalink
Refactor testing from patch to pytest fixtures (#47)
Browse files Browse the repository at this point in the history
* Refactor testing from mocks to pytest fixtures

* Update tests to use data directly instead of mocking

* Update CI for data fixture

* remove top level module export and refactor conftests (#48)

---------

Co-authored-by: Vincent Sarago <vincent.sarago@gmail.com>
  • Loading branch information
zacdezgeo and vincentsarago authored Sep 9, 2024
1 parent d3a9b8b commit 8d1cd7a
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 260 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
run: |
poetry install --with test
- name: install lib postgres
uses: nyurik/action-setup-postgis@v2

- name: Run pre-commit
working-directory: ./space2stats_api/src
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,4 @@ space2stats_api/space2stats_env
cdk.out
lambda_layer
.venv
.envrc
378 changes: 198 additions & 180 deletions space2stats_api/src/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions space2stats_api/src/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ uvicorn = "*"
pre-commit = "*"
pytest = "*"
pytest-cov = "*"
pytest-mock = "*"
pytest-postgresql = "*"
moto = "^5.0.13"

Expand Down
2 changes: 0 additions & 2 deletions space2stats_api/src/space2stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""space2stats."""

from .main import fields, summaries # noqa

__version__ = "0.1.0"
47 changes: 45 additions & 2 deletions space2stats_api/src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

import boto3
import psycopg
import pytest
from fastapi.testclient import TestClient
from moto import mock_aws
from pytest_postgresql.janitor import DatabaseJanitor


@pytest.fixture()
Expand All @@ -28,10 +31,50 @@ def s3_client(aws_credentials):
def test_bucket(s3_client) -> str:
bucket_name = "test-bucket"
s3_client.create_bucket(Bucket=bucket_name)

return bucket_name


@pytest.fixture(scope="session")
def database(postgresql_proc):
with DatabaseJanitor(
user=postgresql_proc.user,
host=postgresql_proc.host,
port=postgresql_proc.port,
dbname="testdb",
version=postgresql_proc.version,
password="password",
) as jan:
db_url = (
f"postgresql://{jan.user}:{jan.password}@{jan.host}:{jan.port}/{jan.dbname}"
)
with psycopg.connect(db_url) as conn:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS space2stats (
hex_id TEXT PRIMARY KEY,
sum_pop_2020 INT,
sum_pop_f_10_2020 INT
);
""")
cur.execute("""
INSERT INTO space2stats (hex_id, sum_pop_2020, sum_pop_f_10_2020)
VALUES ('hex_1', 100, 200), ('hex_2', 150, 250);
""")

yield jan


@pytest.fixture(autouse=True)
def set_bucket_name(monkeypatch, test_bucket):
def client(monkeypatch, database, test_bucket):
monkeypatch.setenv("PGHOST", database.host)
monkeypatch.setenv("PGPORT", str(database.port))
monkeypatch.setenv("PGDATABASE", database.dbname)
monkeypatch.setenv("PGUSER", database.user)
monkeypatch.setenv("PGPASSWORD", database.password)
monkeypatch.setenv("PGTABLENAME", "space2stats")
monkeypatch.setenv("S3_BUCKET_NAME", test_bucket)

from space2stats.app import app

with TestClient(app) as test_client:
yield test_client
81 changes: 5 additions & 76 deletions space2stats_api/src/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from unittest.mock import patch

import pytest
from fastapi.testclient import TestClient
from pytest_postgresql.janitor import DatabaseJanitor

aoi = {
"type": "Feature",
"geometry": {
Expand All @@ -22,56 +16,20 @@
}


@pytest.fixture(scope="session")
def database(postgresql_proc):
"""Fake Database."""
with DatabaseJanitor(
user=postgresql_proc.user,
host=postgresql_proc.host,
port=postgresql_proc.port,
dbname="testdb",
version=postgresql_proc.version,
password="password",
) as jan:
yield jan


@pytest.fixture(autouse=True)
def client(monkeypatch, database, test_bucket):
monkeypatch.setenv("PGHOST", database.host)
monkeypatch.setenv("PGPORT", str(database.port))
monkeypatch.setenv("PGDATABASE", database.dbname)
monkeypatch.setenv("PGUSER", database.user)
monkeypatch.setenv("PGPASSWORD", database.password)
monkeypatch.setenv("PGTABLENAME", "space2stats")

from space2stats.app import app

with TestClient(app) as app:
yield app


def test_read_root(client):
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to Space2Stats!"}


@patch("space2stats.main._get_summaries")
def test_get_summary(mock_get_summaries, client):
mock_get_summaries.return_value = (
[("hex_1", 100, 200)],
["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"],
)

def test_get_summary(client):
request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
assert isinstance(response_json, list)
Expand All @@ -80,17 +38,10 @@ def test_get_summary(mock_get_summaries, client):
assert "hex_id" in summary
for field in request_payload["fields"]:
assert field in summary
# +1 for the 'hex_id'
assert len(summary) == len(request_payload["fields"]) + 1


@patch("space2stats.main._get_summaries")
def test_get_summary_with_geometry_polygon(mock_get_summaries, client):
mock_get_summaries.return_value = (
[("hex_1", 100, 200)],
["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"],
)

def test_get_summary_with_geometry_polygon(client):
request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
Expand All @@ -99,7 +50,6 @@ def test_get_summary_with_geometry_polygon(mock_get_summaries, client):
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
assert isinstance(response_json, list)
Expand All @@ -108,19 +58,10 @@ def test_get_summary_with_geometry_polygon(mock_get_summaries, client):
assert "hex_id" in summary
assert "geometry" in summary
assert summary["geometry"]["type"] == "Polygon"
for field in request_payload["fields"]:
assert field in summary
# +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2


@patch("space2stats.main._get_summaries")
def test_get_summary_with_geometry_point(mock_get_summaries, client):
mock_get_summaries.return_value = (
[("hex_1", 100, 200)],
["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"],
)

def test_get_summary_with_geometry_point(client):
request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
Expand All @@ -129,7 +70,6 @@ def test_get_summary_with_geometry_point(mock_get_summaries, client):
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
assert isinstance(response_json, list)
Expand All @@ -138,25 +78,14 @@ def test_get_summary_with_geometry_point(mock_get_summaries, client):
assert "hex_id" in summary
assert "geometry" in summary
assert summary["geometry"]["type"] == "Point"
for field in request_payload["fields"]:
assert field in summary
# +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2


@patch("space2stats.app.get_available_fields")
def test_get_fields(mock_get_available_fields, client):
mock_get_available_fields.return_value = [
"sum_pop_2020",
"sum_pop_f_10_2020",
"field3",
]

def test_get_fields(client):
response = client.get("/fields")

assert response.status_code == 200
response_json = response.json()

expected_fields = ["sum_pop_2020", "sum_pop_f_10_2020", "field3"]
expected_fields = ["sum_pop_2020", "sum_pop_f_10_2020"]
for field in expected_fields:
assert field in response_json

0 comments on commit 8d1cd7a

Please sign in to comment.