Skip to content

Commit

Permalink
Retry GitHub download failures (#3729)
Browse files Browse the repository at this point in the history
* Retry GitHub download failures

* Refactor and add tests

* Fixed linting and added comment

* Fixing unit test assertRaises

Co-authored-by: Kyle Wigley <kyle@fishtownanalytics.com>

* Fixing casing

Co-authored-by: Kyle Wigley <kyle@fishtownanalytics.com>

* Changing to use partial for function calls

Co-authored-by: Kyle Wigley <kyle@fishtownanalytics.com>
  • Loading branch information
leahwicz and Kyle Wigley authored Aug 24, 2021
1 parent 7fa14b6 commit 09ea989
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 36 deletions.
39 changes: 11 additions & 28 deletions core/dbt/clients/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from functools import wraps
import functools
import requests
from dbt.exceptions import RegistryException
from dbt.utils import memoized
from dbt.utils import memoized, _connection_exception_retry as connection_exception_retry
from dbt.logger import GLOBAL_LOGGER as logger
import os
import time

if os.getenv('DBT_PACKAGE_HUB_URL'):
DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL')
Expand All @@ -19,26 +17,11 @@ def _get_url(url, registry_base_url=None):
return '{}{}'.format(registry_base_url, url)


def _wrap_exceptions(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
max_attempts = 5
attempt = 0
while True:
attempt += 1
try:
return fn(*args, **kwargs)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc:
if attempt < max_attempts:
time.sleep(1)
continue
raise RegistryException(
'Unable to connect to registry hub'
) from exc
return wrapper


@_wrap_exceptions
def _get_with_retries(path, registry_base_url=None):
get_fn = functools.partial(_get, path, registry_base_url)
return connection_exception_retry(get_fn, 5)


def _get(path, registry_base_url=None):
url = _get_url(path, registry_base_url)
logger.debug('Making package registry request: GET {}'.format(url))
Expand All @@ -50,22 +33,22 @@ def _get(path, registry_base_url=None):


def index(registry_base_url=None):
return _get('api/v1/index.json', registry_base_url)
return _get_with_retries('api/v1/index.json', registry_base_url)


index_cached = memoized(index)


def packages(registry_base_url=None):
return _get('api/v1/packages.json', registry_base_url)
return _get_with_retries('api/v1/packages.json', registry_base_url)


def package(name, registry_base_url=None):
return _get('api/v1/{}.json'.format(name), registry_base_url)
return _get_with_retries('api/v1/{}.json'.format(name), registry_base_url)


def package_version(name, version, registry_base_url=None):
return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url)
return _get_with_retries('api/v1/{}/{}.json'.format(name, version), registry_base_url)


def get_available_versions(name):
Expand Down
11 changes: 9 additions & 2 deletions core/dbt/clients/system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import errno
import functools
import fnmatch
import json
import os
Expand All @@ -15,9 +16,8 @@
)

import dbt.exceptions
import dbt.utils

from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import _connection_exception_retry as connection_exception_retry

if sys.platform == 'win32':
from ctypes import WinDLL, c_bool
Expand Down Expand Up @@ -441,6 +441,13 @@ def run_cmd(
return out, err


def download_with_retries(
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
) -> None:
download_fn = functools.partial(download, url, path, timeout)
connection_exception_retry(download_fn, 5)


def download(
url: str, path: str, timeout: Optional[Union[float, tuple]] = None
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/deps/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def install(self, project, renderer):
system.make_directory(os.path.dirname(tar_path))

download_url = metadata.downloads.tarball
system.download(download_url, tar_path)
system.download_with_retries(download_url, tar_path)
deps_path = project.modules_path
package_name = self.get_project_name(project, renderer)
system.untar_package(tar_path, deps_path, package_name)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def system_error(operation_name):
.format(operation_name))


class RegistryException(Exception):
class ConnectionException(Exception):
pass


Expand Down
21 changes: 21 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import jinja2
import json
import os
import requests
import time

from contextlib import contextmanager
from dbt.exceptions import ConnectionException
from dbt.logger import GLOBAL_LOGGER as logger
from enum import Enum
from typing_extensions import Protocol
from typing import (
Expand Down Expand Up @@ -602,3 +607,19 @@ def __getitem__(self, name: str) -> Any:

def __contains__(self, name) -> bool:
return any((name in entry for entry in self._itersource()))


def _connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
"""Attempts to run a function that makes an external call, if the call fails
on a connection error or timeout, it will be tried up to 5 more times.
"""
try:
return fn()
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc:
if attempt <= max_attempts - 1:
logger.debug('Retrying external call. Attempt: ' +
f'{attempt} Max attempts: {max_attempts}')
time.sleep(1)
_connection_exception_retry(fn, max_attempts, attempt + 1)
else:
raise ConnectionException('External connection exception occurred: ' + str(exc))
42 changes: 42 additions & 0 deletions test/unit/test_core_dbt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import requests
import unittest

from dbt.exceptions import ConnectionException
from dbt.utils import _connection_exception_retry as connection_exception_retry


class TestCoreDbtUtils(unittest.TestCase):
def test_connection_exception_retry_none(self):
Counter._reset()
connection_exception_retry(lambda: Counter._add(), 5)
self.assertEqual(1, counter)

def test_connection_exception_retry_max(self):
Counter._reset()
with self.assertRaises(ConnectionException):
connection_exception_retry(lambda: Counter._add_with_exception(), 5)
self.assertEqual(6, counter) # 6 = original attempt plus 5 retries

def test_connection_exception_retry_success(self):
Counter._reset()
connection_exception_retry(lambda: Counter._add_with_limited_exception(), 5)
self.assertEqual(2, counter) # 2 = original attempt plus 1 retry


counter:int = 0
class Counter():
def _add():
global counter
counter+=1
def _add_with_exception():
global counter
counter+=1
raise requests.exceptions.ConnectionError
def _add_with_limited_exception():
global counter
counter+=1
if counter < 2:
raise requests.exceptions.ConnectionError
def _reset():
global counter
counter = 0
8 changes: 4 additions & 4 deletions test/unit/test_registry_get_request_exception.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest

from dbt.exceptions import RegistryException
from dbt.clients.registry import _get
from dbt.exceptions import ConnectionException
from dbt.clients.registry import _get_with_retries

class testRegistryGetRequestException(unittest.TestCase):
def test_registry_request_error_catching(self):
# using non routable IP to test connection error logic in the _get function
self.assertRaises(RegistryException, _get, '', 'http://0.0.0.0')
# using non routable IP to test connection error logic in the _get_with_retries function
self.assertRaises(ConnectionException, _get_with_retries, '', 'http://0.0.0.0')

0 comments on commit 09ea989

Please sign in to comment.