Skip to content

Commit

Permalink
fix(vertex): correct request options in retries
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jul 17, 2024
1 parent 565dfcd commit 460547b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 34 deletions.
68 changes: 36 additions & 32 deletions src/anthropic/lib/vertex/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._auth import load_auth, refresh_auth
from ..._types import NOT_GIVEN, NotGiven, Transport, ProxiesTypes, AsyncTransport
from ..._utils import is_dict, asyncify, is_given
from ..._compat import typed_cached_property
from ..._compat import model_copy, typed_cached_property
from ..._models import FinalRequestOptions
from ..._version import __version__
from ..._streaming import Stream, AsyncStream
Expand All @@ -37,37 +37,6 @@


class BaseVertexClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
@override
def _build_request(
self,
options: FinalRequestOptions,
) -> httpx.Request:
if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)

if options.url == "/v1/messages" and options.method == "post":
project_id = self.project_id
if project_id is None:
raise RuntimeError(
"No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set."
)

if not is_dict(options.json_data):
raise RuntimeError("Expected json data to be a dictionary for post /v1/messages")

model = options.json_data.pop("model")
stream = options.json_data.get("stream", False)
specifier = "streamRawPredict" if stream else "rawPredict"

options.url = (
f"/projects/{self.project_id}/locations/{self.region}/publishers/anthropic/models/{model}:{specifier}"
)

if is_dict(options.json_data):
options.json_data.pop("model", None)

return super()._build_request(options)

@typed_cached_property
def region(self) -> str:
raise RuntimeError("region not set")
Expand Down Expand Up @@ -174,6 +143,10 @@ def __init__(

self.messages = Messages(self)

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options, project_id=self.project_id, region=self.region)

@override
def _prepare_request(self, request: httpx.Request) -> None:
if request.headers.get("Authorization"):
Expand Down Expand Up @@ -336,6 +309,10 @@ def __init__(

self.messages = AsyncMessages(self)

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
return _prepare_options(options, project_id=self.project_id, region=self.region)

@override
async def _prepare_request(self, request: httpx.Request) -> None:
if request.headers.get("Authorization"):
Expand Down Expand Up @@ -436,3 +413,30 @@ def copy(
# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy


def _prepare_options(input_options: FinalRequestOptions, *, project_id: str | None, region: str) -> FinalRequestOptions:
options = model_copy(input_options, deep=True)

if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)

if options.url == "/v1/messages" and options.method == "post":
if project_id is None:
raise RuntimeError(
"No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set."
)

if not is_dict(options.json_data):
raise RuntimeError("Expected json data to be a dictionary for post /v1/messages")

model = options.json_data.pop("model")
stream = options.json_data.get("stream", False)
specifier = "streamRawPredict" if stream else "rawPredict"

options.url = f"/projects/{project_id}/locations/{region}/publishers/anthropic/models/{model}:{specifier}"

if is_dict(options.json_data):
options.json_data.pop("model", None)

return options
68 changes: 66 additions & 2 deletions tests/lib/test_vertex.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
from __future__ import annotations

import os
from typing import cast
from typing_extensions import Protocol

import httpx
import pytest
from respx import MockRouter

from anthropic import AnthropicVertex, AsyncAnthropicVertex

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


class MockRequestCall(Protocol):
request: httpx.Request


class TestAnthropicVertex:
client = AnthropicVertex(region="region", project_id="project")
client = AnthropicVertex(region="region", project_id="project", access_token="my-access-token")

@pytest.mark.respx()
def test_messages_retries(self, respx_mock: MockRouter) -> None:
request_url = "https://region-aiplatform.googleapis.com/v1/projects/project/locations/region/publishers/anthropic/models/claude-3-sonnet@20240229:rawPredict"
respx_mock.post(request_url).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}, headers={"retry-after-ms": "10"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

self.client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="claude-3-sonnet@20240229",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.url == request_url
assert calls[1].request.url == request_url

def test_copy(self) -> None:
copied = self.client.copy()
Expand Down Expand Up @@ -86,7 +121,36 @@ def test_copy_default_headers(self) -> None:


class TestAsyncAnthropicVertex:
client = AsyncAnthropicVertex(region="region", project_id="project")
client = AsyncAnthropicVertex(region="region", project_id="project", access_token="my-access-token")

@pytest.mark.respx()
@pytest.mark.asyncio()
async def test_messages_retries(self, respx_mock: MockRouter) -> None:
request_url = "https://region-aiplatform.googleapis.com/v1/projects/project/locations/region/publishers/anthropic/models/claude-3-sonnet@20240229:rawPredict"
respx_mock.post(request_url).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}, headers={"retry-after-ms": "10"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

await self.client.with_options(timeout=0.2).messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="claude-3-sonnet@20240229",
)

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.url == request_url
assert calls[1].request.url == request_url

def test_copy(self) -> None:
copied = self.client.copy()
Expand Down

0 comments on commit 460547b

Please sign in to comment.