Skip to content

Commit

Permalink
feat: Switch DB to postgres (#10)
Browse files Browse the repository at this point in the history
This uses asyncpg directly.

This removes the tags from the methods, and changes the names to be
more descriptive. The intent is to have a single `service` generated
from the OpenAPI specification that includes methods like
`retrieve_chunks` rather than `chunks.retrieve(...)`.
  • Loading branch information
bjchambers authored Jan 23, 2024
1 parent 1b3da43 commit 3277785
Show file tree
Hide file tree
Showing 19 changed files with 599 additions and 175 deletions.
11 changes: 10 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Ignore everything
**

# Include (don't ignore) the application code
!app
!./pyproject.toml
!./poetry.lock
**/__pycache__

# Re-ignore pycache within `app`.
**/__pycache__

# Include (don't ignore) the migrations.
!migrations/*.sql
!yoyo.ini
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
# Finally, copy in the application code.
COPY ./app /code/app

COPY ./migrations/0001_schema.sql /code/migrations/0001_schema.sql

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
File renamed without changes.
20 changes: 10 additions & 10 deletions app/statements/models.py → app/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ class RetrieveRequest(BaseModel):
"""Whether to include a generated summary."""


class BaseStatement(BaseModel):
class BaseChunk(BaseModel):
kind: Literal["text", "raw_text", "image"]

score: Optional[float] = None
"""The similarity score of this statement."""
"""The similarity score of this chunk."""


class TextStatement(BaseStatement):
class TextChunk(BaseChunk):
kind: Literal["text"] = "text"
raw: bool
text: str = Field(default="", description="Text content of the node.")
text: str = Field(default="", description="Text content of the chunk.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
default=None, description="End char index of the chunk."
)


class ImageStatement(BaseStatement):
class ImageChunk(BaseChunk):
kind: Literal["image"] = "image"
text: Optional[str] = Field(..., description="Textual description of the image.")
image: Optional[str] = Field(..., description="Image of the node.")
Expand All @@ -59,7 +59,7 @@ class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""

summary: Optional[str]
"""Summary of the retrieved statements."""
"""Summary of the retrieved chunks."""

statements: Sequence[Union[TextStatement, ImageStatement]]
"""Retrieved statements."""
chunks: Sequence[Union[TextChunk, ImageChunk]]
"""Retrieved chunks."""
16 changes: 8 additions & 8 deletions app/statements/router.py → app/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from app.ingest.store import StoreDep

from .models import ImageStatement, RetrieveRequest, RetrieveResponse, TextStatement
from .models import ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk

router = APIRouter(tags=["statements"], prefix="/statements")
router = APIRouter(prefix="/chunks")


@router.post("/retrieve")
async def retrieve(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
"""Retrieve statements based on a given query."""
async def retrieve_chunks(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode

Expand All @@ -30,23 +30,23 @@ async def retrieve(store: StoreDep, request: RetrieveRequest) -> RetrieveRespons

return RetrieveResponse(
summary=results.response,
statements=statements if request.include_statements else [],
chunks=statements if request.include_statements else [],
)


def node_to_statement(node: NodeWithScore) -> Union[TextStatement, ImageStatement]:
def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]:
from llama_index.schema import ImageNode, TextNode

if isinstance(node.node, TextNode):
return TextStatement(
return TextChunk(
raw=True,
score=node.score,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)
elif isinstance(node.node, ImageNode):
return ImageStatement(
return ImageChunk(
score=node.score,
text=node.node.text if node.node.text else None,
image=node.node.image,
Expand Down
21 changes: 21 additions & 0 deletions app/collections/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

from pydantic import BaseModel, ConfigDict, TypeAdapter

class Collection(BaseModel):
model_config=ConfigDict(from_attributes=True)

"""A collection of indexed documents."""
id: int
"""The ID of the collection."""

name: str
"""The name of the collection."""

collection_validator = TypeAdapter(Collection)


class CollectionCreate(BaseModel):
"""The request to create a collection."""

name: str
"""The name of the collection."""
32 changes: 17 additions & 15 deletions app/collections/router.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from sqlmodel import Session, select
from pydantic import parse_obj_as

from app.common.schema import Collection, EngineDep
from app.common.db import PgConnectionDep
from app.collections.models import *

router = APIRouter(tags=["collections"], prefix="/collections")
router = APIRouter(prefix="/collections")


@router.put("/")
async def add(engine: EngineDep, collection: Collection) -> Collection:
async def add_collection(conn: PgConnectionDep, collection: CollectionCreate) -> Collection:
"""Create a collection."""
with Session(engine) as session:
session.add(collection)
session.commit()
session.refresh(collection)
return collection
result = await conn.fetchrow("""
INSERT INTO collection (name) VALUES ($1)
RETURNING id, name
""",
collection.name)
return Collection.model_validate(dict(result))


@router.get("/")
async def list(engine: EngineDep) -> List[Collection]:
async def list_collections(conn: PgConnectionDep) -> List[Collection]:
"""List collections."""
with Session(engine) as session:
return session.exec(select(Collection)).all()
results = await conn.fetch("SELECT id, name FROM collection")
return [Collection.model_validate(dict(result)) for result in results]


PathCollectionId = Annotated[int, Path(..., description="The collection ID.")]


@router.get("/{id}")
async def get(id: PathCollectionId, engine: EngineDep) -> Collection:
async def get_collection(id: PathCollectionId, conn: PgConnectionDep) -> Collection:
"""Get a specific collection."""
with Session(engine) as session:
return session.get(Collection, id)
result = await conn.fetchrow("SELECT id, name FROM collection WHERE id = $1", id)
return Collection.model_validate(dict(result))
32 changes: 32 additions & 0 deletions app/common/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import contextlib
from enum import Enum
from typing import Annotated, AsyncIterator, Optional
from uuid import UUID
import asyncpg

from fastapi import Depends, Request

@contextlib.asynccontextmanager
async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
"""
Create a postgres connection pool.
Arguments:
- dsn: Connection arguments specified using as a single string in
the following format:
`postgres://user:pass@host:port/database?option=value`.
"""
pool = await asyncpg.create_pool(dsn)
yield pool
pool.close()

def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool

PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]

async def _pg_connection(pool: PgPoolDep) -> asyncpg.Connection:
async with pool.acquire() as connection:
yield connection

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]
53 changes: 0 additions & 53 deletions app/common/schema.py

This file was deleted.

33 changes: 6 additions & 27 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import RedisDsn, ValidationInfo, field_validator
from pydantic import RedisDsn, PostgresDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic_settings import BaseSettings

Expand All @@ -19,8 +19,11 @@ class Config:
env_file = ".env"
env_file_encoding = "utf-8"

DB: str = "sqlite:///database.db?check_same_thread=false"
"""The database to connect to."""
DB: PostgresDsn
"""The Postgres database to connect to."""

APPLY_MIGRATIONS: bool = False
"""Whether migrations should be applied to the database."""

ENVIRONMENT: Environment = Environment.PRODUCTION
"""The environment the application is running in."""
Expand Down Expand Up @@ -115,37 +118,13 @@ def custom_generate_unique_id_function(route: APIRoute) -> str:
from a variety of sources -- documents, web pages, audio, etc.
"""

STATEMENTS_DESCRIPTION: str = """Operations for retrieving statements.
Statements include chunks of raw-text, images, and tables from documents,
as well as extracted propositions (facts) and other information from
the documents.
Additionally, a summary of retrieved statements may be requested as well
as the statements or instead of the statements.
"""

app_configs: dict[str, Any] = {
"title": "Dewy Knowledge Base API",
"summary": "Knowledge curation for Retrieval Augmented Generation",
"description": API_DESCRIPTION,
"servers": [
{"url": "http://localhost:8000", "description": "Local server"},
],
"openapi_tags": [
{
"name": "documents",
"description": "Operations on specific documents, including ingestion.",
},
{
"name": "statements",
"description": STATEMENTS_DESCRIPTION,
},
{
"name": "collections",
"description": "Operations related to collections of documents.",
},
],
"generate_unique_id_function": custom_generate_unique_id_function,
}

Expand Down
26 changes: 26 additions & 0 deletions app/documents/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel


class IngestState(Enum):
PENDING = "pending"
"""Document is pending ingestion."""

INGESTED = "ingested"
"""Document has been ingested."""

FAILED = "failed"
"""Document failed to be ingested. See `ingest_errors` for details."""

class Document(BaseModel):
"""Schema for documents in the SQL DB."""

id: Optional[int] = None
collection_id: int

url: str

ingest_state: Optional[IngestState] = None
ingest_error: Optional[str] = None
Loading

0 comments on commit 3277785

Please sign in to comment.