Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistralai[patch]: 16k token batching logic embed #17136

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
Expand All @@ -16,9 +17,12 @@
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.exceptions import MistralException
from tokenizers import Tokenizer

logger = logging.getLogger(__name__)

MAX_TOKENS = 16_000


class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
Expand All @@ -43,6 +47,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)

model: str = "mistral-embed"

Expand Down Expand Up @@ -72,8 +77,29 @@ def validate_environment(cls, values: Dict) -> Dict:
timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
)
if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
return values

def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch_tokens = 0
for text in texts:
text_tokens = len(self.tokenizer.encode(text))
efriis marked this conversation as resolved.
Show resolved Hide resolved
if batch_tokens + text_tokens > MAX_TOKENS:
yield batch
batch = [text]
batch_tokens = text_tokens
else:
batch.append(text)
batch_tokens += text_tokens
if batch:
yield batch

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.

Expand All @@ -84,13 +110,17 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = self.client.embeddings(
model=self.model,
input=texts,
batch_responses = (
self.client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")
Expand All @@ -106,13 +136,19 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = await self.async_client.embeddings(
model=self.model,
input=texts,
batch_responses = await asyncio.gather(
*[
self.async_client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
]
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")
Expand Down
Loading
Loading