From af6e6e64bc56c19b5206efcde6e4dc18c2b4b567 Mon Sep 17 00:00:00 2001 From: nsantacruz Date: Thu, 9 May 2024 12:13:52 +0300 Subject: [PATCH] feat: switch to basic langchain impl of voyage ai to use caching --- app/basic_langchain/embeddings.py | 42 +++++++++++++++---- .../topic_source_curation_v2/cluster.py | 4 +- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/app/basic_langchain/embeddings.py b/app/basic_langchain/embeddings.py index 6b5c138..d16e4fc 100644 --- a/app/basic_langchain/embeddings.py +++ b/app/basic_langchain/embeddings.py @@ -1,25 +1,49 @@ +from abc import ABC, abstractmethod from typing import Union from openai import OpenAI +import voyageai from basic_langchain.cache import sqlite_cache import numpy as np -class OpenAIEmbeddings: +class AbstractEmbeddings(ABC): + def embed_documents(self, documents: list[str]) -> np.ndarray: + return np.array(self._call_embedding_api(documents)) + + def embed_query(self, query: str) -> np.ndarray: + return self.embed_documents([query])[0] + + @abstractmethod + def _call_embedding_api(self, documents: Union[list[str], str]) -> list[list[float]]: + pass + + +class OpenAIEmbeddings(AbstractEmbeddings): def __init__(self, model: str = "text-embedding-3-small"): self.client = OpenAI() self.model = model - def embed_documents(self, documents: list[str]) -> np.ndarray: - return np.array([d.embedding for d in self._call_embedding_api(documents).data]) - - @sqlite_cache('embedding') def _call_embedding_api(self, documents: Union[list[str], str]): - return self.client.embeddings.create( + return [d.embedding for d in self.client.embeddings.create( input=documents, model=self.model - ) + ).data] + + +class VoyageAIEmbeddings(AbstractEmbeddings): + + def __init__(self, model: str): + self.client = voyageai.Client() + self.model = model + + @sqlite_cache('embedding') + def _call_embedding_api(self, documents: Union[list[str], str]) -> list[list[float]]: + if isinstance(documents, str): + documents = [documents] + return self.client.embed( + documents, + model=self.model + ).embeddings - def embed_query(self, query: str) -> np.ndarray: - return self.embed_documents([query])[0] diff --git a/experiments/topic_source_curation_v2/cluster.py b/experiments/topic_source_curation_v2/cluster.py index 2668fb7..3b1aeb4 100644 --- a/experiments/topic_source_curation_v2/cluster.py +++ b/experiments/topic_source_curation_v2/cluster.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Union from tqdm import tqdm -from langchain_voyageai import VoyageAIEmbeddings +from basic_langchain.embeddings import VoyageAIEmbeddings from sklearn.metrics import silhouette_score from sklearn.cluster import KMeans import random @@ -103,7 +103,7 @@ def _summarize_source_pbar(pbar, source, llm, topic_str): return summaries def embed_text(text): - return np.array(VoyageAIEmbeddings(model="voyage-large-2-instruct", batch_size=1).embed_query(text)) + return np.array(VoyageAIEmbeddings(model="voyage-large-2-instruct").embed_query(text)) def _cluster_sources(sources: list[SummarizedSource], key: Callable[[SummarizedSource], str]) -> list[Cluster]: return cluster_items(sources, key, embed_text)