Skip to content

Commit

Permalink
feat: Write/retrieve chunks using postgres (#17)
Browse files Browse the repository at this point in the history
* feat: Write/retrieve chunks using postgres

This removes the dependency on Redis, and makes the chunks/embeddings
in the postgres database work.

There are some issues to be addressed, specifically deduplicating cases
where multiple embeddings of the same chunk are retrieved. I plan to
work on those in a follow-up PR, so that we can get the bulk of this in
first.
  • Loading branch information
bjchambers authored Jan 26, 2024
1 parent cd8deb9 commit cda4f8b
Show file tree
Hide file tree
Showing 24 changed files with 452 additions and 2,497 deletions.
5 changes: 4 additions & 1 deletion dewy/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@


class RetrieveRequest(BaseModel):
"""A request for retrieving unstructured (document) results."""
"""A request for retrieving chunks from a collection."""

collection_id: int
"""The collection to retrieve chunks from."""

query: str
"""The query string to use for retrieval."""
Expand Down
56 changes: 9 additions & 47 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,24 @@
from typing import Union

from fastapi import APIRouter
from llama_index.schema import NodeWithScore
from loguru import logger

from dewy.ingest.store import StoreDep
from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep

from .models import ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk
from .models import RetrieveRequest, RetrieveResponse

router = APIRouter(prefix="/chunks")


@router.post("/retrieve")
async def retrieve_chunks(
store: StoreDep, request: RetrieveRequest
pg_pool: PgPoolDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode

logger.info("Retrieving statements for query:", request)
results = store.index.as_query_engine(
similarity_top_k=request.n,
response_mode=ResponseMode.TREE_SUMMARIZE
if request.include_summary
else ResponseMode.NO_TEXT,
# TODO: metadata filters / ACLs
).query(request.query)
# TODO: Revisit response synthesis and hierarchical fetching.

statements = [node_to_statement(node) for node in results.source_nodes]

return RetrieveResponse(
summary=results.response,
chunks=statements if request.include_statements else [],
collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
)
chunks = await collection.retrieve_text_chunks(query=request.query, n=request.n)


def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]:
from llama_index.schema import ImageNode, TextNode

if isinstance(node.node, TextNode):
return TextChunk(
raw=True,
score=node.score,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)
elif isinstance(node.node, ImageNode):
return ImageChunk(
score=node.score,
text=node.node.text if node.node.text else None,
image=node.node.image,
image_mimetype=node.node.image_mimetype,
image_path=node.node.image_path,
image_url=node.node.image_url,
)
else:
raise NotImplementedError(
f"Unsupported node type ({node.node.class_name()}): {node!r}"
)
return RetrieveResponse(summary=None, chunks=chunks)
25 changes: 23 additions & 2 deletions dewy/collections/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ def vector_ops(self) -> str:
case DistanceMetric.l2:
return "vector_l2_ops"

def order_by(self, haystack: str, needle: str) -> str:
match self:
case DistanceMetric.cosine:
return f"{haystack} <=> {needle}"
case DistanceMetric.inner_product:
return f"{haystack} <#> {needle}"
case DistanceMetric.l2:
return f"{haystack} <-> {needle}"

def distance(self, haystack: str, needle: str) -> str:
match self:
case DistanceMetric.cosine:
return f"1 - ({haystack} <=> {needle})"
case DistanceMetric.inner_product:
return f"({haystack} <#> {needle}) * -1"
case DistanceMetric.l2:
return f"{haystack} <-> {needle}"


class Collection(BaseModel):
model_config = ConfigDict(from_attributes=True)
Expand Down Expand Up @@ -49,7 +67,10 @@ class CollectionCreate(BaseModel):
name: str = Field(examples=["my_collection"])
"""The name of the collection."""

text_embedding_model: str = Field(examples=["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"])
text_embedding_model: str = Field(
"openai:text-embedding-ada-002",
examples=["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"],
)
"""The name of the embedding model.
NOTE: Changing embedding models is not currently supported.
Expand All @@ -58,4 +79,4 @@ class CollectionCreate(BaseModel):
text_distance_metric: DistanceMetric = DistanceMetric.cosine
"""The distance metric to use on the text embedding.
NOTE: Changing distance metrics is not currently supported."""
NOTE: Changing distance metrics is not currently supported."""
271 changes: 271 additions & 0 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from typing import List, Self, Tuple

import asyncpg
from llama_index.embeddings import OpenAIEmbedding
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from loguru import logger

from dewy.chunks.models import TextChunk
from dewy.collections.models import DistanceMetric
from dewy.collections.router import get_dimensions

from .extract import extract


class CollectionEmbeddings:
"""Helper class for working with the embeddings in a collection."""

def __init__(
self,
pg_pool: asyncpg.Pool,
*,
collection_id: int,
text_embedding_model: str,
text_distance_metric: DistanceMetric,
) -> None:
"""Create a new CollectionEmbeddings."""
self._pg_pool = pg_pool
self.collection_id = collection_id
self.text_embedding_model = text_embedding_model
self.text_distance_metric = text_distance_metric

self.extract_tables = False
self.extract_images = False

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
# TODO: Support other embeddings (based on the model).
self._embedding = OpenAIEmbedding()

# TODO: Figure out how to limit by the number of *chunks* not the number
# of embeddings.
dimensions = get_dimensions(self.text_embedding_model)
field = f"embedding::vector({dimensions})"

self._retrieve_embeddings = f"""
SELECT
chunk_id,
{self.text_distance_metric.distance(field, "$2")} AS score
FROM embedding
WHERE collection_id = $1
ORDER BY {self.text_distance_metric.order_by(field, "$2")}
LIMIT $3
"""

self._retrieve_chunks = f"""
WITH relevant_embeddings AS (
SELECT
chunk_id,
{self.text_distance_metric.distance(field, "$2")} AS score
FROM embedding
WHERE collection_id = $1
ORDER BY {self.text_distance_metric.order_by(field, "$2")}
)
SELECT
relevant_embeddings.chunk_id AS chunk_id,
chunk.text AS text,
relevant_embeddings.score AS score
FROM relevant_embeddings
JOIN chunk
ON chunk.id = relevant_embeddings.chunk_id
LIMIT $3
"""

@staticmethod
async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
"""Retrieve the collection embeddings of the given collection."""
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(
"""
SELECT
id,
text_embedding_model,
text_distance_metric
FROM collection
WHERE collection.id = $1;
""",
collection_id,
)

return CollectionEmbeddings(
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)

@staticmethod
async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self):
"""Retrieve the collection embeddings and the URL of the given document."""

