Skip to content

Commit

Permalink
feat(client): send retry count header (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-bot committed Sep 19, 2024
1 parent 84ad451 commit 17c26d5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 47 deletions.
101 changes: 54 additions & 47 deletions src/anthropic/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,7 @@ def _make_status_error(
) -> _exceptions.APIStatusError:
raise NotImplementedError()

def _remaining_retries(
self,
remaining_retries: Optional[int],
options: FinalRequestOptions,
) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)

def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)
Expand All @@ -420,6 +413,8 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

headers.setdefault("x-stainless-retry-count", str(retries_taken))

return headers

def _prepare_url(self, url: str) -> URL:
Expand All @@ -441,6 +436,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
def _build_request(
self,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
Expand All @@ -456,7 +453,7 @@ def _build_request(
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")

headers = self._build_headers(options)
headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
files = options.files
Expand Down Expand Up @@ -939,20 +936,25 @@ def request(
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
if remaining_retries is not None:
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
else:
retries_taken = 0

return self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
retries_taken=retries_taken,
)

def _request(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: int | None,
retries_taken: int,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
Expand All @@ -964,8 +966,8 @@ def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
request = self._build_request(options, retries_taken=retries_taken)
self._prepare_request(request)

kwargs: HttpxSendArgs = {}
Expand All @@ -983,11 +985,11 @@ def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -998,11 +1000,11 @@ def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1026,13 +1028,13 @@ def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)

if retries > 0 and self._should_retry(err.response):
if remaining_retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
input_options,
cast_to,
retries,
err.response.headers,
retries_taken=retries_taken,
response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
Expand All @@ -1051,26 +1053,26 @@ def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
retries_taken=retries_taken,
)

def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
retries_taken: int,
response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
remaining = remaining_retries - 1
if remaining == 1:
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining, options, response_headers)
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
Expand All @@ -1080,7 +1082,7 @@ def _retry_request(
return self._request(
options=options,
cast_to=cast_to,
remaining_retries=remaining,
retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
Expand Down Expand Up @@ -1512,12 +1514,17 @@ async def request(
stream_cls: type[_AsyncStreamT] | None = None,
remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT:
if remaining_retries is not None:
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
else:
retries_taken = 0

return await self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
remaining_retries=remaining_retries,
retries_taken=retries_taken,
)

async def _request(
Expand All @@ -1527,7 +1534,7 @@ async def _request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
remaining_retries: int | None,
retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
Expand All @@ -1542,8 +1549,8 @@ async def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = await self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
request = self._build_request(options, retries_taken=retries_taken)
await self._prepare_request(request)

kwargs: HttpxSendArgs = {}
Expand All @@ -1559,11 +1566,11 @@ async def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)

if retries > 0:
if remaining_retries > 0:
return await self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1574,11 +1581,11 @@ async def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)

if retries > 0:
if retries_taken > 0:
return await self._retry_request(
input_options,
cast_to,
retries,
retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
Expand All @@ -1596,13 +1603,13 @@ async def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)

if retries > 0 and self._should_retry(err.response):
if remaining_retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
input_options,
cast_to,
retries,
err.response.headers,
retries_taken=retries_taken,
response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
Expand All @@ -1621,34 +1628,34 @@ async def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
retries_taken=retries_taken,
)

async def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: httpx.Headers | None,
*,
retries_taken: int,
response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
) -> ResponseT | _AsyncStreamT:
remaining = remaining_retries - 1
if remaining == 1:
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining)
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining, options, response_headers)
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)

await anyio.sleep(timeout)

return await self._request(
options=options,
cast_to=cast_to,
remaining_retries=remaining,
retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
Expand Down Expand Up @@ -929,6 +930,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
model="claude-3-5-sonnet-20240620",
) as response:
assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success


class TestAsyncAnthropic:
Expand Down Expand Up @@ -1798,6 +1800,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
Expand Down Expand Up @@ -1830,3 +1833,4 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
model="claude-3-5-sonnet-20240620",
) as response:
assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

0 comments on commit 17c26d5

Please sign in to comment.