Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Write/retrieve chunks using postgres #17

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dewy/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@


class RetrieveRequest(BaseModel):
"""A request for retrieving unstructured (document) results."""
"""A request for retrieving chunks from a collection."""

collection_id: int
"""The collection to retrieve chunks from."""

query: str
"""The query string to use for retrieval."""
Expand Down
56 changes: 9 additions & 47 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,24 @@
from typing import Union

from fastapi import APIRouter
from llama_index.schema import NodeWithScore
from loguru import logger

from dewy.ingest.store import StoreDep
from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep

from .models import ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk
from .models import RetrieveRequest, RetrieveResponse

router = APIRouter(prefix="/chunks")


@router.post("/retrieve")
async def retrieve_chunks(
store: StoreDep, request: RetrieveRequest
pg_pool: PgPoolDep, request: RetrieveRequest
) -> RetrieveResponse:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode

logger.info("Retrieving statements for query:", request)
results = store.index.as_query_engine(
similarity_top_k=request.n,
response_mode=ResponseMode.TREE_SUMMARIZE
if request.include_summary
else ResponseMode.NO_TEXT,
# TODO: metadata filters / ACLs
).query(request.query)
# TODO: Revisit response synthesis and hierarchical fetching.

statements = [node_to_statement(node) for node in results.source_nodes]

return RetrieveResponse(
summary=results.response,
chunks=statements if request.include_statements else [],
collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
)
chunks = await collection.retrieve_text_chunks(query=request.query, n=request.n)


def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]:
from llama_index.schema import ImageNode, TextNode

if isinstance(node.node, TextNode):
return TextChunk(
raw=True,
score=node.score,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)
elif isinstance(node.node, ImageNode):
return ImageChunk(
score=node.score,
text=node.node.text if node.node.text else None,
image=node.node.image,
image_mimetype=node.node.image_mimetype,
image_path=node.node.image_path,
image_url=node.node.image_url,
)
else:
raise NotImplementedError(
f"Unsupported node type ({node.node.class_name()}): {node!r}"
)
return RetrieveResponse(summary=None, chunks=chunks)
25 changes: 23 additions & 2 deletions dewy/collections/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ def vector_ops(self) -> str:
case DistanceMetric.l2:
return "vector_l2_ops"

def order_by(self, haystack: str, needle: str) -> str:
match self:
case DistanceMetric.cosine:
return f"{haystack} <=> {needle}"
case DistanceMetric.inner_product:
return f"{haystack} <#> {needle}"
case DistanceMetric.l2:
return f"{haystack} <-> {needle}"

def distance(self, haystack: str, needle: str) -> str:
match self:
case DistanceMetric.cosine:
return f"1 - ({haystack} <=> {needle})"
case DistanceMetric.inner_product:
return f"({haystack} <#> {needle}) * -1"
case DistanceMetric.l2:
return f"{haystack} <-> {needle}"


class Collection(BaseModel):
model_config = ConfigDict(from_attributes=True)
Expand Down Expand Up @@ -49,7 +67,10 @@ class CollectionCreate(BaseModel):
name: str = Field(examples=["my_collection"])
"""The name of the collection."""

text_embedding_model: str = Field(examples=["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"])
text_embedding_model: str = Field(
"openai:text-embedding-ada-002",
examples=["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"],
)
"""The name of the embedding model.

NOTE: Changing embedding models is not currently supported.
Expand All @@ -58,4 +79,4 @@ class CollectionCreate(BaseModel):
text_distance_metric: DistanceMetric = DistanceMetric.cosine
"""The distance metric to use on the text embedding.

NOTE: Changing distance metrics is not currently supported."""
NOTE: Changing distance metrics is not currently supported."""
259 changes: 259 additions & 0 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from typing import List, Self, Tuple

import asyncpg
from llama_index.embeddings import OpenAIEmbedding
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from loguru import logger

from dewy.chunks.models import TextChunk
from dewy.collections.models import DistanceMetric
from dewy.collections.router import get_dimensions

from .extract import extract


class CollectionEmbeddings:
"""Helper class for working with the embeddings in a collection."""