# TODO: Ideally the collection embeddings would be cached, and this
# wouldn't need to exist.
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(
"""
SELECT
document.url as url,
collection.name,
collection.id as id,
collection.text_embedding_model,
collection.text_distance_metric
FROM document
JOIN collection ON document.collection_id = collection.id
WHERE document.id = $1;
""",
document_id,
)

# TODO: Cache the configured ingestions, and only recreate when needed?
configured_ingestion = CollectionEmbeddings(
pg_pool,
collection_id=result["id"],
text_embedding_model=result["text_embedding_model"],
text_distance_metric=DistanceMetric(result["text_distance_metric"]),
)
return (result["url"], configured_ingestion)

async def retrieve_text_embeddings(
self, query: str, n: int = 10
) -> List[Tuple[int, float]]:
"""Retrieve embeddings related to the given query.
Parameters:
- query: The query to retrieve matching embeddings for.
- n: The number of embeddings to retrieve.
Returns:
List of `(chunk_id, score)` pairs from the embeddings.
"""
embedded_query = await self._embedding.aget_text_embedding(query)

async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
embeddings = await conn.fetch(self._retrieve_embeddings,
self.collection_id,
embedded_query,
n)
embeddings = [e["chunk_id"] for e in embeddings]
return embeddings

async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk]:
"""Retrieve embeddings related to the given query.
Parameters:
- query: The query to retrieve matching embeddings for.
- n: The number of embeddings to retrieve.
Returns:
List of chunk_ids from the embeddings.
"""
embedded_query = await self._embedding.aget_text_embedding(query)

async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
embeddings = await conn.fetch(self.collection_id,
self._retrieve_chunks,
embedded_query,
n)
embeddings = [
TextChunk(raw=True, score=e["score"], text=e["text"])
for e in embeddings
]
return embeddings

async def ingest(self, document_id: int, url: str) -> None:
logger.info("Loading content for document {} from '{}'", document_id, url)
extracted = await extract(
url, extract_tables=self.extract_tables, extract_images=self.extract_images
)
if extracted.is_empty():
logger.error(
"No content retrieved from for document {} from '{}'", document_id, url
)
return

logger.info(
"Chunking text of length {} for {}", len(extracted.text), document_id
)

# Extract chunks (snippets) and perform the direct embedding.
text_chunks = await self._chunk_sentences(extracted.text)

logger.info("Chunking produced {} chunks for {}", len(text_chunks), document_id)

# TODO: support non-text chunks
# TODO: support non-snippet text chunks (eg., summary values)
# TODO: support indirect embeddings
async with self._pg_pool.acquire() as conn:
async with conn.transaction():
# First, insert the chunks.
await conn.executemany(
"""
INSERT INTO chunk (document_id, kind, text)
VALUES ($1, $2, $3);
""",
[(document_id, "text", text_chunk) for text_chunk in text_chunks],
)

# Then, embed each of those chunks.
# We assume no chunks for the document existed before, so we can iterate
# over the chunks.
chunks = conn.cursor(
"SELECT id, text FROM chunk WHERE document_id = $1", document_id
)

# TODO: Write this loop in a cleaner async way, to limit the number of
# in-flight requests as well as batching up the embedding requests.
# Currently, this uses Llama Index embeddings, which requires we put
# all the texts to embed in a list.
#
# Ideally, we could take a chunk of embeddings, embed them, and then
# start writing that to the DB asynchronously.
embedding_chunks = [
(chunk["id"], chunk["text"]) async for chunk in chunks
]

# Extract just the text and embed it.
logger.info(
"Computing {} embeddings for {}", len(embedding_chunks), document_id
)
embeddings = await self._embedding.aget_text_embedding_batch(
[item[1] for item in embedding_chunks]
)

# Change the shape to a list of triples (for writing to the DB)
embeddings = [
(self.collection_id, chunk_id, chunk_text, embedding)
for (chunk_id, chunk_text), embedding in zip(
embedding_chunks, embeddings
)
]

logger.info(
"Writing {} embeddings for {}", len(embeddings), document_id
)
await conn.executemany(
"""
INSERT INTO embedding (collection_id, chunk_id, key_text, embedding)
VALUES ($1, $2, $3, $4)
""",
embeddings,
)
logger.info("Wrote {} embeddings for {}", len(embeddings), document_id)

await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
WHERE id = $1
""",
document_id,
)

async def _chunk_sentences(self, text: str) -> List[str]:
# This uses llama index a bit oddly. Unfortunately:
# - It returns `BaseNode` even though we know these are `TextNode`
# - It returns a `List` rather than an `Iterator` / `Generator`, so
# all resulting nodes are resident in memory.
# - It uses metadata to return the "window" (if using sentence windows).
return [node.text for node in await self._splitter.acall([TextNode(text=text)])]
Loading

0 comments on commit cda4f8b

Please sign in to comment.