Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Switch DB to postgres #10

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Ignore everything
**

# Include (don't ignore) the application code
!app
!./pyproject.toml
!./poetry.lock
**/__pycache__

# Re-ignore pycache within `app`.
**/__pycache__

# Include (don't ignore) the migrations.
!migrations/*.sql
!yoyo.ini
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
# Finally, copy in the application code.
COPY ./app /code/app

COPY ./migrations/0001_schema.sql /code/migrations/0001_schema.sql

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
File renamed without changes.
20 changes: 10 additions & 10 deletions app/statements/models.py → app/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ class RetrieveRequest(BaseModel):
"""Whether to include a generated summary."""


class BaseStatement(BaseModel):
class BaseChunk(BaseModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attempting to s/Statements/Chunk

kind: Literal["text", "raw_text", "image"]

score: Optional[float] = None
"""The similarity score of this statement."""
"""The similarity score of this chunk."""


class TextStatement(BaseStatement):
class TextChunk(BaseChunk):
kind: Literal["text"] = "text"
raw: bool
text: str = Field(default="", description="Text content of the node.")
text: str = Field(default="", description="Text content of the chunk.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
default=None, description="End char index of the chunk."
)


class ImageStatement(BaseStatement):
class ImageChunk(BaseChunk):
kind: Literal["image"] = "image"
text: Optional[str] = Field(..., description="Textual description of the image.")
image: Optional[str] = Field(..., description="Image of the node.")
Expand All @@ -59,7 +59,7 @@ class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""

summary: Optional[str]
"""Summary of the retrieved statements."""
"""Summary of the retrieved chunks."""

statements: Sequence[Union[TextStatement, ImageStatement]]
"""Retrieved statements."""
chunks: Sequence[Union[TextChunk, ImageChunk]]
"""Retrieved chunks."""
16 changes: 8 additions & 8 deletions app/statements/router.py → app/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from app.ingest.store import StoreDep

from .models import ImageStatement, RetrieveRequest, RetrieveResponse, TextStatement
from .models import ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk

router = APIRouter(tags=["statements"], prefix="/statements")
router = APIRouter(prefix="/chunks")


@router.post("/retrieve")
async def retrieve(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
"""Retrieve statements based on a given query."""
async def retrieve_chunks(store: StoreDep, request: RetrieveRequest) -> RetrieveResponse:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resource Naming: I think it makes sense for this to be /chunks/retrieve since it operates on chunks (and conceivably we could have other methods, like list chunks, etc.)

Method Naming: I can see retrieve (if this is the only retireve method, so we have service.retrieve(...)) but called it retrieve_chunks for consistency with other things that name the resource. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retrieve_chunk seems fine - it's explicit about what's being returned which may be good.

It will be easier (in the UI) to treat this as a GET with query params until we have enough query complexity to warrant a full-on post body - it would also simplify the API a bit, allowing this to be treated as just the list view over chunks. I don't want to overly emphasize the FE's needs though.

"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode

Expand All @@ -30,23 +30,23 @@ async def retrieve(store: StoreDep, request: RetrieveRequest) -> RetrieveRespons

return RetrieveResponse(
summary=results.response,
statements=statements if request.include_statements else [],
chunks=statements if request.include_statements else [],
)


def node_to_statement(node: NodeWithScore) -> Union[TextStatement, ImageStatement]:
def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]:
from llama_index.schema import ImageNode, TextNode

if isinstance(node.node, TextNode):
return TextStatement(
return TextChunk(
raw=True,
score=node.score,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)
elif isinstance(node.node, ImageNode):
return ImageStatement(
return ImageChunk(
score=node.score,
text=node.node.text if node.node.text else None,
image=node.node.image,
Expand Down
21 changes: 21 additions & 0 deletions app/collections/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

from pydantic import BaseModel, ConfigDict, TypeAdapter

class Collection(BaseModel):
model_config=ConfigDict(from_attributes=True)

"""A collection of indexed documents."""
id: int
"""The ID of the collection."""

name: str
"""The name of the collection."""

collection_validator = TypeAdapter(Collection)


class CollectionCreate(BaseModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how we want to standardize naming. Basically,I have:

  • Collection the model that we generally return.
  • CollectionCreate (for creation). Could also do CreateCollection or CreateCollectionInput, CreateCollectionRequest, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like Litestar at least goes with <Resource> (for the result), <Resource>Create (for the creation request), and <Resource>Update for update requests:

class Author(BaseModel):
    id: UUID | None
    name: str
    dob: date | None = None

class AuthorCreate(BaseModel):
    name: str
    dob: date | None = None

class AuthorUpdate(BaseModel):
    name: str | None = None
    dob: date | None = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually ended up with something like CreateCollectionRequest and CreateCollectionResponse in the gRPC days, but maybe that's more verbose than is necessary in this case.

"""The request to create a collection."""

name: str
"""The name of the collection."""
32 changes: 17 additions & 15 deletions app/collections/router.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from sqlmodel import Session, select
from pydantic import parse_obj_as

from app.common.schema import Collection, EngineDep
from app.common.db import PgConnectionDep
from app.collections.models import *

router = APIRouter(tags=["collections"], prefix="/collections")
router = APIRouter(prefix="/collections")


@router.put("/")
async def add(engine: EngineDep, collection: Collection) -> Collection:
async def add_collection(conn: PgConnectionDep, collection: CollectionCreate) -> Collection:
"""Create a collection."""
with Session(engine) as session:
session.add(collection)
session.commit()
session.refresh(collection)
return collection
result = await conn.fetchrow("""
INSERT INTO collection (name) VALUES ($1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think asyncpg prepares statements using an LRU cache and the hash of the query. We could (maybe) manually manage that for more efficiency... but should be unnecessary -- we'll likely only have one or two queries per request, so the hash shouldn't be too bad.

RETURNING id, name
""",
collection.name)
return Collection.model_validate(dict(result))


@router.get("/")
async def list(engine: EngineDep) -> List[Collection]:
async def list_collections(conn: PgConnectionDep) -> List[Collection]:
"""List collections."""
with Session(engine) as session:
return session.exec(select(Collection)).all()
results = await conn.fetch("SELECT id, name FROM collection")
return [Collection.model_validate(dict(result)) for result in results]


PathCollectionId = Annotated[int, Path(..., description="The collection ID.")]


@router.get("/{id}")
async def get(id: PathCollectionId, engine: EngineDep) -> Collection:
async def get_collection(id: PathCollectionId, conn: PgConnectionDep) -> Collection:
"""Get a specific collection."""
with Session(engine) as session:
return session.get(Collection, id)
result = await conn.fetchrow("SELECT id, name FROM collection WHERE id = $1", id)
return Collection.model_validate(dict(result))
32 changes: 32 additions & 0 deletions app/common/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import contextlib
from enum import Enum
from typing import Annotated, AsyncIterator, Optional
from uuid import UUID
import asyncpg

from fastapi import Depends, Request

@contextlib.asynccontextmanager
async def create_pool(dsn: str) -> AsyncIterator[asyncpg.Pool]:
"""
Create a postgres connection pool.

Arguments:
- dsn: Connection arguments specified using as a single string in
the following format:
`postgres://user:pass@host:port/database?option=value`.
"""
pool = await asyncpg.create_pool(dsn)
yield pool
pool.close()

def _pg_pool(request: Request) -> asyncpg.Pool:
return request.state.pg_pool

PgPoolDep = Annotated[asyncpg.Pool, Depends(_pg_pool)]

async def _pg_connection(pool: PgPoolDep) -> asyncpg.Connection:
async with pool.acquire() as connection:
yield connection

PgConnectionDep = Annotated[asyncpg.Connection, Depends(_pg_connection)]
53 changes: 0 additions & 53 deletions app/common/schema.py

This file was deleted.

33 changes: 6 additions & 27 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from fastapi.routing import APIRoute
from pydantic import RedisDsn, ValidationInfo, field_validator
from pydantic import RedisDsn, PostgresDsn, ValidationInfo, field_validator
from pydantic_core import Url
from pydantic_settings import BaseSettings

Expand All @@ -19,8 +19,11 @@ class Config:
env_file = ".env"
env_file_encoding = "utf-8"

DB: str = "sqlite:///database.db?check_same_thread=false"
"""The database to connect to."""
DB: PostgresDsn
"""The Postgres database to connect to."""

APPLY_MIGRATIONS: bool = False
"""Whether migrations should be applied to the database."""

ENVIRONMENT: Environment = Environment.PRODUCTION
"""The environment the application is running in."""
Expand Down Expand Up @@ -115,37 +118,13 @@ def custom_generate_unique_id_function(route: APIRoute) -> str:
from a variety of sources -- documents, web pages, audio, etc.
"""

STATEMENTS_DESCRIPTION: str = """Operations for retrieving statements.

Statements include chunks of raw-text, images, and tables from documents,
as well as extracted propositions (facts) and other information from
the documents.

Additionally, a summary of retrieved statements may be requested as well
as the statements or instead of the statements.
"""

app_configs: dict[str, Any] = {
"title": "Dewy Knowledge Base API",
"summary": "Knowledge curation for Retrieval Augmented Generation",
"description": API_DESCRIPTION,
"servers": [
{"url": "http://localhost:8000", "description": "Local server"},
],
"openapi_tags": [
{
"name": "documents",
"description": "Operations on specific documents, including ingestion.",
},
{
"name": "statements",
"description": STATEMENTS_DESCRIPTION,
},
{
"name": "collections",
"description": "Operations related to collections of documents.",
},
],
"generate_unique_id_function": custom_generate_unique_id_function,
}

Expand Down
26 changes: 26 additions & 0 deletions app/documents/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel


class IngestState(Enum):
PENDING = "pending"
"""Document is pending ingestion."""

INGESTED = "ingested"
"""Document has been ingested."""

FAILED = "failed"
"""Document failed to be ingested. See `ingest_errors` for details."""

class Document(BaseModel):
"""Schema for documents in the SQL DB."""

id: Optional[int] = None
collection_id: int

url: str

ingest_state: Optional[IngestState] = None
ingest_error: Optional[str] = None
Loading