Skip to content

Commit

Permalink
feat(vertex): add copy and with_options (#578)
Browse files Browse the repository at this point in the history
* feat(vertex): add copy and with_options

Closes #566

* move vertex client tests to a separate file

* add missing `credentials` argument to `copy()`

* minor cleanup

---------

Co-authored-by: Robert Craigie <robert@craigie.dev>
  • Loading branch information
2 people authored and stainless-app[bot] committed Jul 10, 2024
1 parent fe6f0cc commit 1cf40ae
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 4 deletions.
166 changes: 162 additions & 4 deletions src/anthropic/lib/vertex/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar
from typing_extensions import override
from typing_extensions import Self, override

import httpx

Expand All @@ -15,7 +15,15 @@
from ..._version import __version__
from ..._streaming import Stream, AsyncStream
from ..._exceptions import APIStatusError
from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient
from ..._base_client import (
DEFAULT_MAX_RETRIES,
DEFAULT_CONNECTION_LIMITS,
BaseClient,
SyncAPIClient,
AsyncAPIClient,
SyncHttpxClientWrapper,
AsyncHttpxClientWrapper,
)
from ...resources.messages import Messages, AsyncMessages

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,6 +123,7 @@ def __init__(
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -128,7 +137,6 @@ def __init__(
proxies: ProxiesTypes | None = None,
# See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration)
connection_pool_limits: httpx.Limits | None = None,
credentials: GoogleCredentials | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
Expand Down Expand Up @@ -192,6 +200,81 @@ def _ensure_access_token(self) -> str:
assert isinstance(self.credentials.token, str)
return self.credentials.token

def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
connection_pool_limits: httpx.Limits | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")

headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers

params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query

if connection_pool_limits is not None:
if http_client is not None:
raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'")

if not isinstance(self._client, SyncHttpxClientWrapper):
raise ValueError(
"A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument"
)

http_client = None
else:
if self._limits is not DEFAULT_CONNECTION_LIMITS:
connection_pool_limits = self._limits
else:
connection_pool_limits = None

http_client = http_client or self._client

return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)

# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy


class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient):
messages: AsyncMessages
Expand All @@ -202,6 +285,7 @@ def __init__(
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -215,7 +299,6 @@ def __init__(
proxies: ProxiesTypes | None = None,
# See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration)
connection_pool_limits: httpx.Limits | None = None,
credentials: GoogleCredentials | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
Expand Down Expand Up @@ -278,3 +361,78 @@ async def _ensure_access_token(self) -> str:

assert isinstance(self.credentials.token, str)
return self.credentials.token

def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
connection_pool_limits: httpx.Limits | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")

headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers

params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query

if connection_pool_limits is not None:
if http_client is not None:
raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'")

if not isinstance(self._client, AsyncHttpxClientWrapper):
raise ValueError(
"A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument"
)

http_client = None
else:
if self._limits is not DEFAULT_CONNECTION_LIMITS:
connection_pool_limits = self._limits
else:
connection_pool_limits = None

http_client = http_client or self._client

return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)

# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
160 changes: 160 additions & 0 deletions tests/lib/test_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

import os

import httpx
import pytest

from anthropic import AnthropicVertex, AsyncAnthropicVertex

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


class TestAnthropicVertex:
client = AnthropicVertex(region="region", project_id="project")

def test_copy(self) -> None:
copied = self.client.copy()
assert id(copied) != id(self.client)

copied = self.client.copy(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_with_options(self) -> None:
copied = self.client.with_options(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
assert self.client.max_retries == 2

copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7

# timeout
assert isinstance(self.client.timeout, httpx.Timeout)
copied = self.client.copy(timeout=None)
assert copied.timeout is None
assert isinstance(self.client.timeout, httpx.Timeout)

def test_copy_default_headers(self) -> None:
client = AnthropicVertex(
base_url=base_url,
region="region",
project_id="project",
_strict_response_validation=True,
default_headers={"X-Foo": "bar"},
)
assert client.default_headers["X-Foo"] == "bar"

# does not override the already given value when not specified
copied = client.copy()
assert copied.default_headers["X-Foo"] == "bar"

# merges already given headers
copied = client.copy(default_headers={"X-Bar": "stainless"})
assert copied.default_headers["X-Foo"] == "bar"
assert copied.default_headers["X-Bar"] == "stainless"

# uses new values for any already given headers
copied = client.copy(default_headers={"X-Foo": "stainless"})
assert copied.default_headers["X-Foo"] == "stainless"

# set_default_headers

# completely overrides already set values
copied = client.copy(set_default_headers={})
assert copied.default_headers.get("X-Foo") is None

copied = client.copy(set_default_headers={"X-Bar": "Robert"})
assert copied.default_headers["X-Bar"] == "Robert"

with pytest.raises(
ValueError,
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})


class TestAsyncAnthropicVertex:
client = AsyncAnthropicVertex(region="region", project_id="project")

def test_copy(self) -> None:
copied = self.client.copy()
assert id(copied) != id(self.client)

copied = self.client.copy(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_with_options(self) -> None:
copied = self.client.with_options(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
assert self.client.max_retries == 2

copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7

# timeout
assert isinstance(self.client.timeout, httpx.Timeout)
copied = self.client.copy(timeout=None)
assert copied.timeout is None
assert isinstance(self.client.timeout, httpx.Timeout)

def test_copy_default_headers(self) -> None:
client = AsyncAnthropicVertex(
base_url=base_url,
region="region",
project_id="project",
_strict_response_validation=True,
default_headers={"X-Foo": "bar"},
)
assert client.default_headers["X-Foo"] == "bar"

# does not override the already given value when not specified
copied = client.copy()
assert copied.default_headers["X-Foo"] == "bar"

# merges already given headers
copied = client.copy(default_headers={"X-Bar": "stainless"})
assert copied.default_headers["X-Foo"] == "bar"
assert copied.default_headers["X-Bar"] == "stainless"

# uses new values for any already given headers
copied = client.copy(default_headers={"X-Foo": "stainless"})
assert copied.default_headers["X-Foo"] == "stainless"

# set_default_headers

# completely overrides already set values
copied = client.copy(set_default_headers={})
assert copied.default_headers.get("X-Foo") is None

copied = client.copy(set_default_headers={"X-Bar": "Robert"})
assert copied.default_headers["X-Bar"] == "Robert"

with pytest.raises(
ValueError,
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})

0 comments on commit 1cf40ae

Please sign in to comment.