From a9a4d96da37e0c62a27aa0c2edbe19ab463fe18f Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 27 Nov 2024 13:32:32 -0800 Subject: [PATCH 1/3] Add handling for rate limiting --- .../natural_language_processing/exceptions.py | 4 ++ .../search_nlp_models.py | 45 ++++++++++++------- backend/model_server/encoders.py | 44 +++++++++--------- .../tests/daily/embedding/test_embeddings.py | 40 +++++++++++++++++ 4 files changed, 95 insertions(+), 38 deletions(-) create mode 100644 backend/danswer/natural_language_processing/exceptions.py diff --git a/backend/danswer/natural_language_processing/exceptions.py b/backend/danswer/natural_language_processing/exceptions.py new file mode 100644 index 00000000000..5ca112f64ea --- /dev/null +++ b/backend/danswer/natural_language_processing/exceptions.py @@ -0,0 +1,4 @@ +class ModelServerRateLimitError(Exception): + """ + Exception raised for rate limiting errors from the model server. + """ diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index ee80292de63..2346e7de408 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -6,6 +6,7 @@ import requests from httpx import HTTPError +from requests import Response from retry import retry from danswer.configs.app_configs import LARGE_CHUNK_RATIO @@ -16,6 +17,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from danswer.natural_language_processing.exceptions import ModelServerRateLimitError from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -99,28 +101,39 @@ def __init__( self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: - def _make_request() -> EmbedResponse: + def _make_request() -> Response: response = requests.post( self.embed_server_endpoint, json=embed_request.model_dump() ) - try: - response.raise_for_status() - except requests.HTTPError as e: - try: - error_detail = response.json().get("detail", str(e)) - except Exception: - error_detail = response.text - raise HTTPError(f"HTTP error occurred: {error_detail}") from e - except requests.RequestException as e: - raise HTTPError(f"Request failed: {str(e)}") from e + # signify that this is a rate limit error + if response.status_code == 429: + raise ModelServerRateLimitError(response.text) - return EmbedResponse(**response.json()) + response.raise_for_status() + return response - # only perform retries for the non-realtime embedding of passages (e.g. for indexing) + final_make_request_func = _make_request + + # if the text type is a passage, add some default + # retries + handling for rate limiting if embed_request.text_type == EmbedTextType.PASSAGE: - return retry(tries=3, delay=5)(_make_request)() - else: - return _make_request() + final_make_request_func = retry(tries=3, delay=5)(final_make_request_func) + # use 10 second delay as per Azure suggestion + final_make_request_func = retry( + tries=10, delay=10, exceptions=ModelServerRateLimitError + )(final_make_request_func) + + try: + response = final_make_request_func() + return EmbedResponse(**response.json()) + except requests.HTTPError as e: + try: + error_detail = response.json().get("detail", str(e)) + except Exception: + error_detail = response.text + raise HTTPError(f"HTTP error occurred: {error_detail}") from e + except requests.RequestException as e: + raise HTTPError(f"Request failed: {str(e)}") from e def _batch_encode_texts( self, diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 003953cb29a..c72be9e4ac3 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -11,6 +11,7 @@ from fastapi import HTTPException from google.oauth2 import service_account # type: ignore from litellm import embedding +from litellm.exceptions import RateLimitError from retry import retry from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore @@ -205,28 +206,22 @@ def embed( model_name: str | None = None, deployment_name: str | None = None, ) -> list[Embedding]: - try: - if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) - elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") - elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) - - embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) - if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) - else: - raise ValueError(f"Unsupported provider: {self.provider}") - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error embedding text with {self.provider}: {str(e)}", - ) + if self.provider == EmbeddingProvider.OPENAI: + return self._embed_openai(texts, model_name) + elif self.provider == EmbeddingProvider.AZURE: + return self._embed_azure(texts, f"azure/{deployment_name}") + elif self.provider == EmbeddingProvider.LITELLM: + return self._embed_litellm_proxy(texts, model_name) + + embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) + if self.provider == EmbeddingProvider.COHERE: + return self._embed_cohere(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.VOYAGE: + return self._embed_voyage(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.GOOGLE: + return self._embed_vertex(texts, model_name, embedding_type) + else: + raise ValueError(f"Unsupported provider: {self.provider}") @staticmethod def create( @@ -430,6 +425,11 @@ async def process_embed_request( prefix=prefix, ) return EmbedResponse(embeddings=embeddings) + except RateLimitError as e: + raise HTTPException( + status_code=429, + detail=str(e), + ) except Exception as e: exception_detail = f"Error during embedding process:\n{str(e)}" logger.exception(exception_detail) diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index 10a1dd850f6..7182510214f 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -7,6 +7,7 @@ from shared_configs.model_server_models import EmbeddingProvider VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] +VALID_LONG_SAMPLE = ["hi " * 999] # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't # seem to be true TOO_LONG_SAMPLE = ["a"] * 2500 @@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel: def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None: _run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768) _run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768) + + +@pytest.fixture +def azure_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-large", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("AZURE_API_KEY"), + provider_type=EmbeddingProvider.AZURE, + api_url=os.getenv("AZURE_API_URL"), + ) + + +# NOTE (chris): this test doesn't work, and I do not know why +# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel): +# """NOTE: this test relies on a very low rate limit for the Azure API + +# this test only being run once in a 1 minute window""" +# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate +# # limits assuming the limit is 1000 tokens per minute +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# assert len(result) == 1 +# assert len(result[0]) == 1536 + +# # this should fail +# with pytest.raises(ModelServerRateLimitError): +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) + +# # this should succeed, since passage requests retry up to 10 times +# start = time.time() +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE) +# assert len(result) == 1 +# assert len(result[0]) == 1536 +# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits From 45851b64140d4a4a534e5c4a603d896d8155af6f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 27 Nov 2024 13:49:55 -0800 Subject: [PATCH 2/3] fixed logging --- .../celery/tasks/external_group_syncing/tasks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index 61ceae4e463..30b13ad8ba3 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task( tenant_id: str | None, ) -> None: """ - Permission sync task that handles document permission syncing for a given connector credential pair + Permission sync task that handles external group syncing for a given connector credential pair This task assumes that the task has already been properly fenced """ @@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task( ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) if ext_group_sync_func is None: - raise ValueError(f"No external group sync func found for {source_type}") + raise ValueError( + f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}" + ) - logger.info(f"Syncing docs for {source_type}") + logger.info( + f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}" + ) external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair) From 3c699cc112bff0d93fddc2aa74177e75f4e9f323 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 27 Nov 2024 14:03:48 -0800 Subject: [PATCH 3/3] Fix rate limiting stacking --- .../natural_language_processing/search_nlp_models.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 2346e7de408..9fed0d489e7 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -6,6 +6,8 @@ import requests from httpx import HTTPError +from requests import JSONDecodeError +from requests import RequestException from requests import Response from retry import retry @@ -17,7 +19,9 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface -from danswer.natural_language_processing.exceptions import ModelServerRateLimitError +from danswer.natural_language_processing.exceptions import ( + ModelServerRateLimitError, +) from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -117,7 +121,11 @@ def _make_request() -> Response: # if the text type is a passage, add some default # retries + handling for rate limiting if embed_request.text_type == EmbedTextType.PASSAGE: - final_make_request_func = retry(tries=3, delay=5)(final_make_request_func) + final_make_request_func = retry( + tries=3, + delay=5, + exceptions=(RequestException, ValueError, JSONDecodeError), + )(final_make_request_func) # use 10 second delay as per Azure suggestion final_make_request_func = retry( tries=10, delay=10, exceptions=ModelServerRateLimitError