Skip to content

Commit

Permalink
fix(bedrock): correct request options for retries (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jul 15, 2024
1 parent d41a880 commit f52b5fe
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ... import _exceptions
from ..._types import NOT_GIVEN, Timeout, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import model_copy
from ..._version import __version__
from ..._streaming import Stream, AsyncStream
from ..._exceptions import APIStatusError
Expand All @@ -29,28 +30,27 @@
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _build_request(
self,
options: FinalRequestOptions,
) -> httpx.Request:
if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)
def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions:
options = model_copy(input_options, deep=True)

if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)

if options.url in {"/v1/complete", "/v1/messages"} and options.method == "post":
if not is_dict(options.json_data):
raise RuntimeError("Expected dictionary json_data for post /completions endpoint")

if options.url in {"/v1/complete", "/v1/messages"} and options.method == "post":
if not is_dict(options.json_data):
raise RuntimeError("Expected dictionary json_data for post /completions endpoint")
model = options.json_data.pop("model", None)
stream = options.json_data.pop("stream", False)
if stream:
options.url = f"/model/{model}/invoke-with-response-stream"
else:
options.url = f"/model/{model}/invoke"

model = options.json_data.pop("model", None)
stream = options.json_data.pop("stream", False)
if stream:
options.url = f"/model/{model}/invoke-with-response-stream"
else:
options.url = f"/model/{model}/invoke"
return options

return super()._build_request(options)

class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _make_status_error(
self,
Expand Down Expand Up @@ -145,6 +145,10 @@ def __init__(
def _make_sse_decoder(self) -> AWSEventStreamDecoder:
return AWSEventStreamDecoder()

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options)

@override
def _prepare_request(self, request: httpx.Request) -> None:
from ._auth import get_auth_headers
Expand Down Expand Up @@ -280,6 +284,10 @@ def __init__(
def _make_sse_decoder(self) -> AWSEventStreamDecoder:
return AWSEventStreamDecoder()

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options)

@override
async def _prepare_request(self, request: httpx.Request) -> None:
from ._auth import get_auth_headers
Expand Down
93 changes: 93 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import cast
from typing_extensions import Protocol

import httpx
import pytest
from respx import MockRouter

from anthropic import AnthropicBedrock, AsyncAnthropicBedrock

sync_client = AnthropicBedrock(
aws_region="us-east-1",
aws_access_key="example-access-key",
aws_secret_key="example-secret-key",
)
async_client = AsyncAnthropicBedrock(
aws_region="us-east-1",
aws_access_key="example-access-key",
aws_secret_key="example-secret-key",
)


class MockRequestCall(Protocol):
request: httpx.Request


@pytest.mark.respx()
def test_messages_retries(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime.us-east-1.amazonaws.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

sync_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke"
)
assert (
calls[1].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke"
)


@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_messages_retries_async(respx_mock: MockRouter) -> None:
respx_mock.post(re.compile(r"https://bedrock-runtime.us-east-1.amazonaws.com/model/.*/invoke")).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

await async_client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert (
calls[0].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke"
)
assert (
calls[1].request.url
== "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke"
)

0 comments on commit f52b5fe

Please sign in to comment.