def __init__(self, pg_pool: asyncpg.Pool, collection_row: asyncpg.Record) -> None:
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
"""Create a new CollectionEmbeddings.

Parameters:
- pg_pool: The asyncpg pool for connecting to the database
- collection_row: A record from the database containing the following
fields: collection_id, name, text_embedding_model, text_distance_model.
"""
self._pg_pool = pg_pool

self.collection_id = collection_row["collection_id"]
self.name = collection_row["name"]
self.text_embedding_model = collection_row["text_embedding_model"]
self.text_distance_metric = DistanceMetric(
collection_row["text_distance_metric"]
)
self.extract_tables = False
self.extract_images = False

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
# TODO: Support other embeddings (based on the model).
self._embedding = OpenAIEmbedding()

# TODO: Figure out how to limit by the number of *chunks* not the number
# of embeddings.
dimensions = get_dimensions(self.text_embedding_model)
field = f"embedding::vector({dimensions})"

self._retrieve_embeddings = f"""
SELECT
chunk_id,
{self.text_distance_metric.distance(field, "$1")} AS score
FROM embedding
WHERE collection_id = {self.collection_id}
ORDER BY {self.text_distance_metric.order_by(field, "$1")}
LIMIT $2
"""

self._retrieve_chunks = f"""
WITH relevant_embeddings AS (
SELECT
chunk_id,
{self.text_distance_metric.distance(field, "$1")} AS score
FROM embedding
WHERE collection_id = {self.collection_id}
ORDER BY {self.text_distance_metric.order_by(field, "$1")}
)
SELECT
relevant_embeddings.chunk_id AS chunk_id,
chunk.text AS text,
relevant_embeddings.score AS score
FROM relevant_embeddings
JOIN chunk
ON chunk.id = relevant_embeddings.chunk_id
LIMIT $2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we could invert this and use SELECT DISTINCT ... from chunk to get the deduplicated chunks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? I want to get something in first and then play with it. I'd like to be able to point a pgsql repl at the database with chunks loaded in, and then see what works (and also use explain to see what the query does, etc.). Deferring.

"""

@staticmethod
async def for_collection_id(pg_pool: asyncpg.Pool, collection_id: int) -> Self:
"""Retrieve the collection embeddings of the given collection."""
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(
"""
SELECT
name,
id as collection_id,
text_embedding_model,
text_distance_metric
FROM collection
WHERE collection.id = $1;
""",
collection_id,
)

return CollectionEmbeddings(pg_pool, result)

@staticmethod
async def for_document_id(pg_pool: asyncpg.Pool, document_id: int) -> (str, Self):
"""Retrieve the collection embeddings and the URL of the given document."""

# TODO: Ideally the collection embeddings would be cached, and this
# wouldn't need to exist.
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(
"""
SELECT
document.id,
document.url,
collection.name,
collection.id as collection_id,
collection.text_embedding_model,
collection.text_distance_metric
FROM document
JOIN collection ON document.collection_id = collection.id
WHERE document.id = $1;
""",
document_id,
)

# TODO: Cache the configured ingestions, and only recreate when needed?
configured_ingestion = CollectionEmbeddings(pg_pool, result)
return (result["url"], configured_ingestion)

async def retrieve_text_embeddings(
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
self, query: str, n: int = 10
) -> List[Tuple[int, float]]:
"""Retrieve embeddings related to the given query.

Parameters:
- query: The query to retrieve matching embeddings for.
- n: The number of embeddings to retrieve.

Returns:
List of `(chunk_id, score)` pairs from the embeddings.
"""
embedded_query = await self._embedding.aget_text_embedding(query)

async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
embeddings = await conn.fetch(self._retrieve_embeddings, embedded_query, n)
embeddings = [e["chunk_id"] for e in embeddings]
return embeddings

async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk]:
"""Retrieve embeddings related to the given query.

Parameters:
- query: The query to retrieve matching embeddings for.
- n: The number of embeddings to retrieve.

Returns:
List of chunk_ids from the embeddings.
"""
embedded_query = await self._embedding.aget_text_embedding(query)

async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
embeddings = await conn.fetch(self._retrieve_chunks, embedded_query, n)
embeddings = [
TextChunk(raw=True, score=e["score"], text=e["text"])
for e in embeddings
]
return embeddings

