From eea1662c5b35edb1e3410848a6dc7943645e99e8 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Mon, 8 Jul 2024 09:16:04 +0100 Subject: [PATCH] fix(vertex): avoid credentials refresh on every request (#575) --- src/anthropic/lib/vertex/_client.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/anthropic/lib/vertex/_client.py b/src/anthropic/lib/vertex/_client.py index 9d695524..578cb559 100644 --- a/src/anthropic/lib/vertex/_client.py +++ b/src/anthropic/lib/vertex/_client.py @@ -168,13 +168,11 @@ def __init__( @override def _prepare_request(self, request: httpx.Request) -> None: - access_token = self._ensure_access_token() - if request.headers.get("Authorization"): # already authenticated, nothing for us to do return - request.headers["Authorization"] = f"Bearer {access_token}" + request.headers["Authorization"] = f"Bearer {self._ensure_access_token()}" def _ensure_access_token(self) -> str: if self.access_token is not None: @@ -184,7 +182,8 @@ def _ensure_access_token(self) -> str: self.credentials, project_id = load_auth(project_id=self.project_id) if not self.project_id: self.project_id = project_id - else: + + if self.credentials.expired: refresh_auth(self.credentials) if not self.credentials.token: @@ -256,13 +255,11 @@ def __init__( @override async def _prepare_request(self, request: httpx.Request) -> None: - access_token = await self._ensure_access_token() - if request.headers.get("Authorization"): # already authenticated, nothing for us to do return - request.headers["Authorization"] = f"Bearer {access_token}" + request.headers["Authorization"] = f"Bearer {await self._ensure_access_token()}" async def _ensure_access_token(self) -> str: if self.access_token is not None: @@ -272,11 +269,12 @@ async def _ensure_access_token(self) -> str: self.credentials, project_id = await asyncify(load_auth)(project_id=self.project_id) if not self.project_id: self.project_id = project_id - else: + + if self.credentials.expired: await asyncify(refresh_auth)(self.credentials) if not self.credentials.token: raise RuntimeError("Could not resolve API token from the environment") assert isinstance(self.credentials.token, str) - return self.credentials.token \ No newline at end of file + return self.credentials.token