Skip to content

Commit

Permalink
Litellm dev 01 22 2025 p4 (#7932)
Browse files Browse the repository at this point in the history
* feat(main.py): add new 'provider_specific_header' param

allows passing extra header for specific provider

* fix(litellm_pre_call_utils.py): add unit test for pre call utils

* test(test_bedrock_completion.py): skip test now that bedrock supports this
  • Loading branch information
krrishdholakia authored Jan 23, 2025
1 parent 4911cd8 commit 27560bd
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 5 deletions.
11 changes: 11 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
HiddenParams,
LlmProviders,
PromptTokensDetails,
ProviderSpecificHeader,
all_litellm_params,
)

Expand Down Expand Up @@ -832,6 +833,9 @@ def completion( # type: ignore # noqa: PLR0915
model_info = kwargs.get("model_info", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
provider_specific_header = cast(
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
)
headers = kwargs.get("headers", None) or extra_headers
ensure_alternating_roles: Optional[bool] = kwargs.get(
"ensure_alternating_roles", None
Expand Down Expand Up @@ -937,6 +941,13 @@ def completion( # type: ignore # noqa: PLR0915
api_base=api_base,
api_key=api_key,
)

if (
provider_specific_header is not None
and provider_specific_header["custom_llm_provider"] == custom_llm_provider
):
headers.update(provider_specific_header["extra_headers"])

if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
model_response._hidden_params["region_name"] = kwargs.get(
Expand Down
12 changes: 7 additions & 5 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
from litellm.types.services import ServiceTypes
from litellm.types.utils import (
ProviderSpecificHeader,
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls,
)
Expand Down Expand Up @@ -729,19 +730,20 @@ def add_provider_specific_headers_to_request(
data: dict,
headers: dict,
):

extra_headers = data.get("extra_headers", {}) or {}

anthropic_headers = {}
# boolean to indicate if a header was added
added_header = False
for header in ANTHROPIC_API_HEADERS:
if header in headers:
header_value = headers[header]
extra_headers.update({header: header_value})
anthropic_headers[header] = header_value
added_header = True

if added_header is True:
data["extra_headers"] = extra_headers
data["provider_specific_header"] = ProviderSpecificHeader(
custom_llm_provider="anthropic",
extra_headers=anthropic_headers,
)

return

Expand Down
6 changes: 6 additions & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
"api_key",
"api_version",
"prompt_id",
"provider_specific_header",
"prompt_variables",
"api_base",
"force_timeout",
Expand Down Expand Up @@ -1879,3 +1880,8 @@ class HttpHandlerRequestFields(TypedDict, total=False):
params: dict # query params
files: dict # file uploads
content: Any # raw content


class ProviderSpecificHeader(TypedDict):
custom_llm_provider: str
extra_headers: dict
30 changes: 30 additions & 0 deletions tests/local_testing/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4550,3 +4550,33 @@ def test_deepseek_reasoning_content_completion():
resp.choices[0].message.provider_specific_fields["reasoning_content"]
is not None
)


@pytest.mark.parametrize(
"custom_llm_provider, expected_result",
[("anthropic", {"anthropic-beta": "test"}), ("bedrock", {}), ("vertex_ai", {})],
)
def test_provider_specific_header(custom_llm_provider, expected_result):
from litellm.types.utils import ProviderSpecificHeader
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch

litellm.set_verbose = True
client = HTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
try:
resp = litellm.completion(
model="anthropic/claude-3-5-sonnet-v2@20241022",
messages=[{"role": "user", "content": "Hello world"}],
provider_specific_header=ProviderSpecificHeader(
custom_llm_provider="anthropic",
extra_headers={"anthropic-beta": "test"},
),
client=client,
)
except Exception as e:
print(f"Error: {e}")

mock_post.assert_called_once()
print(mock_post.call_args.kwargs["headers"])
assert "anthropic-beta" in mock_post.call_args.kwargs["headers"]
65 changes: 65 additions & 0 deletions tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,3 +1495,68 @@ def test_custom_openapi(mock_get_openapi_schema):

openapi_schema = custom_openapi()
assert openapi_schema is not None


def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request,
)

data = {
"model": "gemini-1.5-flash",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Tell me a joke"}],
}
],
"stream": True,
"proxy_server_request": {
"url": "http://0.0.0.0:4000/v1/chat/completions",
"method": "POST",
"headers": {
"content-type": "application/json",
"anthropic-beta": "prompt-caching-2024-07-31",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
"host": "0.0.0.0:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "240",
},
"body": {
"model": "gemini-1.5-flash",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Tell me a joke"}],
}
],
"stream": True,
},
},
}

headers = {
"content-type": "application/json",
"anthropic-beta": "prompt-caching-2024-07-31",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
"host": "0.0.0.0:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "240",
}

add_provider_specific_headers_to_request(
data=data,
headers=headers,
)
assert data["provider_specific_header"] == {
"custom_llm_provider": "anthropic",
"extra_headers": {
"anthropic-beta": "prompt-caching-2024-07-31",
},
}

0 comments on commit 27560bd

Please sign in to comment.