Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support model configurations #33

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions dewy/chunks/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Literal, Optional, Sequence, Union, Annotated
from typing import Annotated, Literal, Optional, Sequence, Union

from pydantic import BaseModel, Field


class TextChunk(BaseModel):
id: int
document_id: int
Expand All @@ -26,7 +27,9 @@ class ImageChunk(BaseModel):
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')]

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator="kind")]


class RetrieveRequest(BaseModel):
"""A request for retrieving chunks from a collection."""
Expand Down Expand Up @@ -61,17 +64,18 @@ class RetrieveRequest(BaseModel):
include_summary: bool = False
"""Whether to include a generated summary."""


class TextResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

score: float
"""The similarity score of this result."""

text: str
text: str
"Textual description of the chunk."

raw: bool
Expand All @@ -82,10 +86,11 @@ class TextResult(BaseModel):
default=None, description="End char index of the chunk."
)


class ImageResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

Expand All @@ -97,6 +102,7 @@ class ImageResult(BaseModel):
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")


class RetrieveResponse(BaseModel):
"""The response from a retrieval request."""

Expand All @@ -107,4 +113,4 @@ class RetrieveResponse(BaseModel):
"""Retrieved text chunks."""

image_results: Sequence[ImageResult]
"""Retrieved image chunks."""
"""Retrieved image chunks."""
20 changes: 15 additions & 5 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union, Annotated, List
from typing import Annotated, List

from fastapi import APIRouter, Query, Path
from fastapi import APIRouter, Path, Query

from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep
Expand All @@ -9,11 +9,16 @@

router = APIRouter(prefix="/chunks")


@router.get("/")
async def list_chunks(
pg_pool: PgPoolDep,
collection_id: Annotated[int | None, Query(description="Limit to chunks associated with this collection")] = None,
document_id: Annotated[int | None, Query(description="Limit to chunks associated with this document")] = None,
collection_id: Annotated[
int | None, Query(description="Limit to chunks associated with this collection")
] = None,
document_id: Annotated[
int | None, Query(description="Limit to chunks associated with this document")
] = None,
) -> List[Chunk]:
"""List chunks."""

Expand All @@ -31,8 +36,10 @@ async def list_chunks(
)
return [Chunk.model_validate(dict(result)) for result in results]


PathChunkId = Annotated[int, Path(..., description="The chunk ID.")]


