diff --git a/cid/main.py b/cid/main.py index 79e6c54..d5f81ef 100644 --- a/cid/main.py +++ b/cid/main.py @@ -1,9 +1,9 @@ import logging -from typing import Any, Dict +from typing import Any, Dict, Generator -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI from sqlalchemy import create_engine, desc -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from cid.config import DATABASE_URL from cid.models import AwsImage, AzureImage, GoogleImage @@ -15,13 +15,20 @@ app = FastAPI() +def get_db() -> Generator: + db = SessionLocal() + try: + yield db + finally: + db.close() + + @app.get("/") def read_root() -> dict: return {"Hello": "World"} -def latest_aws_image(request: Request) -> Dict[str, Any]: - db = SessionLocal() +def latest_aws_image(db: Session) -> Dict[str, Any]: regions = db.query(AwsImage.region).distinct().order_by(AwsImage.region).all() latest_image = ( db.query(AwsImage).order_by(desc(AwsImage.version), desc(AwsImage.date)).first() @@ -48,7 +55,7 @@ def latest_aws_image(request: Request) -> Dict[str, Any]: } -def latest_azure_image(request: Request) -> Dict[str, Any]: +def latest_azure_image(db: Session) -> Dict[str, Any]: db = SessionLocal() latest_image = db.query(AzureImage).order_by(desc(AzureImage.version)).first() @@ -63,7 +70,7 @@ def latest_azure_image(request: Request) -> Dict[str, Any]: } -def latest_google_image(request: Request) -> Dict[str, Any]: +def latest_google_image(db: Session) -> Dict[str, Any]: db = SessionLocal() latest_image = ( db.query(GoogleImage) @@ -83,9 +90,9 @@ def latest_google_image(request: Request) -> Dict[str, Any]: @app.get("/latest") -def latest(request: Request) -> Dict[str, Any]: +def latest(db: Session = Depends(get_db)) -> Dict[str, Any]: # noqa: B008 return { - "latest_aws_image": latest_aws_image(request), - "latest_azure_image": latest_azure_image(request), - "latest_google_image": latest_google_image(request), + "latest_aws_image": latest_aws_image(db), + "latest_azure_image": latest_azure_image(db), + "latest_google_image": latest_google_image(db), } diff --git a/poetry.lock b/poetry.lock index de1b5f9..3c8c3d8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1356,6 +1356,34 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlalchemy-utils" +version = "0.41.2" +description = "Various utility functions for SQLAlchemy." +optional = false +python-versions = ">=3.7" +files = [ + {file = "SQLAlchemy-Utils-0.41.2.tar.gz", hash = "sha256:bc599c8c3b3319e53ce6c5c3c471120bd325d0071fb6f38a10e924e3d07b9990"}, + {file = "SQLAlchemy_Utils-0.41.2-py3-none-any.whl", hash = "sha256:85cf3842da2bf060760f955f8467b87983fb2e30f1764fd0e24a48307dc8ec6e"}, +] + +[package.dependencies] +SQLAlchemy = ">=1.3" + +[package.extras] +arrow = ["arrow (>=0.3.4)"] +babel = ["Babel (>=1.3)"] +color = ["colour (>=0.0.4)"] +encrypted = ["cryptography (>=0.6)"] +intervals = ["intervals (>=0.7.1)"] +password = ["passlib (>=1.6,<2.0)"] +pendulum = ["pendulum (>=2.0.5)"] +phone = ["phonenumbers (>=5.9.2)"] +test = ["Jinja2 (>=2.3)", "Pygments (>=1.2)", "backports.zoneinfo", "docutils (>=0.10)", "flake8 (>=2.4.0)", "flexmock (>=0.9.7)", "isort (>=4.2.2)", "pg8000 (>=1.12.4)", "psycopg (>=3.1.8)", "psycopg2 (>=2.5.1)", "psycopg2cffi (>=2.8.1)", "pymysql", "pyodbc", "pytest (==7.4.4)", "python-dateutil (>=2.6)", "pytz (>=2014.2)"] +test-all = ["Babel (>=1.3)", "Jinja2 (>=2.3)", "Pygments (>=1.2)", "arrow (>=0.3.4)", "backports.zoneinfo", "colour (>=0.0.4)", "cryptography (>=0.6)", "docutils (>=0.10)", "flake8 (>=2.4.0)", "flexmock (>=0.9.7)", "furl (>=0.4.1)", "intervals (>=0.7.1)", "isort (>=4.2.2)", "passlib (>=1.6,<2.0)", "pendulum (>=2.0.5)", "pg8000 (>=1.12.4)", "phonenumbers (>=5.9.2)", "psycopg (>=3.1.8)", "psycopg2 (>=2.5.1)", "psycopg2cffi (>=2.8.1)", "pymysql", "pyodbc", "pytest (==7.4.4)", "python-dateutil", "python-dateutil (>=2.6)", "pytz (>=2014.2)"] +timezone = ["python-dateutil"] +url = ["furl (>=0.4.1)"] + [[package]] name = "starlette" version = "0.37.2" @@ -1788,4 +1816,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.12,<4.0" -content-hash = "739cd03c462b318e2f153dccbd18f18f6bec722a7886c038833bdffb0189337a" +content-hash = "1278dc2b2591331754bf9f455e1d5cb4b35e6886631485410367fa699a7d0b29" diff --git a/pyproject.toml b/pyproject.toml index b010258..0fbd4b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pre-commit = "^3.7.1" tox = "^4.15.0" pytest-randomly = "^3.15.0" pytest-sugar = "^1.0.0" +sqlalchemy-utils = "^0.41.2" [build-system] requires = ["poetry-core"] diff --git a/tests/test_main.py b/tests/test_main.py index a199e13..e24400b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,30 @@ """Tests for the main module.""" +from datetime import datetime + from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy_utils import drop_database + +from cid.main import app, get_db, latest_aws_image +from cid.models import AwsImage + +# Create an in-memory SQLite database for testing +DATABASE_URL = "sqlite:///./test.db" +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + -from cid.main import app +def override_get_db(): + try: + db = TestingSessionLocal() + yield db + finally: + db.close() + + +app.dependency_overrides[get_db] = override_get_db # Create a TestClient instance to test the FastAPI app # https://fastapi.tiangolo.com/tutorial/testing/ @@ -14,3 +36,38 @@ def test_read_root(): response = client.get("/") assert response.status_code == 200 assert response.json() == {"Hello": "World"} + + +def teardown(): + AwsImage.metadata.drop_all(bind=engine) + TestingSessionLocal.remove() + engine.dispose() + + +def test_latest_aws_image(): + AwsImage.metadata.create_all(bind=engine) + + db = TestingSessionLocal() + aws_image = AwsImage( + id="ami-12345678", + name="test_image", + version="1.0", + # SQLite only supports datetime objects + date=datetime.strptime("2022-01-01", "%Y-%m-%d").date(), + region="us-west-1", + imageId="ami-12345678", + ) + db.add(aws_image) + db.commit() + + response = latest_aws_image(db) + + assert response == { + "name": "test_image", + "version": "1.0", + "date": datetime(2022, 1, 1, 0, 0), + "amis": {"us-west-1": "ami-12345678"}, + } + + drop_database(engine.url) + app.dependency_overrides.pop(get_db)