Skip to content

Commit

Permalink
refactor!: Switch to LangChain for abstractions (#98)
Browse files Browse the repository at this point in the history
* refactor!: Switch to LangChain for abstractions

Since this uses the BGE specific embedding classes, this closes #79.

This also uses the `RecursiveTextSplitter` which we've measured as
performing well.

* ruff

* add langchain dependency
  • Loading branch information
bjchambers authored Feb 27, 2024
1 parent aeb0df8 commit bdbc989
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 23 deletions.
19 changes: 7 additions & 12 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import List, Optional, Self, Tuple, Union

import asyncpg
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from langchain.text_splitter import RecursiveCharacterTextSplitter
from loguru import logger

from dewy.chunk.models import TextResult
Expand Down Expand Up @@ -48,8 +47,9 @@ def __init__(
self.extract_tables = False
self.extract_images = False

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter(chunk_size=256)
self._splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, add_start_index=True
)
embedding = EMBEDDINGS[self.text_embedding_model]
self._embedding = embedding.factory(config)

Expand Down Expand Up @@ -171,7 +171,7 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult
Returns:
List of chunk_ids from the embeddings.
"""
embedded_query = await self._embedding.aget_text_embedding(query)
embedded_query = await self._embedding.aembed_query(query)

async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
Expand Down Expand Up @@ -282,7 +282,7 @@ def encode_chunk(c: str) -> str:

# Extract just the text and embed it.
logger.info("Computing {} embeddings for {}", len(embedding_chunks), document_id)
embeddings = await self._embedding.aget_text_embedding_batch(
embeddings = await self._embedding.aembed_documents(
[item[1] for item in embedding_chunks]
)

Expand Down Expand Up @@ -317,9 +317,4 @@ def encode_chunk(c: str) -> str:
logger.info("Finished updating embeddings for document {}", document_id)

async def _chunk_sentences(self, text: str) -> List[str]:
# This uses llama index a bit oddly. Unfortunately:
# - It returns `BaseNode` even though we know these are `TextNode`
# - It returns a `List` rather than an `Iterator` / `Generator`, so
# all resulting nodes are resident in memory.
# - It uses metadata to return the "window" (if using sentence windows).
return [node.text for node in await self._splitter.acall([TextNode(text=text)])]
return self._splitter.split_text(text)
13 changes: 7 additions & 6 deletions dewy/common/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Callable

from llama_index import OpenAIEmbedding
from llama_index.embeddings import BaseEmbedding, HuggingFaceEmbedding
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings

from dewy.config import Config

Expand All @@ -11,7 +12,7 @@
class EmbeddingModel:
name: str
dimensions: int
factory: Callable[[Config], BaseEmbedding]
factory: Callable[[Config], Embeddings]


EMBEDDINGS = {
Expand All @@ -20,19 +21,19 @@ class EmbeddingModel:
EmbeddingModel(
name="openai:text-embedding-ada-002",
dimensions=1536,
factory=lambda config: OpenAIEmbedding(
factory=lambda config: OpenAIEmbeddings(
model="text-embedding-ada-002", api_key=config.OPENAI_API_KEY
),
),
EmbeddingModel(
name="hf:BAAI/bge-small-en",
dimensions=384,
factory=lambda _config: HuggingFaceEmbedding("BAAI/bge-small-en"),
factory=lambda _config: HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en"),
),
EmbeddingModel(
name="hf:BAAI/bge-small-en-v1.5",
dimensions=384,
factory=lambda _config: HuggingFaceEmbedding("BAAI/bge-small-en-v1.5"),
factory=lambda _config: HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
),
]
}
6 changes: 3 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ services:
image: dewy
environment:
ENVIRONMENT: LOCAL
LLAMA_INDEX_CACHE_DIR: "/tmp/cache/llama_index"
SENTENCE_TRANSFORMERS_HOME: "/tmp/cache/sentence_transformers"
HF_HOME: "/tmp/cache/hf"
DB: "postgresql://dewydbuser:dewydbpwd@postgres/dewydb"
APPLY_MIGRATIONS: true
Expand All @@ -21,7 +21,7 @@ services:
depends_on:
- postgres
volumes:
- llama-cache:/tmp/cache
- dewy-cache:/tmp/cache

postgres:
image: ankane/pgvector
Expand All @@ -43,7 +43,7 @@ services:

volumes:
db:
llama-cache:
dewy-cache:

networks:
kb-network:
Loading

0 comments on commit bdbc989

Please sign in to comment.