diff --git a/app/collections/router.py b/app/collections/router.py index 234af09..070ae1b 100644 --- a/app/collections/router.py +++ b/app/collections/router.py @@ -3,15 +3,15 @@ from fastapi import APIRouter, Path from sqlmodel import Session, select -from app.common.schema import Collection, DbDep +from app.common.schema import Collection, EngineDep router = APIRouter(tags=["collections"], prefix="/collections") @router.put("/") -async def add(db: DbDep, collection: Collection) -> Collection: +async def add(engine: EngineDep, collection: Collection) -> Collection: """Create a collection.""" - with Session(db) as session: + with Session(engine) as session: session.add(collection) session.commit() session.refresh(collection) @@ -19,18 +19,17 @@ async def add(db: DbDep, collection: Collection) -> Collection: @router.get("/") -async def list(db: DbDep) -> List[Collection]: +async def list(engine: EngineDep) -> List[Collection]: """List collections.""" - with Session(db) as session: - collections = session.exec(select(Collection)).all() - return collections + with Session(engine) as session: + return session.exec(select(Collection)).all() PathCollectionId = Annotated[int, Path(..., description="The collection ID.")] @router.get("/{id}") -async def get(id: PathCollectionId, db: DbDep) -> Collection: +async def get(id: PathCollectionId, engine: EngineDep) -> Collection: """Get a specific collection.""" - with Session(db) as session: + with Session(engine) as session: return session.get(Collection, id) diff --git a/app/common/schema.py b/app/common/schema.py index 6af9fb4..7fdae38 100644 --- a/app/common/schema.py +++ b/app/common/schema.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Annotated, Optional from fastapi import Depends, Request @@ -11,6 +12,18 @@ class Collection(SQLModel, table=True): # TODO: We may want this to be unique per-tenant rather than globally unique names. name: str = Field(index=True, unique=True) +class IngestState(Enum): + UNKNOWN = "unknown" + """Document is in an unknown state.""" + + PENDING = "pending" + """Document is pending ingestion.""" + + INGESTED = "ingested" + """Document has been ingested.""" + + FAILED = "failed" + """Document failed to be ingested. See `ingest_errors` for details.""" class Document(SQLModel, table=True): """Schema for documents in the SQL DB.""" @@ -21,14 +34,20 @@ class Document(SQLModel, table=True): ) id: Optional[int] = Field(default=None, primary_key=True) - collection_id: int = Field(foreign_key="collection.id") + collection_id: Optional[int] = Field(foreign_key="collection.id") url: str = Field(index=True) doc_id: Optional[str] = Field(default=None) + ingest_state: IngestState = Field(default=IngestState.UNKNOWN) + """The state of the document ingestion.""" + + ingest_error: Optional[str] = Field(default=None) + """Errors which occurred during ingestion, if any.""" + def _db(request: Request) -> Engine: return request.state.engine -DbDep = Annotated[Engine, Depends(_db)] +EngineDep = Annotated[Engine, Depends(_db)] diff --git a/app/documents/router.py b/app/documents/router.py index 9763f6c..055cf0e 100644 --- a/app/documents/router.py +++ b/app/documents/router.py @@ -1,36 +1,85 @@ -from typing import Annotated +from typing import Annotated, List -from fastapi import APIRouter, Body, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Path, status from loguru import logger +from sqlalchemy import Engine +from sqlmodel import Session, select +from app.common.schema import Document, EngineDep, IngestState from app.ingest.extract import extract from app.ingest.extract.source import ExtractSource -from app.ingest.store import StoreDep +from app.ingest.store import Store, StoreDep router = APIRouter(tags=["documents"], prefix="/documents") +async def ingest_document(id: int, store: Store, engine: Engine): + # Load the content. + with Session(engine) as session: + document = session.get(Document, id) + + logger.debug("Loading content from {}", document.url) + documents = await extract( + ExtractSource( + document.url, + ) + ) + logger.debug("Loaded {} pages from {}", len(documents), document.url) + if not documents: + raise HTTPException( + status_code=status.HTTP_412_PRECONDITION_FAILED, + detail=f"No content retrieved from '{document.url}'", + ) + + logger.debug("Inserting {} documents from {}", len(documents), document.url) + nodes = await store.ingestion_pipeline.arun(documents=documents) + logger.debug("Done. Inserted {} nodes", len(nodes)) + + document.ingest_state = IngestState.INGESTED + document.ingest_error = None + session.add(document) + session.commit() + @router.put("/") async def add( store: StoreDep, + engine: EngineDep, + background: BackgroundTasks, url: Annotated[str, Body(..., description="The URL of the document to add.")], -): +) -> Document: """Add a document.""" - # Load the content. - logger.debug("Loading content from {}", url) - documents = await extract( - ExtractSource( - url, - ) + # Update the document in the DB. + document = Document( + url = url ) - logger.debug("Loaded {} pages from {}", len(documents), url) - if not documents: - raise HTTPException( - status_code=status.HTTP_412_PRECONDITION_FAILED, - detail=f"No content retrieved from '{url}'", - ) + with Session(engine) as session: + # TODO: Support update (and fail if the document doesn't exist/etc.) + + document.ingest_state = IngestState.PENDING + document.ingest_error = None + + session.add(document) + session.commit() + session.refresh(document) + + # Create the background task to update the state. + background.add_task(ingest_document, document.id, store, engine) + + return document + +PathDocumentId = Annotated[int, Path(..., description="The document ID.")] + +@router.get("/") +async def list(engine: EngineDep) -> List[Document]: + """List documents.""" + with Session(engine) as session: + return session.exec(select(Document)).all() - logger.debug("Inserting {} documents from {}", len(documents), url) - nodes = await store.ingestion_pipeline.arun(documents=documents) - logger.debug("Done. Inserted {} nodes", len(nodes)) +@router.get("/{id}") +async def get( + engine: EngineDep, + id: PathDocumentId +) -> Document: + with Session(engine) as session: + return session.get(Document, id) diff --git a/docker-compose.yml b/docker-compose.yml index 94bb356..b18cdf3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,7 @@ services: REDIS: "redis://default:testing123@redis:6379" LLAMA_INDEX_CACHE_DIR: "/tmp/cache/llama_index" HF_HOME: "/tmp/cache/hf" + # DB: "sqlite:///var/db/database.db?check_same_thread=false" env_file: - .env build: @@ -21,6 +22,7 @@ services: - redis volumes: - llama-cache:/tmp/cache + - db:/var/db redis: build: @@ -36,6 +38,7 @@ services: - 8001:8001 volumes: + db: redis-data: llama-cache: diff --git a/example_notebook.ipynb b/example_notebook.ipynb index 2d0d88c..f834a40 100644 --- a/example_notebook.ipynb +++ b/example_notebook.ipynb @@ -21,9 +21,21 @@ "source": [ "# Add \"Query Rewriting for Retrieval-Augmented Large Language Models\"\n", "response = client.put(f\"/documents/\",\n", - " content = \"\\\"https://arxiv.org/pdf/2305.14283.pdf\\\"\",\n", - " timeout = None)\n", - "response.raise_for_status()" + " content = \"\\\"https://arxiv.org/pdf/2305.14283.pdf\\\"\")\n", + "response.raise_for_status()\n", + "print(response.json())\n", + "document_id = response.json()['id']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Report the status of the document ingestion.\n", + "response = client.get(f\"/documents/{document_id}\")\n", + "print(response.raise_for_status().json())" ] }, {