@router.get("/{id}")
async def get_chunk(
pg_pool: PgPoolDep,
Expand All @@ -48,6 +55,7 @@ async def get_chunk(
)
return Chunk.model_validate(dict(result))


@router.post("/retrieve")
async def retrieve_chunks(
pg_pool: PgPoolDep, request: RetrieveRequest
Expand All @@ -59,7 +67,9 @@ async def retrieve_chunks(
collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
)
text_results = await collection.retrieve_text_chunks(query=request.query, n=request.n)
text_results = await collection.retrieve_text_chunks(
query=request.query, n=request.n
)

return RetrieveResponse(
summary=None,
Expand Down
27 changes: 10 additions & 17 deletions dewy/collections/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,30 @@
from fastapi import APIRouter, Path

from dewy.collections.models import Collection, CollectionCreate
from dewy.common.collection_embeddings import get_dimensions
from dewy.common.db import PgConnectionDep
from loguru import logger

router = APIRouter(prefix="/collections")


def get_dimensions(model: str) -> int:
# TODO: Consider instantiating the model and applying it to a string
# to determine the dimensions. This would make it easier to support
# new models.
match model:
case "openai:text-embedding-ada-002":
return 1536
case "hf:BAAI/bge-small-en":
return 384
case _:
raise ValueError(f"Unsupported model '{model}'")


@router.put("/")
async def add_collection(
conn: PgConnectionDep, collection: CollectionCreate
) -> Collection:
"""Create a collection."""
dimensions = get_dimensions(collection.text_embedding_model)
dimensions = await get_dimensions(conn, collection.text_embedding_model)
logger.info("Dimensions: {}", dimensions)
async with conn.transaction():
result = await conn.fetchrow(
"""
INSERT INTO collection (name, text_embedding_model, text_distance_metric)
VALUES ($1, $2, $3)
INSERT INTO collection (
name,
text_embedding_model,
text_distance_metric
) VALUES ($1, $2, $3)
RETURNING id, name, text_embedding_model, text_distance_metric
""",
""",
collection.name,
collection.text_embedding_model,
collection.text_distance_metric.value,
Expand Down
88 changes: 76 additions & 12 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from io import text_encoding
from typing import List, Self, Tuple

import asyncpg
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings import BaseEmbedding
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from loguru import logger

from dewy.chunks.models import TextResult
from dewy.collections.models import DistanceMetric
from dewy.collections.router import get_dimensions
from dewy.config import settings

from .extract import extract

Expand All @@ -22,6 +23,7 @@ def __init__(
*,
collection_id: int,
text_embedding_model: str,
text_embedding_dimensions: int,
text_distance_metric: DistanceMetric,
) -> None:
"""Create a new CollectionEmbeddings."""
Expand All @@ -35,14 +37,12 @@ def __init__(

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
# TODO: Support other embeddings (based on the model).
self._embedding = OpenAIEmbedding()
self._embedding = _resolve_embedding_model(self.text_embedding_model)

field = f"embedding::vector({text_embedding_dimensions})"

# TODO: Figure out how to limit by the number of *chunks* not the number
# of embeddings.
dimensions = get_dimensions(self.text_embedding_model)
field = f"embedding::vector({dimensions})"

self._retrieve_embeddings = f"""
SELECT
chunk_id,
Expand Down Expand Up @@ -80,10 +80,13 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
result = await conn.fetchrow(
"""
SELECT
id,
collection.id as id,
text_embedding_model,
text_distance_metric
text_distance_metric,
text_embedding_dimensions.dimensions AS text_embedding_dimensions
FROM collection
JOIN text_embedding_dimensions
ON text_embedding_dimensions.name = collection.text_embedding_model
WHERE collection.id = $1;
""",
collection_id,
Expand All @@ -93,6 +96,7 @@ async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)

Expand All @@ -110,9 +114,12 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self
collection.name,
collection.id as id,
collection.text_embedding_model,
collection.text_distance_metric
collection.text_distance_metric,
text_embedding_dimensions.dimensions AS text_embedding_dimensions
FROM document
JOIN collection ON document.collection_id = collection.id
JOIN text_embedding_dimensions
ON text_embedding_dimensions.name = collection.text_embedding_model
WHERE document.id = $1;
""",
document_id,
Expand All @@ -123,6 +130,7 @@ async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_embedding_dimensions=result["text_embedding_dimensions"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)
return (result["url"], configured_ingestion)
Expand Down Expand Up @@ -170,9 +178,9 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult
TextResult(
chunk_id=e["chunk_id"],
document_id=e["document_id"],
score=e["score"],
score=e["score"],
text=e["text"],
raw=True,
raw=True,
start_char_idx=None,
end_char_idx=None,
)
Expand Down Expand Up @@ -276,3 +284,59 @@ async def _chunk_sentences(self, text: str) -> List[str]:
# all resulting nodes are resident in memory.
# - It uses metadata to return the "window" (if using sentence windows).
return [node.text for node in await self._splitter.acall([TextNode(text=text)])]


DEFAULT_OPENAI_EMBEDDING_MODEL: str = "openai:text-embedding-ada-002"
DEFAULT_HF_EMBEDDING_MODEL: str = "hf:BAAI/bge-small-en"


async def get_dimensions(conn: asyncpg.Connection, model_name: str) -> int:
dimensions = await conn.fetchval(
"""
SELECT dimensions
FROM text_embedding_dimensions
WHERE name = $1
""",
model_name,
)

if dimensions is not None:
return dimensions

model = _resolve_embedding_model(model_name)
dimensions = len(await model.aget_text_embedding("test string"))

# TODO: Deal with concurrency? I suspect it is OK if this fails
# due to the uniqueness constraint, and we should just move on.
# Someone wrote the value for that name to the table, and we should
# have determined the same values.
await conn.execute(
"""
INSERT INTO text_embedding_dimensions (name, dimensions)
VALUES ($1, $2)
""",
model_name,
dimensions,
)

return dimensions


def _resolve_embedding_model(model: str) -> BaseEmbedding:
if not model:
if settings.OPENAI_API_KEY:
model = DEFAULT_OPENAI_EMBEDDING_MODEL
else:
model = DEFAULT_HF_EMBEDDING_MODEL

split = model.split(":", 2)
if split[0] == "openai":
from llama_index.embeddings import OpenAIEmbedding

return OpenAIEmbedding(model=split[1])
elif split[0] == "hf":
from llama_index.embeddings import HuggingFaceEmbedding

return HuggingFaceEmbedding(model_name=split[1])
else:
raise ValueError(f"Unrecognized embedding model '{model}'")
6 changes: 5 additions & 1 deletion dewy/common/db_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ async def apply_migrations(
if applied_migrations:
logger.warn("Unrecognized migrations applied: {}", applied_migrations)

logger.info("Migrations complete. {} total, {} newly applied", len(defined_migrations), applied)
logger.info(
"Migrations complete. {} total, {} newly applied",
len(defined_migrations),
applied,
)


MIGRATION_RE = re.compile(r"([0-9]{4})[a-zA-Z0-9_-]+\.sql")
Expand Down
Loading