Skip to content

Commit

Permalink
refactor: Switch to a factory method for the app (#78)
Browse files Browse the repository at this point in the history
* refactor: Switch to a factory method for the app

This allows the tests to inject a specific configuration directly,
rather than relying on having the environment set before the config
is loaded. It avoids errors like what took me a few hours to debug
when I imported the wrong thing in tests, causing the environment to
be read too early.

This closes #77.

* format and lint
  • Loading branch information
bjchambers authored Feb 8, 2024
1 parent a028d64 commit 1ca7bfd
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY ./dewy /code/dewy
COPY --from=frontend-stage /app/dist /code/dewy/frontend/dist

CMD ["uvicorn", "dewy.main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "dewy.main:create_app", "--host", "0.0.0.0", "--port", "8000"]
5 changes: 3 additions & 2 deletions dewy/chunk/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep
from dewy.config import ConfigDep

from .models import Chunk, RetrieveRequest, RetrieveResponse, TextChunk

Expand Down Expand Up @@ -65,14 +66,14 @@ async def get_chunk(

@router.post("/retrieve")
async def retrieve_chunks(
pg_pool: PgPoolDep, request: RetrieveRequest
pg_pool: PgPoolDep, config: ConfigDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

# TODO: Revisit response synthesis and hierarchical fetching.

collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
pg_pool, config, request.collection_id
)
text_results = await collection.retrieve_text_chunks(
query=request.query, n=request.n
Expand Down
21 changes: 14 additions & 7 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dewy.chunk.models import TextResult
from dewy.collection.models import DistanceMetric
from dewy.config import settings
from dewy.config import Config

from .extract import extract

Expand All @@ -19,6 +19,7 @@ class CollectionEmbeddings:
def __init__(
self,
pg_pool: asyncpg.Pool,
config: Config,
*,
collection_id: int,
text_embedding_model: str,
Expand All @@ -36,7 +37,7 @@ def __init__(

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter(chunk_size=256)
self._embedding = _resolve_embedding_model(self.text_embedding_model)
self._embedding = _resolve_embedding_model(config, self.text_embedding_model)

field = f"embedding::vector({text_embedding_dimensions})"

Expand Down Expand Up @@ -73,7 +74,9 @@ def __init__(
"""

@staticmethod
async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
async def for_collection_id(
pg_pool: asyncpg.Pool, config: Config, collection_id: int
) -> Self:
"""Retrieve the collection embeddings of the given collection."""
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(
Expand All @@ -93,14 +96,17 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:

return CollectionEmbeddings(
pg_pool,
config,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)

@staticmethod
async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self):
async def for_document_id(
pg_pool: asyncpg.Pool, config: Config, document_id: int
) -> (str, Self):
"""Retrieve the collection embeddings and the URL of the given document."""

# TODO: Ideally the collection embeddings would be cached, and this
Expand All @@ -127,6 +133,7 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self
# TODO: Cache the configured ingestions, and only recreate when needed?
configured_ingestion = CollectionEmbeddings(
pg_pool,
config,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
Expand Down Expand Up @@ -338,9 +345,9 @@ async def get_dimensions(conn: asyncpg.Connection, model_name: str) -> int:
return dimensions


def _resolve_embedding_model(model: str) -> BaseEmbedding:
def _resolve_embedding_model(config: Config, model: str) -> BaseEmbedding:
if not model:
if settings.OPENAI_API_KEY:
if config.OPENAI_API_KEY:
model = DEFAULT_OPENAI_EMBEDDING_MODEL
else:
model = DEFAULT_HF_EMBEDDING_MODEL
Expand All @@ -349,7 +356,7 @@ def _resolve_embedding_model(model: str) -> BaseEmbedding:
if split[0] == "openai":
from llama_index.embeddings import OpenAIEmbedding

return OpenAIEmbedding(model=split[1], api_key=settings.OPENAI_API_KEY)
return OpenAIEmbedding(model=split[1], api_key=config.OPENAI_API_KEY)
elif split[0] == "hf":
from llama_index.embeddings import HuggingFaceEmbedding

Expand Down
43 changes: 25 additions & 18 deletions dewy/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Optional
from typing import Annotated, Any, Optional

from fastapi import Depends, Request
from fastapi.routing import APIRoute
from pydantic import PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict
Expand Down Expand Up @@ -42,8 +43,27 @@ class Config(BaseSettings):
This is required for using openai models.
"""

def app_configs(self) -> dict[str, Any]:
API_DESCRIPTION: str = """This API allows ingesting and retrieving knowledge.
settings = Config()
Knowledge comes in a variety of forms -- text, image, tables, etc. and
from a variety of sources -- documents, web pages, audio, etc."""

app_configs: dict[str, Any] = {
"title": "Dewy Knowledge Base API",
"version": "0.1.3",
"summary": "Knowledge curation for Retrieval Augmented Generation",
"description": API_DESCRIPTION,
"servers": [
{"url": "http://localhost:8000", "description": "Local server"},
],
"generate_unique_id_function": custom_generate_unique_id_function,
}

if not self.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs

return app_configs


def convert_snake_case_to_camel_case(string: str) -> str:
Expand All @@ -59,21 +79,8 @@ def custom_generate_unique_id_function(route: APIRoute) -> str:
return convert_snake_case_to_camel_case(route.name)


API_DESCRIPTION: str = """This API allows ingesting and retrieving knowledge.
Knowledge comes in a variety of forms -- text, image, tables, etc. and
from a variety of sources -- documents, web pages, audio, etc."""
def _get_config(request: Request) -> Config:
return request.app.config

app_configs: dict[str, Any] = {
"title": "Dewy Knowledge Base API",
"version": "0.1.3",
"summary": "Knowledge curation for Retrieval Augmented Generation",
"description": API_DESCRIPTION,
"servers": [
{"url": "http://localhost:8000", "description": "Local server"},
],
"generate_unique_id_function": custom_generate_unique_id_function,
}

if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs
ConfigDep = Annotated[Config, Depends(_get_config)]
10 changes: 7 additions & 3 deletions dewy/document/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgConnectionDep, PgPoolDep
from dewy.config import Config, ConfigDep
from dewy.document.models import Document

from .models import AddDocumentRequest, DocumentStatus

router = APIRouter(prefix="/documents")


async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None:
async def ingest_document(
document_id: int, pg_pool: asyncpg.Pool, config: Config
) -> None:
try:
url, embeddings = await CollectionEmbeddings.for_document_id(
pg_pool, document_id
pg_pool, config, document_id
)
if url.startswith("error://"):
raise RuntimeError(url.removeprefix("error://"))
Expand Down Expand Up @@ -61,6 +64,7 @@ async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None:
@router.put("/")
async def add_document(
pg_pool: PgPoolDep,
config: ConfigDep,
background: BackgroundTasks,
req: AddDocumentRequest,
) -> Document:
Expand All @@ -79,7 +83,7 @@ async def add_document(
)

document = Document.model_validate(dict(row))
background.add_task(ingest_document, document.id, pg_pool)
background.add_task(ingest_document, document.id, pg_pool, config)
return document


Expand Down
68 changes: 41 additions & 27 deletions dewy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import asyncpg
import uvicorn
from fastapi import FastAPI
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from loguru import logger

from dewy.common import db
from dewy.common.db_migration import apply_migrations
from dewy.config import app_configs, settings
from dewy.config import Config
from dewy.routes import api_router


Expand All @@ -27,12 +27,12 @@ class State(TypedDict):


@contextlib.asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
async def lifespan(app: FastAPI) -> AsyncIterator[State]:
"""Function creating instances used during the lifespan of the service."""

if settings.DB is not None:
async with db.create_pool(settings.DB.unicode_string()) as pg_pool:
if settings.APPLY_MIGRATIONS:
if app.config.DB is not None:
async with db.create_pool(app.config.DB.unicode_string()) as pg_pool:
if app.config.APPLY_MIGRATIONS:
async with pg_pool.acquire() as conn:
await apply_migrations(conn, migration_dir=migrations_path)

Expand All @@ -45,35 +45,49 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
yield state


app = FastAPI(lifespan=lifespan, **app_configs)
root_router = APIRouter()

origins = [
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@app.get("/healthcheck", include_in_schema=False)
@root_router.get("/healthcheck", include_in_schema=False)
async def healthcheck() -> dict[str, str]:
return {"status": "ok"}


app.include_router(api_router)

if settings.SERVE_ADMIN_UI and os.path.isdir(react_build_path):
logger.info("Running admin UI at http://localhost:8000/admin")
# Serve static files from the React app build directory
app.mount(
"/admin", StaticFiles(directory=str(react_build_path), html=True), name="static"
def install_middleware(app: FastAPI) -> None:
origins = [
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


def create_app(config: Optional[Config] = None) -> FastAPI:
config = config or Config()
app = FastAPI(lifespan=lifespan, **config.app_configs())
app.config = config

install_middleware(app)

app.include_router(root_router)
app.include_router(api_router)

if config.SERVE_ADMIN_UI and os.path.isdir(react_build_path):
logger.info("Running admin UI at http://localhost:8000/admin")
# Serve static files from the React app build directory
app.mount(
"/admin",
StaticFiles(directory=str(react_build_path), html=True),
name="static",
)

return app


# Function for running Dewy as a script
def run(*args):
uvicorn.run("dewy.main:app", host="0.0.0.0", port=8000)
uvicorn.run("dewy.main:create_app", host="0.0.0.0", port=8000)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ cmd = "pytest"

[tool.poe.tasks.extract-openapi]
help = "Update openapi.toml from the swagger docs"
cmd = "python scripts/extract_openapi.py dewy.main:app"
cmd = "python scripts/extract_openapi.py dewy.main:create_app"

[tool.poe.tasks.generate-client]
help = "Generate the openapi client"
Expand Down
2 changes: 1 addition & 1 deletion scripts/extract_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

parser = argparse.ArgumentParser(prog="extract-openapi.py")
parser.add_argument(
"app", help='App import string. Eg. "main:app"', default="dewy.main:app"
"app", help='App import string. Eg. "main:app"', default="dewy.main:create_app"
)
parser.add_argument("--app-dir", help="Directory containing the app", default=None)
parser.add_argument(
Expand Down
15 changes: 9 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dewy_client import Client
from httpx import AsyncClient

from dewy.config import Config

pytest_plugins = ["pytest_docker_fixtures"]

from pytest_docker_fixtures.images import configure as configure_image # noqa: E402
Expand All @@ -24,14 +26,15 @@

@pytest.fixture(scope="session")
async def app(pg, event_loop):
# Set environment variables before the application is loaded.
import os

(pg_host, pg_port) = pg
os.environ["DB"] = f"postgresql://dewydbuser:dewydbpwd@{pg_host}:{pg_port}/dewydb"
os.environ["APPLY_MIGRATIONS"] = "true"
config = Config(
DB=f"postgresql://dewydbuser:dewydbpwd@{pg_host}:{pg_port}/dewydb",
APPLY_MIGRATIONS=True,
)

from dewy.main import create_app

from dewy.main import app
app = create_app(config)

async with LifespanManager(app) as manager:
yield manager.app
Expand Down

0 comments on commit 1ca7bfd

Please sign in to comment.