From 1ca7bfdc2ee0613979d0fbebbbe160b15a3d10a6 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:23:00 -0800 Subject: [PATCH] refactor: Switch to a factory method for the app (#78) * refactor: Switch to a factory method for the app This allows the tests to inject a specific configuration directly, rather than relying on having the environment set before the config is loaded. It avoids errors like what took me a few hours to debug when I imported the wrong thing in tests, causing the environment to be read too early. This closes #77. * format and lint --- Dockerfile | 2 +- dewy/chunk/router.py | 5 +- dewy/common/collection_embeddings.py | 21 ++++++--- dewy/config.py | 43 ++++++++++-------- dewy/document/router.py | 10 ++-- dewy/main.py | 68 +++++++++++++++++----------- pyproject.toml | 2 +- scripts/extract_openapi.py | 2 +- tests/conftest.py | 15 +++--- 9 files changed, 102 insertions(+), 66 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2eb7ec1..fb6866f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,4 +31,4 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt COPY ./dewy /code/dewy COPY --from=frontend-stage /app/dist /code/dewy/frontend/dist -CMD ["uvicorn", "dewy.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["uvicorn", "dewy.main:create_app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/dewy/chunk/router.py b/dewy/chunk/router.py index 279e2ec..40b55bf 100644 --- a/dewy/chunk/router.py +++ b/dewy/chunk/router.py @@ -4,6 +4,7 @@ from dewy.common.collection_embeddings import CollectionEmbeddings from dewy.common.db import PgPoolDep +from dewy.config import ConfigDep from .models import Chunk, RetrieveRequest, RetrieveResponse, TextChunk @@ -65,14 +66,14 @@ async def get_chunk( @router.post("/retrieve") async def retrieve_chunks( - pg_pool: PgPoolDep, request: RetrieveRequest + pg_pool: PgPoolDep, config: ConfigDep, request: RetrieveRequest ) -> RetrieveResponse: """Retrieve chunks based on a given query.""" # TODO: Revisit response synthesis and hierarchical fetching. collection = await CollectionEmbeddings.for_collection_id( - pg_pool, request.collection_id + pg_pool, config, request.collection_id ) text_results = await collection.retrieve_text_chunks( query=request.query, n=request.n diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 6bbb486..dc5d4c0 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -8,7 +8,7 @@ from dewy.chunk.models import TextResult from dewy.collection.models import DistanceMetric -from dewy.config import settings +from dewy.config import Config from .extract import extract @@ -19,6 +19,7 @@ class CollectionEmbeddings: def __init__( self, pg_pool: asyncpg.Pool, + config: Config, *, collection_id: int, text_embedding_model: str, @@ -36,7 +37,7 @@ def __init__( # TODO: Look at a sentence window splitter? self._splitter = SentenceSplitter(chunk_size=256) - self._embedding = _resolve_embedding_model(self.text_embedding_model) + self._embedding = _resolve_embedding_model(config, self.text_embedding_model) field = f"embedding::vector({text_embedding_dimensions})" @@ -73,7 +74,9 @@ def __init__( """ @staticmethod - async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self: + async def for_collection_id( + pg_pool: asyncpg.Pool, config: Config, collection_id: int + ) -> Self: """Retrieve the collection embeddings of the given collection.""" async with pg_pool.acquire() as conn: result = await conn.fetchrow( @@ -93,6 +96,7 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self: return CollectionEmbeddings( pg_pool, + config, collection_id=result["id"], text_embedding_model=result["text_embedding_model"], text_embedding_dimensions=result["text_embedding_dimensions"], @@ -100,7 +104,9 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self: ) @staticmethod - async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self): + async def for_document_id( + pg_pool: asyncpg.Pool, config: Config, document_id: int + ) -> (str, Self): """Retrieve the collection embeddings and the URL of the given document.""" # TODO: Ideally the collection embeddings would be cached, and this @@ -127,6 +133,7 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self # TODO: Cache the configured ingestions, and only recreate when needed? configured_ingestion = CollectionEmbeddings( pg_pool, + config, collection_id=result["id"], text_embedding_model=result["text_embedding_model"], text_embedding_dimensions=result["text_embedding_dimensions"], @@ -338,9 +345,9 @@ async def get_dimensions(conn: asyncpg.Connection, model_name: str) -> int: return dimensions -def _resolve_embedding_model(model: str) -> BaseEmbedding: +def _resolve_embedding_model(config: Config, model: str) -> BaseEmbedding: if not model: - if settings.OPENAI_API_KEY: + if config.OPENAI_API_KEY: model = DEFAULT_OPENAI_EMBEDDING_MODEL else: model = DEFAULT_HF_EMBEDDING_MODEL @@ -349,7 +356,7 @@ def _resolve_embedding_model(model: str) -> BaseEmbedding: if split[0] == "openai": from llama_index.embeddings import OpenAIEmbedding - return OpenAIEmbedding(model=split[1], api_key=settings.OPENAI_API_KEY) + return OpenAIEmbedding(model=split[1], api_key=config.OPENAI_API_KEY) elif split[0] == "hf": from llama_index.embeddings import HuggingFaceEmbedding diff --git a/dewy/config.py b/dewy/config.py index f0450bd..43de4af 100644 --- a/dewy/config.py +++ b/dewy/config.py @@ -1,5 +1,6 @@ -from typing import Any, Optional +from typing import Annotated, Any, Optional +from fastapi import Depends, Request from fastapi.routing import APIRoute from pydantic import PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict @@ -42,8 +43,27 @@ class Config(BaseSettings): This is required for using openai models. """ + def app_configs(self) -> dict[str, Any]: + API_DESCRIPTION: str = """This API allows ingesting and retrieving knowledge. -settings = Config() + Knowledge comes in a variety of forms -- text, image, tables, etc. and + from a variety of sources -- documents, web pages, audio, etc.""" + + app_configs: dict[str, Any] = { + "title": "Dewy Knowledge Base API", + "version": "0.1.3", + "summary": "Knowledge curation for Retrieval Augmented Generation", + "description": API_DESCRIPTION, + "servers": [ + {"url": "http://localhost:8000", "description": "Local server"}, + ], + "generate_unique_id_function": custom_generate_unique_id_function, + } + + if not self.ENVIRONMENT.is_debug: + app_configs["openapi_url"] = None # hide docs + + return app_configs def convert_snake_case_to_camel_case(string: str) -> str: @@ -59,21 +79,8 @@ def custom_generate_unique_id_function(route: APIRoute) -> str: return convert_snake_case_to_camel_case(route.name) -API_DESCRIPTION: str = """This API allows ingesting and retrieving knowledge. - -Knowledge comes in a variety of forms -- text, image, tables, etc. and -from a variety of sources -- documents, web pages, audio, etc.""" +def _get_config(request: Request) -> Config: + return request.app.config -app_configs: dict[str, Any] = { - "title": "Dewy Knowledge Base API", - "version": "0.1.3", - "summary": "Knowledge curation for Retrieval Augmented Generation", - "description": API_DESCRIPTION, - "servers": [ - {"url": "http://localhost:8000", "description": "Local server"}, - ], - "generate_unique_id_function": custom_generate_unique_id_function, -} -if not settings.ENVIRONMENT.is_debug: - app_configs["openapi_url"] = None # hide docs +ConfigDep = Annotated[Config, Depends(_get_config)] diff --git a/dewy/document/router.py b/dewy/document/router.py index 2b0ae62..b221c26 100644 --- a/dewy/document/router.py +++ b/dewy/document/router.py @@ -6,6 +6,7 @@ from dewy.common.collection_embeddings import CollectionEmbeddings from dewy.common.db import PgConnectionDep, PgPoolDep +from dewy.config import Config, ConfigDep from dewy.document.models import Document from .models import AddDocumentRequest, DocumentStatus @@ -13,10 +14,12 @@ router = APIRouter(prefix="/documents") -async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None: +async def ingest_document( + document_id: int, pg_pool: asyncpg.Pool, config: Config +) -> None: try: url, embeddings = await CollectionEmbeddings.for_document_id( - pg_pool, document_id + pg_pool, config, document_id ) if url.startswith("error://"): raise RuntimeError(url.removeprefix("error://")) @@ -61,6 +64,7 @@ async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None: @router.put("/") async def add_document( pg_pool: PgPoolDep, + config: ConfigDep, background: BackgroundTasks, req: AddDocumentRequest, ) -> Document: @@ -79,7 +83,7 @@ async def add_document( ) document = Document.model_validate(dict(row)) - background.add_task(ingest_document, document.id, pg_pool) + background.add_task(ingest_document, document.id, pg_pool, config) return document diff --git a/dewy/main.py b/dewy/main.py index 181871a..a988b75 100644 --- a/dewy/main.py +++ b/dewy/main.py @@ -5,14 +5,14 @@ import asyncpg import uvicorn -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from loguru import logger from dewy.common import db from dewy.common.db_migration import apply_migrations -from dewy.config import app_configs, settings +from dewy.config import Config from dewy.routes import api_router @@ -27,12 +27,12 @@ class State(TypedDict): @contextlib.asynccontextmanager -async def lifespan(_app: FastAPI) -> AsyncIterator[State]: +async def lifespan(app: FastAPI) -> AsyncIterator[State]: """Function creating instances used during the lifespan of the service.""" - if settings.DB is not None: - async with db.create_pool(settings.DB.unicode_string()) as pg_pool: - if settings.APPLY_MIGRATIONS: + if app.config.DB is not None: + async with db.create_pool(app.config.DB.unicode_string()) as pg_pool: + if app.config.APPLY_MIGRATIONS: async with pg_pool.acquire() as conn: await apply_migrations(conn, migration_dir=migrations_path) @@ -45,35 +45,49 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[State]: yield state -app = FastAPI(lifespan=lifespan, **app_configs) +root_router = APIRouter() -origins = [ - "*", -] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -@app.get("/healthcheck", include_in_schema=False) +@root_router.get("/healthcheck", include_in_schema=False) async def healthcheck() -> dict[str, str]: return {"status": "ok"} -app.include_router(api_router) - -if settings.SERVE_ADMIN_UI and os.path.isdir(react_build_path): - logger.info("Running admin UI at http://localhost:8000/admin") - # Serve static files from the React app build directory - app.mount( - "/admin", StaticFiles(directory=str(react_build_path), html=True), name="static" +def install_middleware(app: FastAPI) -> None: + origins = [ + "*", + ] + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) +def create_app(config: Optional[Config] = None) -> FastAPI: + config = config or Config() + app = FastAPI(lifespan=lifespan, **config.app_configs()) + app.config = config + + install_middleware(app) + + app.include_router(root_router) + app.include_router(api_router) + + if config.SERVE_ADMIN_UI and os.path.isdir(react_build_path): + logger.info("Running admin UI at http://localhost:8000/admin") + # Serve static files from the React app build directory + app.mount( + "/admin", + StaticFiles(directory=str(react_build_path), html=True), + name="static", + ) + + return app + + # Function for running Dewy as a script def run(*args): - uvicorn.run("dewy.main:app", host="0.0.0.0", port=8000) + uvicorn.run("dewy.main:create_app", host="0.0.0.0", port=8000) diff --git a/pyproject.toml b/pyproject.toml index ca63953..542603d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ cmd = "pytest" [tool.poe.tasks.extract-openapi] help = "Update openapi.toml from the swagger docs" -cmd = "python scripts/extract_openapi.py dewy.main:app" +cmd = "python scripts/extract_openapi.py dewy.main:create_app" [tool.poe.tasks.generate-client] help = "Generate the openapi client" diff --git a/scripts/extract_openapi.py b/scripts/extract_openapi.py index a818cd0..a1ae451 100644 --- a/scripts/extract_openapi.py +++ b/scripts/extract_openapi.py @@ -8,7 +8,7 @@ parser = argparse.ArgumentParser(prog="extract-openapi.py") parser.add_argument( - "app", help='App import string. Eg. "main:app"', default="dewy.main:app" + "app", help='App import string. Eg. "main:app"', default="dewy.main:create_app" ) parser.add_argument("--app-dir", help="Directory containing the app", default=None) parser.add_argument( diff --git a/tests/conftest.py b/tests/conftest.py index 2b14b01..fe09fd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from dewy_client import Client from httpx import AsyncClient +from dewy.config import Config + pytest_plugins = ["pytest_docker_fixtures"] from pytest_docker_fixtures.images import configure as configure_image # noqa: E402 @@ -24,14 +26,15 @@ @pytest.fixture(scope="session") async def app(pg, event_loop): - # Set environment variables before the application is loaded. - import os - (pg_host, pg_port) = pg - os.environ["DB"] = f"postgresql://dewydbuser:dewydbpwd@{pg_host}:{pg_port}/dewydb" - os.environ["APPLY_MIGRATIONS"] = "true" + config = Config( + DB=f"postgresql://dewydbuser:dewydbpwd@{pg_host}:{pg_port}/dewydb", + APPLY_MIGRATIONS=True, + ) + + from dewy.main import create_app - from dewy.main import app + app = create_app(config) async with LifespanManager(app) as manager: yield manager.app