async def ingest(self, document_id: int, url: str) -> None:
logger.info("Loading content for document {} from '{}'", document_id, url)
extracted = await extract(
url, extract_tables=self.extract_tables, extract_images=self.extract_images
)
if extracted.is_empty():
logger.error(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is an error, shouldn't it throw an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could -- but with background tasks, there isn't really anything to do with that error. What I think we actually need to do is mark the document (or ingestion associated with the document) as failed and/or do some kind of dead letter. That said -- perhaps we shouldn't treat this as an error?

"No content retrieved from for document {} from '{}'", document_id, url
)
return

logger.info(
"Chunking text of length {} for {}", len(extracted.text), document_id
)

# Extract chunks (snippets) and perform the direct embedding.
text_chunks = await self._chunk_sentences(extracted.text)

logger.info("Chunking produced {} chunks for {}", len(text_chunks), document_id)

# TODO: support non-text chunks
# TODO: support non-snippet text chunks (eg., summary values)
# TODO: support indirect embeddings
async with self._pg_pool.acquire() as conn:
async with conn.transaction():
# First, insert the chunks.
await conn.executemany(
"""
INSERT INTO chunk (document_id, kind, text)
VALUES ($1, $2, $3);
""",
[(document_id, "text", text_chunk) for text_chunk in text_chunks],
)

# Then, embed each of those chunks.
# We assume no chunks for the document existed before, so we can iterate
# over the chunks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not return the chunk ID's or something? This seems like an assumption that's going to cause bugs as soon as we support updating a document.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. My thinking was that we could write them into the DB rather than trying to keep them in memory and then read them back out. But, I think that both llamaindex and various other embeddings will lead to the whole text having to fit in memory anyway during an ingest, so maybe it doesn't matter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I take that back. There isn't a great way to do that. Specifically:

  1. This uses executemany, which doesn't return anything.
  2. If we use fetch, we can't provide a list of rows to insert -- it needs to be a single query.

I think I'll leave as is for this PR. I think we could handle update in a variety of ways:

  1. Introduce a new document ID and delete the old one.
  2. Add a "version" to each chunk, and query for only the chunks related to the current version.
  3. etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, including an "ingest version" or something that we could filter on the other side would work.

chunks = conn.cursor(
"SELECT id, text FROM chunk WHERE document_id = $1", document_id
)

# TODO: Write this loop in a cleaner async way, to limit the number of
# in-flight requests as well as batching up the embedding requests.
# Currently, this uses Llama Index embeddings, which requires we put
# all the texts to embed in a list.
#
# Ideally, we could take a chunk of embeddings, embed them, and then
# start writing that to the DB asynchronously.
embedding_chunks = [
(chunk["id"], chunk["text"]) async for chunk in chunks
]

# Extract just the text and embed it.
logger.info(
"Computing {} embeddings for {}", len(embedding_chunks), document_id
)
embeddings = await self._embedding.aget_text_embedding_batch(
[item[1] for item in embedding_chunks]
)

# Change the shape to a list of triples (for writing to the DB)
embeddings = [
(self.collection_id, chunk_id, chunk_text, embedding)
for (chunk_id, chunk_text), embedding in zip(
embedding_chunks, embeddings
)
]

logger.info(
"Writing {} embeddings for {}", len(embeddings), document_id
)
await conn.executemany(
"""
INSERT INTO embedding (collection_id, chunk_id, key_text, embedding)
VALUES ($1, $2, $3, $4)
""",
embeddings,
)
logger.info("Wrote {} embeddings for {}", len(embeddings), document_id)

await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
WHERE id = $1
""",
document_id,
)

async def _chunk_sentences(self, text: str) -> List[str]:
# This uses llama index a bit oddly. Unfortunately:
# - It returns `BaseNode` even though we know these are `TextNode`
# - It returns a `List` rather than an `Iterator` / `Generator`, so
# all resulting nodes are resident in memory.
# - It uses metadata to return the "window" (if using sentence windows).
return [node.text for node in await self._splitter.acall([TextNode(text=text)])]
2 changes: 1 addition & 1 deletion dewy/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def init_pool(conn: asyncpg.Connection):

pool = await asyncpg.create_pool(dsn, init=init_pool)
yield pool
pool.close()
await pool.close()


def _pg_pool(request: Request) -> asyncpg.Pool:
Expand Down
Loading