Skip to content

Commit

Permalink
(Feat) add bedrock/deepseek custom import models (#8132)
Browse files Browse the repository at this point in the history
* add support for using llama spec with bedrock

* fix get_bedrock_invoke_provider

* add support for using bedrock provider in mappings

* working request

* test_bedrock_custom_deepseek

* test_bedrock_custom_deepseek

* fix _get_model_id_for_llama_like_model

* test_bedrock_custom_deepseek

* doc DeepSeek-R1-Distill-Llama-70B

* test_bedrock_custom_deepseek
  • Loading branch information
ishaan-jaff authored Feb 1, 2025
1 parent 29a8a61 commit 9ff2780
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 17 deletions.
70 changes: 69 additions & 1 deletion docs/my-website/docs/providers/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Property | Details |
|-------|-------|
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy) |
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
Expand Down Expand Up @@ -1277,6 +1277,74 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
https://some-api-url/models
```

## Bedrock Imported Models (Deepseek)

| Property | Details |
|----------|---------|
| Provider Route | `bedrock/llama/{model_arn}` |
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |

Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec


<Tabs>
<TabItem value="sdk" label="SDK">

```python
from litellm import completion
import os
response = completion(
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
messages=[{"role": "user", "content": "Tell me a joke"}],
)
```

</TabItem>

<TabItem value="proxy" label="Proxy">


**1. Add to config**

```yaml
model_list:
- model_name: DeepSeek-R1-Distill-Llama-70B
litellm_params:
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
```

**2. Start proxy**

```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```

**3. Test it!**

```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}'
```

</TabItem>
</Tabs>



## Provisioned throughput models
To use provisioned throughput Bedrock models pass
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
Expand Down
3 changes: 3 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ def identify(event_details):
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
]
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere", "anthropic", "mistral", "amazon", "meta", "llama"
]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
Expand Down
92 changes: 81 additions & 11 deletions litellm/llms/bedrock/chat/invoke_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
Tuple,
Union,
cast,
get_args,
)

import httpx # type: ignore

import litellm
from litellm import verbose_logger
from litellm._logging import print_verbose
from litellm.caching.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
Expand Down Expand Up @@ -206,7 +208,7 @@ async def make_call(
api_key="",
data=data,
messages=messages,
print_verbose=litellm.print_verbose,
print_verbose=print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(
Expand Down Expand Up @@ -286,7 +288,7 @@ def convert_messages_to_prompt(
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
elif provider == "meta" or provider == "llama":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
Expand Down Expand Up @@ -318,7 +320,7 @@ def process_response( # noqa: PLR0915
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
provider = model.split(".")[0]
provider = self.get_bedrock_invoke_provider(model)
## LOGGING
logging_obj.post_call(
input=messages,
Expand Down Expand Up @@ -465,7 +467,7 @@ def process_response( # noqa: PLR0915
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta":
elif provider == "meta" or provider == "llama":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"]
Expand Down Expand Up @@ -597,13 +599,13 @@ def completion( # noqa: PLR0915

## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model

provider = model.split(".")[0]
provider = self.get_bedrock_invoke_provider(model)
modelId = self.get_bedrock_model_id(
model=model,
provider=provider,
optional_params=optional_params,
)

## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
Expand Down Expand Up @@ -785,7 +787,7 @@ def completion( # noqa: PLR0915
"textGenerationConfig": inference_params,
}
)
elif provider == "meta":
elif provider == "meta" or provider == "llama":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
Expand Down Expand Up @@ -1044,6 +1046,74 @@ async def async_streaming(
)
return streaming_response

@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 2 scenarions:
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
"""
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)

# If not a known provider, check for pattern with two slashes
provider = BedrockLLM._get_provider_from_model_path(model)
if provider is not None:
return provider
return None

@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None

def get_bedrock_model_id(
self,
optional_params: dict,
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
model: str,
) -> str:
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model

if provider == "llama" and "llama/" in modelId:
modelId = self._get_model_id_for_llama_like_model(modelId)

return modelId

def _get_model_id_for_llama_like_model(
self,
model: str,
) -> str:
"""
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
"""
model_id = model.replace("llama/", "")
return self.encode_model_id(model_id=model_id)


def get_response_stream_shape():
global _response_stream_shape_cache
Expand Down
13 changes: 8 additions & 5 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6045,20 +6045,23 @@ def get_provider_chat_config( # noqa: PLR0915
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
return litellm.AmazonConverseConfig()
elif "amazon" in model: # amazon titan llms
elif bedrock_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
elif "meta" in model: # amazon / meta llms
elif (
bedrock_provider == "meta" or bedrock_provider == "llama"
): # amazon / meta llms
return litellm.AmazonLlamaConfig()
elif "ai21" in model: # ai21 llms
elif bedrock_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config()
elif "cohere" in model: # cohere models on bedrock
elif bedrock_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig()
elif "mistral" in model: # mistral models on bedrock
elif bedrock_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig()
return litellm.OpenAIGPTConfig()

Expand Down
51 changes: 51 additions & 0 deletions tests/llm_translation/test_bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,3 +2529,54 @@ def test_bedrock_custom_proxy():
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"

assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"


def test_bedrock_custom_deepseek():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import json

litellm._turn_on_debug()
client = HTTPHandler()

with patch.object(client, "post") as mock_post:
# Mock the response
mock_response = Mock()
mock_response.text = json.dumps(
{"generation": "Here's a joke...", "stop_reason": "stop"}
)
mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response

try:
response = completion(
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # Updated to specify provider
messages=[{"role": "user", "content": "Tell me a joke"}],
max_tokens=100,
client=client,
)

# Print request details
print("\nRequest Details:")
print(f"URL: {mock_post.call_args.kwargs['url']}")

# Verify the URL
assert (
mock_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
)

# Verify the request body format
request_body = json.loads(mock_post.call_args.kwargs["data"])
print("request_body=", json.dumps(request_body, indent=4, default=str))
assert "prompt" in request_body
assert request_body["prompt"] == "Tell me a joke"

# follows the llama spec
assert request_body["max_gen_len"] == 100

except Exception as e:
print(f"Error: {str(e)}")
raise e

0 comments on commit 9ff2780

Please sign in to comment.