From f52b5fece542e77ce149fd8036dc92e4c1efb295 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Mon, 15 Jul 2024 11:16:23 +0100 Subject: [PATCH] fix(bedrock): correct request options for retries (#593) --- src/anthropic/lib/bedrock/_client.py | 44 +++++++------ tests/lib/test_bedrock.py | 93 ++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 18 deletions(-) create mode 100644 tests/lib/test_bedrock.py diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index b3f388e5..f7298adc 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -9,6 +9,7 @@ from ... import _exceptions from ..._types import NOT_GIVEN, Timeout, NotGiven from ..._utils import is_dict, is_given +from ..._compat import model_copy from ..._version import __version__ from ..._streaming import Stream, AsyncStream from ..._exceptions import APIStatusError @@ -29,28 +30,27 @@ _DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) -class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): - @override - def _build_request( - self, - options: FinalRequestOptions, - ) -> httpx.Request: - if is_dict(options.json_data): - options.json_data.setdefault("anthropic_version", DEFAULT_VERSION) +def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions: + options = model_copy(input_options, deep=True) + + if is_dict(options.json_data): + options.json_data.setdefault("anthropic_version", DEFAULT_VERSION) + + if options.url in {"/v1/complete", "/v1/messages"} and options.method == "post": + if not is_dict(options.json_data): + raise RuntimeError("Expected dictionary json_data for post /completions endpoint") - if options.url in {"/v1/complete", "/v1/messages"} and options.method == "post": - if not is_dict(options.json_data): - raise RuntimeError("Expected dictionary json_data for post /completions endpoint") + model = options.json_data.pop("model", None) + stream = options.json_data.pop("stream", False) + if stream: + options.url = f"/model/{model}/invoke-with-response-stream" + else: + options.url = f"/model/{model}/invoke" - model = options.json_data.pop("model", None) - stream = options.json_data.pop("stream", False) - if stream: - options.url = f"/model/{model}/invoke-with-response-stream" - else: - options.url = f"/model/{model}/invoke" + return options - return super()._build_request(options) +class BaseBedrockClient(BaseClient[_HttpxClientT, _DefaultStreamT]): @override def _make_status_error( self, @@ -145,6 +145,10 @@ def __init__( def _make_sse_decoder(self) -> AWSEventStreamDecoder: return AWSEventStreamDecoder() + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + return _prepare_options(options) + @override def _prepare_request(self, request: httpx.Request) -> None: from ._auth import get_auth_headers @@ -280,6 +284,10 @@ def __init__( def _make_sse_decoder(self) -> AWSEventStreamDecoder: return AWSEventStreamDecoder() + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + return _prepare_options(options) + @override async def _prepare_request(self, request: httpx.Request) -> None: from ._auth import get_auth_headers diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py new file mode 100644 index 00000000..d19d9098 --- /dev/null +++ b/tests/lib/test_bedrock.py @@ -0,0 +1,93 @@ +import re +from typing import cast +from typing_extensions import Protocol + +import httpx +import pytest +from respx import MockRouter + +from anthropic import AnthropicBedrock, AsyncAnthropicBedrock + +sync_client = AnthropicBedrock( + aws_region="us-east-1", + aws_access_key="example-access-key", + aws_secret_key="example-secret-key", +) +async_client = AsyncAnthropicBedrock( + aws_region="us-east-1", + aws_access_key="example-access-key", + aws_secret_key="example-secret-key", +) + + +class MockRequestCall(Protocol): + request: httpx.Request + + +@pytest.mark.respx() +def test_messages_retries(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime.us-east-1.amazonaws.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + sync_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-sonnet-20240229-v1:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke" + ) + assert ( + calls[1].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke" + ) + + +@pytest.mark.respx() +@pytest.mark.asyncio() +async def test_messages_retries_async(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime.us-east-1.amazonaws.com/model/.*/invoke")).mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + await async_client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-sonnet-20240229-v1:0", + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + + assert len(calls) == 2 + + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke" + ) + assert ( + calls[1].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke" + )