From b9827c13fb15995bfe355d91a2856495c6dd64a9 Mon Sep 17 00:00:00 2001 From: David Volquartz Lebech Date: Mon, 8 Jul 2024 10:35:37 +0200 Subject: [PATCH] feat(vertex): add copy and with_options (#578) * feat(vertex): add copy and with_options Closes #566 * move vertex client tests to a separate file * add missing `credentials` argument to `copy()` * minor cleanup --------- Co-authored-by: Robert Craigie --- src/anthropic/lib/vertex/_client.py | 166 +++++++++++++++++++++++++++- tests/lib/test_vertex.py | 160 +++++++++++++++++++++++++++ 2 files changed, 322 insertions(+), 4 deletions(-) create mode 100644 tests/lib/test_vertex.py diff --git a/src/anthropic/lib/vertex/_client.py b/src/anthropic/lib/vertex/_client.py index 578cb559..a513ff5d 100644 --- a/src/anthropic/lib/vertex/_client.py +++ b/src/anthropic/lib/vertex/_client.py @@ -2,7 +2,7 @@ import os from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar -from typing_extensions import override +from typing_extensions import Self, override import httpx @@ -15,7 +15,15 @@ from ..._version import __version__ from ..._streaming import Stream, AsyncStream from ..._exceptions import APIStatusError -from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient +from ..._base_client import ( + DEFAULT_MAX_RETRIES, + DEFAULT_CONNECTION_LIMITS, + BaseClient, + SyncAPIClient, + AsyncAPIClient, + SyncHttpxClientWrapper, + AsyncHttpxClientWrapper, +) from ...resources.messages import Messages, AsyncMessages if TYPE_CHECKING: @@ -115,6 +123,7 @@ def __init__( region: str | NotGiven = NOT_GIVEN, project_id: str | NotGiven = NOT_GIVEN, access_token: str | None = None, + credentials: GoogleCredentials | None = None, base_url: str | httpx.URL | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -128,7 +137,6 @@ def __init__( proxies: ProxiesTypes | None = None, # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) connection_pool_limits: httpx.Limits | None = None, - credentials: GoogleCredentials | None = None, _strict_response_validation: bool = False, ) -> None: if not is_given(region): @@ -192,6 +200,81 @@ def _ensure_access_token(self) -> str: assert isinstance(self.credentials.token, str) return self.credentials.token + def copy( + self, + *, + region: str | NotGiven = NOT_GIVEN, + project_id: str | NotGiven = NOT_GIVEN, + access_token: str | None = None, + credentials: GoogleCredentials | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + http_client: httpx.Client | None = None, + connection_pool_limits: httpx.Limits | None = None, + max_retries: int | NotGiven = NOT_GIVEN, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + + headers = self._custom_headers + if default_headers is not None: + headers = {**headers, **default_headers} + elif set_default_headers is not None: + headers = set_default_headers + + params = self._custom_query + if default_query is not None: + params = {**params, **default_query} + elif set_default_query is not None: + params = set_default_query + + if connection_pool_limits is not None: + if http_client is not None: + raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'") + + if not isinstance(self._client, SyncHttpxClientWrapper): + raise ValueError( + "A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument" + ) + + http_client = None + else: + if self._limits is not DEFAULT_CONNECTION_LIMITS: + connection_pool_limits = self._limits + else: + connection_pool_limits = None + + http_client = http_client or self._client + + return self.__class__( + region=region if is_given(region) else self.region, + project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN, + access_token=access_token or self.access_token, + credentials=credentials or self.credentials, + base_url=base_url or self.base_url, + timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, + http_client=http_client, + max_retries=max_retries if is_given(max_retries) else self.max_retries, + default_headers=headers, + default_query=params, + **_extra_kwargs, + ) + + # Alias for `copy` for nicer inline usage, e.g. + # client.with_options(timeout=10).foo.create(...) + with_options = copy + class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient): messages: AsyncMessages @@ -202,6 +285,7 @@ def __init__( region: str | NotGiven = NOT_GIVEN, project_id: str | NotGiven = NOT_GIVEN, access_token: str | None = None, + credentials: GoogleCredentials | None = None, base_url: str | httpx.URL | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -215,7 +299,6 @@ def __init__( proxies: ProxiesTypes | None = None, # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) connection_pool_limits: httpx.Limits | None = None, - credentials: GoogleCredentials | None = None, _strict_response_validation: bool = False, ) -> None: if not is_given(region): @@ -278,3 +361,78 @@ async def _ensure_access_token(self) -> str: assert isinstance(self.credentials.token, str) return self.credentials.token + + def copy( + self, + *, + region: str | NotGiven = NOT_GIVEN, + project_id: str | NotGiven = NOT_GIVEN, + access_token: str | None = None, + credentials: GoogleCredentials | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + http_client: httpx.AsyncClient | None = None, + connection_pool_limits: httpx.Limits | None = None, + max_retries: int | NotGiven = NOT_GIVEN, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + + headers = self._custom_headers + if default_headers is not None: + headers = {**headers, **default_headers} + elif set_default_headers is not None: + headers = set_default_headers + + params = self._custom_query + if default_query is not None: + params = {**params, **default_query} + elif set_default_query is not None: + params = set_default_query + + if connection_pool_limits is not None: + if http_client is not None: + raise ValueError("The 'http_client' argument is mutually exclusive with 'connection_pool_limits'") + + if not isinstance(self._client, AsyncHttpxClientWrapper): + raise ValueError( + "A custom HTTP client has been set and is mutually exclusive with the 'connection_pool_limits' argument" + ) + + http_client = None + else: + if self._limits is not DEFAULT_CONNECTION_LIMITS: + connection_pool_limits = self._limits + else: + connection_pool_limits = None + + http_client = http_client or self._client + + return self.__class__( + region=region if is_given(region) else self.region, + project_id=project_id if is_given(project_id) else self.project_id or NOT_GIVEN, + access_token=access_token or self.access_token, + credentials=credentials or self.credentials, + base_url=base_url or self.base_url, + timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, + http_client=http_client, + max_retries=max_retries if is_given(max_retries) else self.max_retries, + default_headers=headers, + default_query=params, + **_extra_kwargs, + ) + + # Alias for `copy` for nicer inline usage, e.g. + # client.with_options(timeout=10).foo.create(...) + with_options = copy diff --git a/tests/lib/test_vertex.py b/tests/lib/test_vertex.py new file mode 100644 index 00000000..2f741c3f --- /dev/null +++ b/tests/lib/test_vertex.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os + +import httpx +import pytest + +from anthropic import AnthropicVertex, AsyncAnthropicVertex + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestAnthropicVertex: + client = AnthropicVertex(region="region", project_id="project") + + def test_copy(self) -> None: + copied = self.client.copy() + assert id(copied) != id(self.client) + + copied = self.client.copy(region="another-region", project_id="another-project") + assert copied.region == "another-region" + assert self.client.region == "region" + assert copied.project_id == "another-project" + assert self.client.project_id == "project" + + def test_with_options(self) -> None: + copied = self.client.with_options(region="another-region", project_id="another-project") + assert copied.region == "another-region" + assert self.client.region == "region" + assert copied.project_id == "another-project" + assert self.client.project_id == "project" + + def test_copy_default_options(self) -> None: + # options that have a default are overridden correctly + copied = self.client.copy(max_retries=7) + assert copied.max_retries == 7 + assert self.client.max_retries == 2 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(self.client.timeout, httpx.Timeout) + copied = self.client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(self.client.timeout, httpx.Timeout) + + def test_copy_default_headers(self) -> None: + client = AnthropicVertex( + base_url=base_url, + region="region", + project_id="project", + _strict_response_validation=True, + default_headers={"X-Foo": "bar"}, + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + + +class TestAsyncAnthropicVertex: + client = AsyncAnthropicVertex(region="region", project_id="project") + + def test_copy(self) -> None: + copied = self.client.copy() + assert id(copied) != id(self.client) + + copied = self.client.copy(region="another-region", project_id="another-project") + assert copied.region == "another-region" + assert self.client.region == "region" + assert copied.project_id == "another-project" + assert self.client.project_id == "project" + + def test_with_options(self) -> None: + copied = self.client.with_options(region="another-region", project_id="another-project") + assert copied.region == "another-region" + assert self.client.region == "region" + assert copied.project_id == "another-project" + assert self.client.project_id == "project" + + def test_copy_default_options(self) -> None: + # options that have a default are overridden correctly + copied = self.client.copy(max_retries=7) + assert copied.max_retries == 7 + assert self.client.max_retries == 2 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(self.client.timeout, httpx.Timeout) + copied = self.client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(self.client.timeout, httpx.Timeout) + + def test_copy_default_headers(self) -> None: + client = AsyncAnthropicVertex( + base_url=base_url, + region="region", + project_id="project", + _strict_response_validation=True, + default_headers={"X-Foo": "bar"}, + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})