Skip to content

Commit

Permalink
ref: Move chunks/documents out of unstructured (#4)
Browse files Browse the repository at this point in the history
* Eliminate the top level collection for simpler ingest / retrieval.
* Eliminate the "unstructured" distinction
* Separate documents and chunks
* Add more openapi annotations for client generation

Co-authored-by: Ryan Michael <kerinin@gmail.com>
  • Loading branch information
bjchambers and kerinin authored Jan 17, 2024
1 parent 4549505 commit 093be1f
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 152 deletions.
File renamed without changes.
14 changes: 14 additions & 0 deletions app/chunks/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Optional, Sequence

from pydantic import BaseModel

from app.common.models import Chunk

class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""

synthesized_text: Optional[str]
"""Synthesized text across all chunks, if requested."""

chunks: Sequence[Chunk]
"""Retrieved chunks."""
25 changes: 25 additions & 0 deletions app/chunks/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from fastapi import APIRouter

from app.common.models import Chunk, RetrieveRequest
from app.ingest.store import StoreDep
from .models import RetrieveResponse

router = APIRouter(tags=["chunks"], prefix="/chunks")

@router.post("/retrieve")
async def retrieve(
store: StoreDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

results = store.index.as_query_engine(
similarity_top_k=request.n,
response_mode=request.synthesis_mode.value,
# TODO: metadata filters / ACLs
).query(request.query)

chunks = [Chunk.from_llama_index(node) for node in results.source_nodes]
return RetrieveResponse(
synthesized_text=results.response,
chunks=chunks,
)
Empty file added app/common/__init__.py
Empty file.
53 changes: 36 additions & 17 deletions app/unstructured/models.py → app/common/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from typing import Optional, Sequence

from typing import Optional, Self, Union
from pydantic import BaseModel, Field

from llama_index.schema import NodeWithScore

class SynthesisMode(str, Enum):
"""How result nodes should be synthesized into a single result."""
Expand Down Expand Up @@ -56,7 +56,6 @@ class SynthesisMode(str, Enum):
This mode is faster than accumulate since we make fewer calls to the LLM.
"""


class RetrieveRequest(BaseModel):
"""A request for retrieving unstructured (document) results."""

Expand All @@ -72,8 +71,7 @@ class RetrieveRequest(BaseModel):
The default (`NO_TEXT`) will disable synthesis.
"""


class TextNode(BaseModel):
class TextContent(BaseModel):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
Expand All @@ -82,18 +80,39 @@ class TextNode(BaseModel):
default=None, description="End char index of the node."
)

class ImageContent(BaseModel):
text: Optional[str] = Field(..., description="Textual description of the image.")
image: Optional[str] = Field(..., description="Image of the node.")
image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.")
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

class NodeWithScore(BaseModel):
node: TextNode
class Chunk(BaseModel):
"""A retrieved chunk."""
content: Union[TextContent, ImageContent]
score: Optional[float] = None


class RetrieveResponse(BaseModel):
"""The response from a retrieval request."""

synthesized_text: Optional[str]
"""Synthesized text if requested."""

# TODO: We may want to copy the NodeWithScore model to avoid API changes.
retrieved_nodes: Sequence[NodeWithScore]
"""Retrieved nodes."""
@staticmethod
def from_llama_index(node: NodeWithScore) -> Self:
score = node.score

content = None
from llama_index.schema import TextNode, ImageNode
if isinstance(node.node, TextNode):
content = TextContent(
text = node.node.text,
start_char_idx = node.node.start_char_idx,
end_char_idx = node.node.end_char_idx
)
elif isinstance(node.node, ImageNode):
content = ImageContent(
text = node.node.text if node.node.text else None,
image = node.node.image,
image_mimetype = node.node.image_mimetype,
image_path = node.node.image_path,
image_url = node.node.image_url,
)
else:
raise NotImplementedError(f"Unsupported node type ({node.node.class_name()}): {node!r}")

return Chunk(content=content, score=score)
49 changes: 45 additions & 4 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional
from fastapi.routing import APIRoute

from pydantic import RedisDsn, ValidationInfo, field_validator
from pydantic_core import Url
Expand Down Expand Up @@ -79,18 +80,58 @@ def validate_ollama_base_url(cls, v, info: ValidationInfo):
MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"]
if v is None:
for model in MODELS:
value = info.get(model, "")
if value.startswith("ollama"):
raise ValueError(
f"{info.field_name} must be set to use '{model}={value}'"
context = info.context
if context:
value = context.get(model, "")
if value.startswith("ollama"):
raise ValueError(
f"{info.field_name} must be set to use '{model}={value}'"
)
return v


settings = Config()

def convert_snake_case_to_camel_case(string: str) -> str:
"""Convert snake case to camel case"""

words = string.split("_")
return words[0] + "".join(word.title() for word in words[1:])


def custom_generate_unique_id_function(route: APIRoute) -> str:
"""Custom function to generate unique id for each endpoint"""

return convert_snake_case_to_camel_case(route.name)


app_configs: dict[str, Any] = {
"title": "Dewy Knowledge Base API",
"summary": "Knowledge curation for Retrieval Augmented Generation",
"description": """This API allows ingesting and retrieving knowledge.
Knowledge comes in a variety of forms -- text, image, tables, etc. and
from a variety of sources -- documents, web pages, audio, etc.
""",
"servers": [
{"url": "http://127.0.0.1:8000", "description": "Local server"},
],
"openapi_tags": [
{
"name": "documents",
"description": "Operations for ingesting and retrieving documents."

},
{
"name": "chunks",
"description": "Operations for retrieving individual chunks.",
},
{
"name": "collections",
"description": "Operations related to collections of documents."
},
],
"generate_unique_id_function": custom_generate_unique_id_function,
}

if not settings.ENVIRONMENT.is_debug:
Expand Down
Empty file added app/documents/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions app/documents/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional, Sequence

from pydantic import BaseModel

from app.common.models import Chunk

class RetrievedDocument(BaseModel):
chunks: Sequence[Chunk]
"""Retrieved chunks in the given document.."""

class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""

synthesized_text: Optional[str]
"""Synthesized text across all documents, if requested."""

documents: Sequence[RetrievedDocument]
"""Retrieved documents."""
49 changes: 49 additions & 0 deletions app/documents/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Annotated

from fastapi import APIRouter, Body, HTTPException, status
from loguru import logger

from app.common.models import RetrieveRequest
from app.documents.models import RetrieveResponse
from app.ingest.extract import extract
from app.ingest.extract.source import ExtractSource
from app.ingest.store import StoreDep

router = APIRouter(tags=["documents"], prefix="/documents")

@router.put("/")
async def add(
store: StoreDep,
url: Annotated[str, Body(..., description="The URL of the document to add.")],
):
"""Add a document to the unstructured collection.
Parameters:
- collection: The ID of the collection to add to.
- document: The URL of the document to add.
"""

# Load the content.
logger.debug("Loading content from {}", url)
documents = await extract(
ExtractSource(
url,
)
)
logger.debug("Loaded {} pages from {}", len(documents), url)
if not documents:
raise HTTPException(
status_code=status.HTTP_412_PRECONDITION_FAILED,
detail=f"No content retrieved from '{url}'",
)

logger.debug("Inserting {} documents from {}", len(documents), url)
nodes = await store.ingestion_pipeline.arun(documents=documents)
logger.debug("Done. Inserted {} nodes", len(nodes))

@router.post("/retrieve")
async def retrieve(
_store: StoreDep, _request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve documents based on a given query."""
raise NotImplementedError()
7 changes: 4 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AsyncIterator, TypedDict

from fastapi import FastAPI
from fastapi.routing import APIRoute
from llama_index import StorageContext

from app.config import app_configs
Expand All @@ -20,13 +21,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[State]:

yield state


app = FastAPI(lifespan=lifespan, **app_configs)
app = FastAPI(
lifespan=lifespan,
**app_configs)


@app.get("/healthcheck", include_in_schema=False)
async def healthcheck() -> dict[str, str]:
return {"status": "ok"}


app.include_router(api_router)
6 changes: 4 additions & 2 deletions app/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi import APIRouter

from app.unstructured.router import router as unstructured_router
from app.chunks.router import router as chunks_router
from app.documents.router import router as documents_router

api_router = APIRouter(prefix="/api")

api_router.include_router(unstructured_router)
api_router.include_router(documents_router)
api_router.include_router(chunks_router)
Loading

0 comments on commit 093be1f

Please sign in to comment.