Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handling for rate limiting #3280

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""

Expand Down Expand Up @@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task(

ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)

logger.info(f"Syncing docs for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)

external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)

Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/natural_language_processing/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""
53 changes: 37 additions & 16 deletions backend/danswer/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import requests
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry

from danswer.configs.app_configs import LARGE_CHUNK_RATIO
Expand All @@ -16,6 +19,9 @@
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -99,28 +105,43 @@ def __init__(
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"

def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
def _make_request() -> Response:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)

return EmbedResponse(**response.json())
response.raise_for_status()
return response

final_make_request_func = _make_request

# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
final_make_request_func = retry(
tries=3,
delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)

try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e

def _batch_encode_texts(
self,
Expand Down
44 changes: 22 additions & 22 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
Expand Down Expand Up @@ -205,28 +206,22 @@ def embed(
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)

embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)

embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")

@staticmethod
def create(
Expand Down Expand Up @@ -430,6 +425,11 @@ async def process_embed_request(
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
Expand Down
40 changes: 40 additions & 0 deletions backend/tests/daily/embedding/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from shared_configs.model_server_models import EmbeddingProvider

VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
VALID_LONG_SAMPLE = ["hi " * 999]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500
Expand Down Expand Up @@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)


@pytest.fixture
def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("AZURE_API_KEY"),
provider_type=EmbeddingProvider.AZURE,
api_url=os.getenv("AZURE_API_URL"),
)


# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +
# this test only being run once in a 1 minute window"""
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
# # limits assuming the limit is 1000 tokens per minute
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# assert len(result) == 1
# assert len(result[0]) == 1536

# # this should fail
# with pytest.raises(ModelServerRateLimitError):
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)

# # this should succeed, since passage requests retry up to 10 times
# start = time.time()
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
# assert len(result) == 1
# assert len(result[0]) == 1536
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits
Loading