From 093be1f0d768a75fb82e37a992a5c93cfe379b61 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Wed, 17 Jan 2024 12:39:30 -0800 Subject: [PATCH] ref: Move chunks/documents out of unstructured (#4) * Eliminate the top level collection for simpler ingest / retrieval. * Eliminate the "unstructured" distinction * Separate documents and chunks * Add more openapi annotations for client generation Co-authored-by: Ryan Michael --- app/{unstructured => chunks}/__init__.py | 0 app/chunks/models.py | 14 +++ app/chunks/router.py | 25 +++++ app/common/__init__.py | 0 app/{unstructured => common}/models.py | 53 ++++++---- app/config.py | 49 ++++++++- app/documents/__init__.py | 0 app/documents/models.py | 18 ++++ app/documents/router.py | 49 +++++++++ app/main.py | 7 +- app/routes.py | 6 +- app/unstructured/router.py | 123 ----------------------- example_notebook.ipynb | 6 +- 13 files changed, 198 insertions(+), 152 deletions(-) rename app/{unstructured => chunks}/__init__.py (100%) create mode 100644 app/chunks/models.py create mode 100644 app/chunks/router.py create mode 100644 app/common/__init__.py rename app/{unstructured => common}/models.py (63%) create mode 100644 app/documents/__init__.py create mode 100644 app/documents/models.py create mode 100644 app/documents/router.py delete mode 100644 app/unstructured/router.py diff --git a/app/unstructured/__init__.py b/app/chunks/__init__.py similarity index 100% rename from app/unstructured/__init__.py rename to app/chunks/__init__.py diff --git a/app/chunks/models.py b/app/chunks/models.py new file mode 100644 index 0000000..0ee0ae7 --- /dev/null +++ b/app/chunks/models.py @@ -0,0 +1,14 @@ +from typing import Optional, Sequence + +from pydantic import BaseModel + +from app.common.models import Chunk + +class RetrieveResponse(BaseModel): + """The response from a chunk retrieval request.""" + + synthesized_text: Optional[str] + """Synthesized text across all chunks, if requested.""" + + chunks: Sequence[Chunk] + """Retrieved chunks.""" diff --git a/app/chunks/router.py b/app/chunks/router.py new file mode 100644 index 0000000..18b9eba --- /dev/null +++ b/app/chunks/router.py @@ -0,0 +1,25 @@ +from fastapi import APIRouter + +from app.common.models import Chunk, RetrieveRequest +from app.ingest.store import StoreDep +from .models import RetrieveResponse + +router = APIRouter(tags=["chunks"], prefix="/chunks") + +@router.post("/retrieve") +async def retrieve( + store: StoreDep, request: RetrieveRequest +) -> RetrieveResponse: + """Retrieve chunks based on a given query.""" + + results = store.index.as_query_engine( + similarity_top_k=request.n, + response_mode=request.synthesis_mode.value, + # TODO: metadata filters / ACLs + ).query(request.query) + + chunks = [Chunk.from_llama_index(node) for node in results.source_nodes] + return RetrieveResponse( + synthesized_text=results.response, + chunks=chunks, + ) \ No newline at end of file diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/unstructured/models.py b/app/common/models.py similarity index 63% rename from app/unstructured/models.py rename to app/common/models.py index 9c5e45d..b03e979 100644 --- a/app/unstructured/models.py +++ b/app/common/models.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Optional, Sequence - +from typing import Optional, Self, Union from pydantic import BaseModel, Field +from llama_index.schema import NodeWithScore class SynthesisMode(str, Enum): """How result nodes should be synthesized into a single result.""" @@ -56,7 +56,6 @@ class SynthesisMode(str, Enum): This mode is faster than accumulate since we make fewer calls to the LLM. """ - class RetrieveRequest(BaseModel): """A request for retrieving unstructured (document) results.""" @@ -72,8 +71,7 @@ class RetrieveRequest(BaseModel): The default (`NO_TEXT`) will disable synthesis. """ - -class TextNode(BaseModel): +class TextContent(BaseModel): text: str = Field(default="", description="Text content of the node.") start_char_idx: Optional[int] = Field( default=None, description="Start char index of the node." @@ -82,18 +80,39 @@ class TextNode(BaseModel): default=None, description="End char index of the node." ) +class ImageContent(BaseModel): + text: Optional[str] = Field(..., description="Textual description of the image.") + image: Optional[str] = Field(..., description="Image of the node.") + image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.") + image_path: Optional[str] = Field(..., description="Path of the image.") + image_url: Optional[str] = Field(..., description="URL of the image.") -class NodeWithScore(BaseModel): - node: TextNode +class Chunk(BaseModel): + """A retrieved chunk.""" + content: Union[TextContent, ImageContent] score: Optional[float] = None - -class RetrieveResponse(BaseModel): - """The response from a retrieval request.""" - - synthesized_text: Optional[str] - """Synthesized text if requested.""" - - # TODO: We may want to copy the NodeWithScore model to avoid API changes. - retrieved_nodes: Sequence[NodeWithScore] - """Retrieved nodes.""" + @staticmethod + def from_llama_index(node: NodeWithScore) -> Self: + score = node.score + + content = None + from llama_index.schema import TextNode, ImageNode + if isinstance(node.node, TextNode): + content = TextContent( + text = node.node.text, + start_char_idx = node.node.start_char_idx, + end_char_idx = node.node.end_char_idx + ) + elif isinstance(node.node, ImageNode): + content = ImageContent( + text = node.node.text if node.node.text else None, + image = node.node.image, + image_mimetype = node.node.image_mimetype, + image_path = node.node.image_path, + image_url = node.node.image_url, + ) + else: + raise NotImplementedError(f"Unsupported node type ({node.node.class_name()}): {node!r}") + + return Chunk(content=content, score=score) diff --git a/app/config.py b/app/config.py index 876efde..65d86e8 100644 --- a/app/config.py +++ b/app/config.py @@ -1,4 +1,5 @@ from typing import Any, Optional +from fastapi.routing import APIRoute from pydantic import RedisDsn, ValidationInfo, field_validator from pydantic_core import Url @@ -79,18 +80,58 @@ def validate_ollama_base_url(cls, v, info: ValidationInfo): MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"] if v is None: for model in MODELS: - value = info.get(model, "") - if value.startswith("ollama"): - raise ValueError( - f"{info.field_name} must be set to use '{model}={value}'" + context = info.context + if context: + value = context.get(model, "") + if value.startswith("ollama"): + raise ValueError( + f"{info.field_name} must be set to use '{model}={value}'" ) return v settings = Config() +def convert_snake_case_to_camel_case(string: str) -> str: + """Convert snake case to camel case""" + + words = string.split("_") + return words[0] + "".join(word.title() for word in words[1:]) + + +def custom_generate_unique_id_function(route: APIRoute) -> str: + """Custom function to generate unique id for each endpoint""" + + return convert_snake_case_to_camel_case(route.name) + + app_configs: dict[str, Any] = { "title": "Dewy Knowledge Base API", + "summary": "Knowledge curation for Retrieval Augmented Generation", + "description": """This API allows ingesting and retrieving knowledge. + + Knowledge comes in a variety of forms -- text, image, tables, etc. and + from a variety of sources -- documents, web pages, audio, etc. + """, + "servers": [ + {"url": "http://127.0.0.1:8000", "description": "Local server"}, + ], + "openapi_tags": [ + { + "name": "documents", + "description": "Operations for ingesting and retrieving documents." + + }, + { + "name": "chunks", + "description": "Operations for retrieving individual chunks.", + }, + { + "name": "collections", + "description": "Operations related to collections of documents." + }, + ], + "generate_unique_id_function": custom_generate_unique_id_function, } if not settings.ENVIRONMENT.is_debug: diff --git a/app/documents/__init__.py b/app/documents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/documents/models.py b/app/documents/models.py new file mode 100644 index 0000000..2987875 --- /dev/null +++ b/app/documents/models.py @@ -0,0 +1,18 @@ +from typing import Optional, Sequence + +from pydantic import BaseModel + +from app.common.models import Chunk + +class RetrievedDocument(BaseModel): + chunks: Sequence[Chunk] + """Retrieved chunks in the given document..""" + +class RetrieveResponse(BaseModel): + """The response from a chunk retrieval request.""" + + synthesized_text: Optional[str] + """Synthesized text across all documents, if requested.""" + + documents: Sequence[RetrievedDocument] + """Retrieved documents.""" \ No newline at end of file diff --git a/app/documents/router.py b/app/documents/router.py new file mode 100644 index 0000000..bde7788 --- /dev/null +++ b/app/documents/router.py @@ -0,0 +1,49 @@ +from typing import Annotated + +from fastapi import APIRouter, Body, HTTPException, status +from loguru import logger + +from app.common.models import RetrieveRequest +from app.documents.models import RetrieveResponse +from app.ingest.extract import extract +from app.ingest.extract.source import ExtractSource +from app.ingest.store import StoreDep + +router = APIRouter(tags=["documents"], prefix="/documents") + +@router.put("/") +async def add( + store: StoreDep, + url: Annotated[str, Body(..., description="The URL of the document to add.")], +): + """Add a document to the unstructured collection. + + Parameters: + - collection: The ID of the collection to add to. + - document: The URL of the document to add. + """ + + # Load the content. + logger.debug("Loading content from {}", url) + documents = await extract( + ExtractSource( + url, + ) + ) + logger.debug("Loaded {} pages from {}", len(documents), url) + if not documents: + raise HTTPException( + status_code=status.HTTP_412_PRECONDITION_FAILED, + detail=f"No content retrieved from '{url}'", + ) + + logger.debug("Inserting {} documents from {}", len(documents), url) + nodes = await store.ingestion_pipeline.arun(documents=documents) + logger.debug("Done. Inserted {} nodes", len(nodes)) + +@router.post("/retrieve") +async def retrieve( + _store: StoreDep, _request: RetrieveRequest +) -> RetrieveResponse: + """Retrieve documents based on a given query.""" + raise NotImplementedError() \ No newline at end of file diff --git a/app/main.py b/app/main.py index 313520d..792450a 100644 --- a/app/main.py +++ b/app/main.py @@ -2,6 +2,7 @@ from typing import AsyncIterator, TypedDict from fastapi import FastAPI +from fastapi.routing import APIRoute from llama_index import StorageContext from app.config import app_configs @@ -20,13 +21,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[State]: yield state - -app = FastAPI(lifespan=lifespan, **app_configs) +app = FastAPI( + lifespan=lifespan, + **app_configs) @app.get("/healthcheck", include_in_schema=False) async def healthcheck() -> dict[str, str]: return {"status": "ok"} - app.include_router(api_router) diff --git a/app/routes.py b/app/routes.py index 04712b5..aa7fc90 100644 --- a/app/routes.py +++ b/app/routes.py @@ -1,7 +1,9 @@ from fastapi import APIRouter -from app.unstructured.router import router as unstructured_router +from app.chunks.router import router as chunks_router +from app.documents.router import router as documents_router api_router = APIRouter(prefix="/api") -api_router.include_router(unstructured_router) +api_router.include_router(documents_router) +api_router.include_router(chunks_router) \ No newline at end of file diff --git a/app/unstructured/router.py b/app/unstructured/router.py deleted file mode 100644 index 10dc221..0000000 --- a/app/unstructured/router.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import Annotated - -import llama_index -from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status -from loguru import logger - -from app.ingest.extract import ExtractSource, extract -from app.ingest.store import StoreDep -from app.unstructured.models import ( - NodeWithScore, - RetrieveRequest, - RetrieveResponse, - TextNode, -) - -router = APIRouter(tags=["unstructured"], prefix="/unstructured") - - -class Collection: - def __init__( - self, - collection: Annotated[str, Path(..., description="the name of the collection")], - ) -> None: - self.collection = collection - self.id = hash(collection) - - -PathCollection = Annotated[Collection, Depends()] - - -@router.put("/{collection}/documents") -async def add_document_unstructured( - store: StoreDep, - # TODO: Use the collection. - collection: PathCollection, - url: Annotated[str, Body(..., description="The URL of the document to add.")], -): - """Add a document to the unstructured collection. - - Parameters: - - collection: The ID of the collection to add to. - - document: The URL of the document to add. - """ - - # Load the content. - logger.debug("Loading content from {}", url) - documents = await extract( - ExtractSource( - url, - ) - ) - logger.debug("Loaded {} pages from {}", len(documents), url) - if not documents: - raise HTTPException( - status_code=status.HTTP_412_PRECONDITION_FAILED, - detail=f"No content retrieved from '{url}'", - ) - - logger.debug("Inserting {} documents from {}", len(documents), url) - nodes = await store.ingestion_pipeline.arun(documents=documents) - logger.debug("Done. Inserted {} nodes", len(nodes)) - - -@router.delete("/{collection}/documents/{document}") -async def delete_document_unstructured( - store: StoreDep, collection: PathCollection, document: str -): - """Delete a document from the unstructured collection. - - Parameters: - - collection: The ID of the collection to remove from. - - document: The ID of the document to remove. - """ - raise NotImplementedError() - - -class RetrieveParams: - def __init__( - self, - query: Annotated[ - str, Query(..., description="The query string to use for retrieval") - ], - n: Annotated[ - int, Query(description="Number of document chunks to retrieve") - ] = 10, - ): - self.query = query - self.n = n - - -@router.post("/{collection}/retrieve") -async def retrieve_documents_unstructured( - store: StoreDep, collection: PathCollection, request: RetrieveRequest -) -> RetrieveResponse: - """Retrieve chunks based on a given query.""" - - results = store.index.as_query_engine( - similarity_top_k=request.n, - response_mode=request.synthesis_mode.value, - # TODO: metadata filters / ACLs - ).query(request.query) - - retrieved_nodes = [convert_node(node) for node in results.source_nodes] - return RetrieveResponse( - synthesized_text=results.response, - retrieved_nodes=retrieved_nodes, - ) - - -def convert_node(node: llama_index.schema.NodeWithScore) -> NodeWithScore: - score = node.score - node = node.node - - converted = None - if isinstance(node, llama_index.schema.TextNode): - converted = TextNode( - text=node.text, - start_char_idx=node.start_char_idx, - end_char_idx=node.end_char_idx, - ) - else: - raise NotImplementedError(f"Conversion of {node!r} ({node.class_name})") - return NodeWithScore(node=converted, score=score) diff --git a/example_notebook.ipynb b/example_notebook.ipynb index 183f097..808b365 100644 --- a/example_notebook.ipynb +++ b/example_notebook.ipynb @@ -20,7 +20,7 @@ "outputs": [], "source": [ "# Add \"Query Rewriting for Retrieval-Augmented Large Language Models\"\n", - "response = client.put(f\"/unstructured/my_collection/documents\",\n", + "response = client.put(f\"/documents\",\n", " content = \"\\\"https://arxiv.org/pdf/2305.14283.pdf\\\"\",\n", " timeout = None)\n", "response.raise_for_status()" @@ -33,7 +33,7 @@ "outputs": [], "source": [ "# Retrieve 4 items with no synthesis.\n", - "results = client.post(f\"/unstructured/my_collection/retrieve\",\n", + "results = client.post(f\"/chunks/retrieve\",\n", " json = {\n", " \"query\": \"retrieval augmented generation\",\n", " \"n\": 4\n", @@ -51,7 +51,7 @@ "outputs": [], "source": [ "# Retrieve 32 items, and synthesis a response to the query.\n", - "results = client.post(f\"/unstructured/my_collection/retrieve\",\n", + "results = client.post(f\"/chunks/retrieve\",\n", " json = {\n", " \"query\": \"How does query-rewriting improve Retrieval-Augmented-Generation?\",\n", " \"synthesis_mode\": \"compact\",\n",