Skip to content

Commit

Permalink
fix: Support model configurations (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjchambers authored Jan 29, 2024
1 parent 813573a commit 86fb278
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 60 deletions.
18 changes: 12 additions & 6 deletions dewy/chunks/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Literal, Optional, Sequence, Union, Annotated
from typing import Annotated, Literal, Optional, Sequence, Union

from pydantic import BaseModel, Field


class TextChunk(BaseModel):
id: int
document_id: int
Expand All @@ -26,7 +27,9 @@ class ImageChunk(BaseModel):
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')]

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator="kind")]


class RetrieveRequest(BaseModel):
"""A request for retrieving chunks from a collection."""
Expand Down Expand Up @@ -61,17 +64,18 @@ class RetrieveRequest(BaseModel):
include_summary: bool = False
"""Whether to include a generated summary."""


class TextResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

score: float
"""The similarity score of this result."""

text: str
text: str
"Textual description of the chunk."

raw: bool
Expand All @@ -82,10 +86,11 @@ class TextResult(BaseModel):
default=None, description="End char index of the chunk."
)


class ImageResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

Expand All @@ -97,6 +102,7 @@ class ImageResult(BaseModel):
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")


class RetrieveResponse(BaseModel):
"""The response from a retrieval request."""

Expand All @@ -107,4 +113,4 @@ class RetrieveResponse(BaseModel):
"""Retrieved text chunks."""

image_results: Sequence[ImageResult]
"""Retrieved image chunks."""
"""Retrieved image chunks."""
20 changes: 15 additions & 5 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union, Annotated, List
from typing import Annotated, List

from fastapi import APIRouter, Query, Path
from fastapi import APIRouter, Path, Query

from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep
Expand All @@ -9,11 +9,16 @@

router = APIRouter(prefix="/chunks")


@router.get("/")
async def list_chunks(
pg_pool: PgPoolDep,
collection_id: Annotated[int | None, Query(description="Limit to chunks associated with this collection")] = None,
document_id: Annotated[int | None, Query(description="Limit to chunks associated with this document")] = None,
collection_id: Annotated[
int | None, Query(description="Limit to chunks associated with this collection")
] = None,
document_id: Annotated[
int | None, Query(description="Limit to chunks associated with this document")
] = None,
) -> List[Chunk]:
"""List chunks."""

Expand All @@ -31,8 +36,10 @@ async def list_chunks(
)
return [Chunk.model_validate(dict(result)) for result in results]


PathChunkId = Annotated[int, Path(..., description="The chunk ID.")]


