Skip to content

Commit

Permalink
Respect retry_timeout_seconds config setting and align retry implem…
Browse files Browse the repository at this point in the history
…entation with Go SDK (#337)

This PR increases the robustness of the API client by porting the same
retry behavior as in Go SDK, including respecting the
`retry_timeout_seconds` config setting. This improvement is required to
make Python SDK compatible with multi-threaded applications that may
work in unstable networks.

---------

Signed-off-by: Serge Smertin <259697+nfx@users.noreply.github.com>
  • Loading branch information
nfx authored Oct 2, 2023
1 parent 3492167 commit 6254119
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 25 deletions.
120 changes: 95 additions & 25 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import subprocess
import sys
import urllib.parse
from datetime import datetime
from datetime import datetime, timedelta
from json import JSONDecodeError
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand All @@ -21,12 +21,12 @@
import requests
import requests.auth
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from .azure import (ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment,
add_sp_management_token, add_workspace_id_header)
from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable,
Token, TokenCache, TokenSource)
from .retries import retried
from .version import __version__

__all__ = ['Config', 'DatabricksError']
Expand Down Expand Up @@ -925,6 +925,7 @@ def __init__(self,
status: str = None,
scimType: str = None,
error: str = None,
retry_after_secs: int = None,
details: List[Dict[str, any]] = None,
**kwargs):
if error:
Expand All @@ -942,6 +943,7 @@ def __init__(self,
error_code = f"SCIM_{status}"
super().__init__(message if message else error)
self.error_code = error_code
self.retry_after_secs = retry_after_secs
self.details = [ErrorDetail.from_dict(detail) for detail in details] if details else []
self.kwargs = kwargs

Expand All @@ -963,25 +965,10 @@ def __init__(self, cfg: Config = None):
cfg = Config()

self._cfg = cfg
# See https://github.com/databricks/databricks-sdk-go/blob/main/client/client.go#L34-L35
self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96
self._retry_timeout_seconds = cfg.retry_timeout_seconds if cfg.retry_timeout_seconds else 300
self._user_agent_base = cfg.user_agent

# Since urllib3 v1.26.0, Retry.DEFAULT_METHOD_WHITELIST is deprecated in favor of
# Retry.DEFAULT_ALLOWED_METHODS. We need to support both versions.
if 'DEFAULT_ALLOWED_METHODS' in dir(Retry):
retry_kwargs = {'allowed_methods': {"POST"} | set(Retry.DEFAULT_ALLOWED_METHODS)}
else:
retry_kwargs = {'method_whitelist': {"POST"} | set(Retry.DEFAULT_METHOD_WHITELIST)}

retry_strategy = Retry(
total=6,
backoff_factor=1,
status_forcelist=[429],
respect_retry_after_header=True,
raise_on_status=False, # return original response when retries have been exhausted
**retry_kwargs,
)

self._session = requests.Session()
self._session.auth = self._authenticate

Expand All @@ -1004,8 +991,10 @@ def __init__(self, cfg: Config = None):
# Prevents platform from flooding. By default, requests library doesn't block.
pool_block = True

http_adapter = HTTPAdapter(max_retries=retry_strategy,
pool_connections=pool_connections,
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
# @retried for more details.
http_adapter = HTTPAdapter(pool_connections=pool_connections,
pool_maxsize=pool_maxsize,
pool_block=pool_block)
self._session.mount("https://", http_adapter)
Expand Down Expand Up @@ -1067,6 +1056,83 @@ def do(self,
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable)
return retryable(self._perform)(method,
path,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data)

@staticmethod
def _is_retryable(err: BaseException) -> Optional[str]:
# this method is Databricks-specific port of urllib3 retries
# (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py)
# and Databricks SDK for Go retries
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
from urllib3.exceptions import ProxyError
if isinstance(err, ProxyError):
err = err.original_error
if isinstance(err, requests.ConnectionError):
# corresponds to `connection reset by peer` and `connection refused` errors from Go,
# which are generally related to the temporary glitches in the networking stack,
# also caused by endpoint protection software, like ZScaler, to drop connections while
# not yet authenticated.
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'cannot connect'
if isinstance(err, requests.Timeout):
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
#
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
# will bubble up the original exception in case we reach max retries.
return f'timeout'
if isinstance(err, DatabricksError):
message = str(err)
transient_error_string_matches = [
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException",
"does not have any associated worker environments", "There is no worker environment with id",
"Unknown worker environment", "ClusterNotReadyException", "Unexpected error",
"Please try again later or try a faster operation."
]
for substring in transient_error_string_matches:
if substring not in message:
continue
return f'matched {substring}'
return None

@staticmethod
def _parse_retry_after(response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
return None
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return 1

def _perform(self,
method: str,
path: str,
query: dict = None,
headers: dict = None,
body: dict = None,
raw: bool = False,
files=None,
data=None):
response = self._session.request(method,
f"{self._cfg.host}{path}",
params=self._fix_query_string(query),
Expand All @@ -1077,11 +1143,11 @@ def do(self,
stream=raw)
try:
self._record_request_log(response, raw=raw or data is not None or files is not None)
if not response.ok:
if not response.ok: # internally calls response.raise_for_status()
# TODO: experiment with traceback pruning for better readability
# See https://stackoverflow.com/a/58821552/277035
payload = response.json()
raise self._make_nicer_error(status_code=response.status_code, **payload) from None
raise self._make_nicer_error(response=response, **payload) from None
if raw:
return StreamingResponse(response)
if not len(response.content):
Expand All @@ -1091,7 +1157,7 @@ def do(self,
message = self._make_sense_from_html(response.text)
if not message:
message = response.reason
raise self._make_nicer_error(message=message) from None
raise self._make_nicer_error(response=response, message=message) from None

@staticmethod
def _make_sense_from_html(txt: str) -> str:
Expand All @@ -1104,11 +1170,15 @@ def _make_sense_from_html(txt: str) -> str:
return match.group(1).strip()
return txt

def _make_nicer_error(self, status_code: int = 200, **kwargs) -> DatabricksError:
def _make_nicer_error(self, *, response: requests.Response, **kwargs) -> DatabricksError:
status_code = response.status_code
message = kwargs.get('message', 'request failed')
is_http_unauthorized_or_forbidden = status_code in (401, 403)
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_http_unauthorized_or_forbidden:
message = self._cfg.wrap_debug_info(message)
if is_too_many_requests_or_unavailable:
kwargs['retry_after_secs'] = self._parse_retry_after(response)
kwargs['message'] = message
return DatabricksError(**kwargs)

Expand Down
56 changes: 56 additions & 0 deletions databricks/sdk/retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import functools
import logging
import time
from datetime import timedelta
from random import random
from typing import Callable, List, Optional, Type

logger = logging.getLogger('databricks.sdk')


def retried(*,
on: List[Type[BaseException]] = None,
is_retryable: Callable[[BaseException], Optional[str]] = None,
timeout=timedelta(minutes=20)):
has_allowlist = on is not None
has_callback = is_retryable is not None
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
raise SyntaxError('either on=[Exception] or callback=lambda x: .. is required')

def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
deadline = time.time() + timeout.total_seconds()
attempt = 1
last_err = None
while time.time() < deadline:
try:
return func(*args, **kwargs)
except Exception as err:
last_err = err
retry_reason = None
# sleep 10s max per attempt, unless it's HTTP 429 or 503
sleep = min(10, attempt)
retry_after_secs = getattr(err, 'retry_after_secs', None)
if retry_after_secs is not None:
# cannot depend on DatabricksError directly because of circular dependency
sleep = retry_after_secs
retry_reason = 'throttled by platform'
elif is_retryable is not None:
retry_reason = is_retryable(err)
elif type(err) in on:
retry_reason = f'{type(err).__name__} is allowed to retry'

if retry_reason is None:
# raise if exception is not retryable
raise err

logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
time.sleep(sleep + random())
attempt += 1
raise TimeoutError(f'Timed out after {timeout}') from last_err

return wrapper

return decorator
Loading

0 comments on commit 6254119

Please sign in to comment.