From 27560bd5ad95636f7e95da8fc74bc2162a100f37 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 22 Jan 2025 21:52:07 -0800 Subject: [PATCH] Litellm dev 01 22 2025 p4 (#7932) * feat(main.py): add new 'provider_specific_header' param allows passing extra header for specific provider * fix(litellm_pre_call_utils.py): add unit test for pre call utils * test(test_bedrock_completion.py): skip test now that bedrock supports this --- litellm/main.py | 11 ++++ litellm/proxy/litellm_pre_call_utils.py | 12 ++-- litellm/types/utils.py | 6 ++ tests/local_testing/test_completion.py | 30 ++++++++++ tests/proxy_unit_tests/test_proxy_utils.py | 65 ++++++++++++++++++++++ 5 files changed, 119 insertions(+), 5 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 8042fb1cc80c..37ef97864256 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -179,6 +179,7 @@ HiddenParams, LlmProviders, PromptTokensDetails, + ProviderSpecificHeader, all_litellm_params, ) @@ -832,6 +833,9 @@ def completion( # type: ignore # noqa: PLR0915 model_info = kwargs.get("model_info", None) proxy_server_request = kwargs.get("proxy_server_request", None) fallbacks = kwargs.get("fallbacks", None) + provider_specific_header = cast( + Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None) + ) headers = kwargs.get("headers", None) or extra_headers ensure_alternating_roles: Optional[bool] = kwargs.get( "ensure_alternating_roles", None @@ -937,6 +941,13 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, api_key=api_key, ) + + if ( + provider_specific_header is not None + and provider_specific_header["custom_llm_provider"] == custom_llm_provider + ): + headers.update(provider_specific_header["extra_headers"]) + if model_response is not None and hasattr(model_response, "_hidden_params"): model_response._hidden_params["custom_llm_provider"] = custom_llm_provider model_response._hidden_params["region_name"] = kwargs.get( diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9839a519a2da..94ff51fd2aab 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -20,6 +20,7 @@ from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS from litellm.types.services import ServiceTypes from litellm.types.utils import ( + ProviderSpecificHeader, StandardLoggingUserAPIKeyMetadata, SupportedCacheControls, ) @@ -729,19 +730,20 @@ def add_provider_specific_headers_to_request( data: dict, headers: dict, ): - - extra_headers = data.get("extra_headers", {}) or {} - + anthropic_headers = {} # boolean to indicate if a header was added added_header = False for header in ANTHROPIC_API_HEADERS: if header in headers: header_value = headers[header] - extra_headers.update({header: header_value}) + anthropic_headers[header] = header_value added_header = True if added_header is True: - data["extra_headers"] = extra_headers + data["provider_specific_header"] = ProviderSpecificHeader( + custom_llm_provider="anthropic", + extra_headers=anthropic_headers, + ) return diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d60b263052e0..09f0f864d21c 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1658,6 +1658,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False): "api_key", "api_version", "prompt_id", + "provider_specific_header", "prompt_variables", "api_base", "force_timeout", @@ -1879,3 +1880,8 @@ class HttpHandlerRequestFields(TypedDict, total=False): params: dict # query params files: dict # file uploads content: Any # raw content + + +class ProviderSpecificHeader(TypedDict): + custom_llm_provider: str + extra_headers: dict diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 466369ef6ec3..ef90d56f7011 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -4550,3 +4550,33 @@ def test_deepseek_reasoning_content_completion(): resp.choices[0].message.provider_specific_fields["reasoning_content"] is not None ) + + +@pytest.mark.parametrize( + "custom_llm_provider, expected_result", + [("anthropic", {"anthropic-beta": "test"}), ("bedrock", {}), ("vertex_ai", {})], +) +def test_provider_specific_header(custom_llm_provider, expected_result): + from litellm.types.utils import ProviderSpecificHeader + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from unittest.mock import patch + + litellm.set_verbose = True + client = HTTPHandler() + with patch.object(client, "post", return_value=MagicMock()) as mock_post: + try: + resp = litellm.completion( + model="anthropic/claude-3-5-sonnet-v2@20241022", + messages=[{"role": "user", "content": "Hello world"}], + provider_specific_header=ProviderSpecificHeader( + custom_llm_provider="anthropic", + extra_headers={"anthropic-beta": "test"}, + ), + client=client, + ) + except Exception as e: + print(f"Error: {e}") + + mock_post.assert_called_once() + print(mock_post.call_args.kwargs["headers"]) + assert "anthropic-beta" in mock_post.call_args.kwargs["headers"] diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index a3de35a2abee..d73606761226 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1495,3 +1495,68 @@ def test_custom_openapi(mock_get_openapi_schema): openapi_schema = custom_openapi() assert openapi_schema is not None + + +def test_provider_specific_header(): + from litellm.proxy.litellm_pre_call_utils import ( + add_provider_specific_headers_to_request, + ) + + data = { + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Tell me a joke"}], + } + ], + "stream": True, + "proxy_server_request": { + "url": "http://0.0.0.0:4000/v1/chat/completions", + "method": "POST", + "headers": { + "content-type": "application/json", + "anthropic-beta": "prompt-caching-2024-07-31", + "user-agent": "PostmanRuntime/7.32.3", + "accept": "*/*", + "postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529", + "host": "0.0.0.0:4000", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "content-length": "240", + }, + "body": { + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Tell me a joke"}], + } + ], + "stream": True, + }, + }, + } + + headers = { + "content-type": "application/json", + "anthropic-beta": "prompt-caching-2024-07-31", + "user-agent": "PostmanRuntime/7.32.3", + "accept": "*/*", + "postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529", + "host": "0.0.0.0:4000", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "content-length": "240", + } + + add_provider_specific_headers_to_request( + data=data, + headers=headers, + ) + assert data["provider_specific_header"] == { + "custom_llm_provider": "anthropic", + "extra_headers": { + "anthropic-beta": "prompt-caching-2024-07-31", + }, + }