@router.get("/{id}")
async def get_chunk(
pg_pool: PgPoolDep,
Expand All @@ -48,6 +55,7 @@ async def get_chunk(
)
return Chunk.model_validate(dict(result))


@router.post("/retrieve")
async def retrieve_chunks(
pg_pool: PgPoolDep, request: RetrieveRequest
Expand All @@ -59,7 +67,9 @@ async def retrieve_chunks(
collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
)
text_results = await collection.retrieve_text_chunks(query=request.query, n=request.n)
text_results = await collection.retrieve_text_chunks(
query=request.query, n=request.n
)

return RetrieveResponse(
summary=None,
Expand Down
27 changes: 10 additions & 17 deletions dewy/collections/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,30 @@
from fastapi import APIRouter, Path

from dewy.collections.models import Collection, CollectionCreate
from dewy.common.collection_embeddings import get_dimensions
from dewy.common.db import PgConnectionDep
from loguru import logger

router = APIRouter(prefix="/collections")


def get_dimensions(model: str) -> int:
# TODO: Consider instantiating the model and applying it to a string
# to determine the dimensions. This would make it easier to support
# new models.
match model:
case "openai:text-embedding-ada-002":
return 1536
case "hf:BAAI/bge-small-en":
return 384
case _:
raise ValueError(f"Unsupported model '{model}'")


@router.put("/")
async def add_collection(
conn: PgConnectionDep, collection: CollectionCreate
) -> Collection:
"""Create a collection."""
dimensions = get_dimensions(collection.text_embedding_model)
dimensions = await get_dimensions(conn, collection.text_embedding_model)
logger.info("Dimensions: {}", dimensions)
async with conn.transaction():
result = await conn.fetchrow(
"""
INSERT INTO collection (name, text_embedding_model, text_distance_metric)
VALUES ($1, $2, $3)
INSERT INTO collection (
name,
text_embedding_model,
text_distance_metric
) VALUES ($1, $2, $3)
RETURNING id, name, text_embedding_model, text_distance_metric
""",
""",
collection.name,
collection.text_embedding_model,
collection.text_distance_metric.value,
Expand Down
88 changes: 76 additions & 12 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from io import text_encoding
from typing import List, Self, Tuple

import asyncpg
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings import BaseEmbedding
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from loguru import logger

from dewy.chunks.models import TextResult
from dewy.collections.models import DistanceMetric
from dewy.collections.router import get_dimensions
from dewy.config import settings

from .extract import extract

Expand All @@ -22,6 +23,7 @@ def __init__(
*,
collection_id: int,
text_embedding_model: str,
text_embedding_dimensions: int,
text_distance_metric: DistanceMetric,
) -> None:
"""Create a new CollectionEmbeddings."""
Expand All @@ -35,14 +37,12 @@ def __init__(

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
# TODO: Support other embeddings (based on the model).
self._embedding = OpenAIEmbedding()
self._embedding = _resolve_embedding_model(self.text_embedding_model)

field = f"embedding::vector({text_embedding_dimensions})"

# TODO: Figure out how to limit by the number of *chunks* not the number
# of embeddings.
dimensions = get_dimensions(self.text_embedding_model)
field = f"embedding::vector({dimensions})"

self._retrieve_embeddings = f"""
SELECT
chunk_id,
Expand Down Expand Up @@ -80,10 +80,13 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
result = await conn.fetchrow(
"""
SELECT
id,
collection.id as id,
text_embedding_model,
text_distance_metric
text_distance_metric,
text_embedding_dimensions.dimensions AS text_embedding_dimensions
FROM collection
JOIN text_embedding_dimensions
ON text_embedding_dimensions.name = collection.text_embedding_model
WHERE collection.id = $1;
""",
collection_id,
Expand All @@ -93,6 +96,7 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)

Expand All @@ -110,9 +114,12 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self
collection.name,
collection.id as id,
collection.text_embedding_model,
collection.text_distance_metric
collection.text_distance_metric,
text_embedding_dimensions.dimensions AS text_embedding_dimensions
FROM document
JOIN collection ON document.collection_id = collection.id
JOIN text_embedding_dimensions
ON text_embedding_dimensions.name = collection.text_embedding_model
WHERE document.id = $1;
""",
document_id,
Expand All @@ -123,6 +130,7 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)
return (result["url"], configured_ingestion)
Expand Down Expand Up @@ -170,9 +178,9 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult
TextResult(
chunk_id=e["chunk_id"],
document_id=e["document_id"],
score=e["score"],
score=e["score"],
text=e["text"],
raw=True,
raw=True,
start_char_idx=None,
end_char_idx=None,
)
Expand Down Expand Up @@ -276,3 +284,59 @@ async def _chunk_sentences(self, text: str) -> List[str]:
# all resulting nodes are resident in memory.
# - It uses metadata to return the "window" (if using sentence windows).
return [node.text for node in await self._splitter.acall([TextNode(text=text)])]


DEFAULT_OPENAI_EMBEDDING_MODEL: str = "openai:text-embedding-ada-002"
DEFAULT_HF_EMBEDDING_MODEL: str = "hf:BAAI/bge-small-en"


async def get_dimensions(conn: asyncpg.Connection, model_name: str) -> int:
dimensions = await conn.fetchval(
"""
SELECT dimensions
FROM text_embedding_dimensions
WHERE name = $1
""",
model_name,
)

if dimensions is not None:
return dimensions

model = _resolve_embedding_model(model_name)
dimensions = len(await model.aget_text_embedding("test string"))

# TODO: Deal with concurrency? I suspect it is OK if this fails
# due to the uniqueness constraint, and we should just move on.
# Someone wrote the value for that name to the table, and we should
# have determined the same values.
await conn.execute(
"""
INSERT INTO text_embedding_dimensions (name, dimensions)
VALUES ($1, $2)
""",
model_name,
dimensions,
)

return dimensions


def _resolve_embedding_model(model: str) -> BaseEmbedding:
if not model:
if settings.OPENAI_API_KEY:
model = DEFAULT_OPENAI_EMBEDDING_MODEL
else:
model = DEFAULT_HF_EMBEDDING_MODEL

split = model.split(":", 2)
if split[0] == "openai":
from llama_index.embeddings import OpenAIEmbedding

return OpenAIEmbedding(model=split[1])
elif split[0] == "hf":
from llama_index.embeddings import HuggingFaceEmbedding

return HuggingFaceEmbedding(model_name=split[1])
else:
raise ValueError(f"Unrecognized embedding model '{model}'")
6 changes: 5 additions & 1 deletion dewy/common/db_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ async def apply_migrations(
if applied_migrations:
logger.warn("Unrecognized migrations applied: {}", applied_migrations)

logger.info("Migrations complete. {} total, {} newly applied", len(defined_migrations), applied)
logger.info(
"Migrations complete. {} total, {} newly applied",
len(defined_migrations),
applied,
)


MIGRATION_RE = re.compile(r"([0-9]{4})[a-zA-Z0-9_-]+\.sql")
Expand Down
Loading

0 comments on commit 86fb278

Please sign in to comment.