Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertex): add copy and with_options #578

Merged
merged 4 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions src/anthropic/lib/vertex/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar
from typing_extensions import override
from typing_extensions import Self, override

import httpx

Expand All @@ -15,7 +15,15 @@
from ..._version import __version__
from ..._streaming import Stream, AsyncStream
from ..._exceptions import APIStatusError
from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient
from ..._base_client import (
DEFAULT_MAX_RETRIES,
DEFAULT_CONNECTION_LIMITS,
BaseClient,
SyncAPIClient,
AsyncAPIClient,
SyncHttpxClientWrapper,
AsyncHttpxClientWrapper,
)
from ...resources.messages import Messages, AsyncMessages

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,6 +123,7 @@ def __init__(
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -128,7 +137,6 @@ def __init__(
proxies: ProxiesTypes | None = None,
# See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration)
connection_pool_limits: httpx.Limits | None = None,
credentials: GoogleCredentials | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
Expand Down Expand Up @@ -192,6 +200,81 @@ def _ensure_access_token(self) -> str:
assert isinstance(self.credentials.token, str)
return self.credentials.token

def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
connection_pool_limits: httpx.Limits | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")

headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers

params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query

if connection_pool_limits is not None:
if http_client is not None:
raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'")

if not isinstance(self._client, SyncHttpxClientWrapper):
raise ValueError(
"A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument"
)

http_client = None
else:
if self._limits is not DEFAULT_CONNECTION_LIMITS:
connection_pool_limits = self._limits
else:
connection_pool_limits = None

http_client = http_client or self._client

return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)

# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy


class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient):
messages: AsyncMessages
Expand All @@ -202,6 +285,7 @@ def __init__(
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -215,7 +299,6 @@ def __init__(
proxies: ProxiesTypes | None = None,
# See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration)
connection_pool_limits: httpx.Limits | None = None,
credentials: GoogleCredentials | None = None,
_strict_response_validation: bool = False,
) -> None:
if not is_given(region):
Expand Down Expand Up @@ -278,3 +361,78 @@ async def _ensure_access_token(self) -> str:

assert isinstance(self.credentials.token, str)
return self.credentials.token

def copy(
self,
*,
region: str | NotGiven = NOT_GIVEN,
project_id: str | NotGiven = NOT_GIVEN,
access_token: str | None = None,
credentials: GoogleCredentials | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
connection_pool_limits: httpx.Limits | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")

if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")

headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers

params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query

if connection_pool_limits is not None:
if http_client is not None:
raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'")

if not isinstance(self._client, AsyncHttpxClientWrapper):
raise ValueError(
"A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument"
)

http_client = None
else:
if self._limits is not DEFAULT_CONNECTION_LIMITS:
connection_pool_limits = self._limits
else:
connection_pool_limits = None

http_client = http_client or self._client

return self.__class__(
region=region if is_given(region) else self.region,
project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN,
access_token=access_token or self.access_token,
credentials=credentials or self.credentials,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)

# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
160 changes: 160 additions & 0 deletions tests/lib/test_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

import os

import httpx
import pytest

from anthropic import AnthropicVertex, AsyncAnthropicVertex

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


class TestAnthropicVertex:
client = AnthropicVertex(region="region", project_id="project")

def test_copy(self) -> None:
copied = self.client.copy()
assert id(copied) != id(self.client)

copied = self.client.copy(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_with_options(self) -> None:
copied = self.client.with_options(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
assert self.client.max_retries == 2

copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7

# timeout
assert isinstance(self.client.timeout, httpx.Timeout)
copied = self.client.copy(timeout=None)
assert copied.timeout is None
assert isinstance(self.client.timeout, httpx.Timeout)

def test_copy_default_headers(self) -> None:
client = AnthropicVertex(
base_url=base_url,
region="region",
project_id="project",
_strict_response_validation=True,
default_headers={"X-Foo": "bar"},
)
assert client.default_headers["X-Foo"] == "bar"

# does not override the already given value when not specified
copied = client.copy()
assert copied.default_headers["X-Foo"] == "bar"

# merges already given headers
copied = client.copy(default_headers={"X-Bar": "stainless"})
assert copied.default_headers["X-Foo"] == "bar"
assert copied.default_headers["X-Bar"] == "stainless"

# uses new values for any already given headers
copied = client.copy(default_headers={"X-Foo": "stainless"})
assert copied.default_headers["X-Foo"] == "stainless"

# set_default_headers

# completely overrides already set values
copied = client.copy(set_default_headers={})
assert copied.default_headers.get("X-Foo") is None

copied = client.copy(set_default_headers={"X-Bar": "Robert"})
assert copied.default_headers["X-Bar"] == "Robert"

with pytest.raises(
ValueError,
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})


class TestAsyncAnthropicVertex:
client = AsyncAnthropicVertex(region="region", project_id="project")

def test_copy(self) -> None:
copied = self.client.copy()
assert id(copied) != id(self.client)

copied = self.client.copy(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_with_options(self) -> None:
copied = self.client.with_options(region="another-region", project_id="another-project")
assert copied.region == "another-region"
assert self.client.region == "region"
assert copied.project_id == "another-project"
assert self.client.project_id == "project"

def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
assert self.client.max_retries == 2

copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7

# timeout
assert isinstance(self.client.timeout, httpx.Timeout)
copied = self.client.copy(timeout=None)
assert copied.timeout is None
assert isinstance(self.client.timeout, httpx.Timeout)

def test_copy_default_headers(self) -> None:
client = AsyncAnthropicVertex(
base_url=base_url,
region="region",
project_id="project",
_strict_response_validation=True,
default_headers={"X-Foo": "bar"},
)
assert client.default_headers["X-Foo"] == "bar"

# does not override the already given value when not specified
copied = client.copy()
assert copied.default_headers["X-Foo"] == "bar"

# merges already given headers
copied = client.copy(default_headers={"X-Bar": "stainless"})
assert copied.default_headers["X-Foo"] == "bar"
assert copied.default_headers["X-Bar"] == "stainless"

# uses new values for any already given headers
copied = client.copy(default_headers={"X-Foo": "stainless"})
assert copied.default_headers["X-Foo"] == "stainless"

# set_default_headers

# completely overrides already set values
copied = client.copy(set_default_headers={})
assert copied.default_headers.get("X-Foo") is None

copied = client.copy(set_default_headers={"X-Bar": "Robert"})
assert copied.default_headers["X-Bar"] == "Robert"

with pytest.raises(
ValueError,
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})