diff --git a/core/dbt/clients/registry.py b/core/dbt/clients/registry.py index 30b2c9020e6..1d6ac5c7675 100644 --- a/core/dbt/clients/registry.py +++ b/core/dbt/clients/registry.py @@ -1,10 +1,8 @@ -from functools import wraps +import functools import requests -from dbt.exceptions import RegistryException -from dbt.utils import memoized +from dbt.utils import memoized, _connection_exception_retry as connection_exception_retry from dbt.logger import GLOBAL_LOGGER as logger import os -import time if os.getenv('DBT_PACKAGE_HUB_URL'): DEFAULT_REGISTRY_BASE_URL = os.getenv('DBT_PACKAGE_HUB_URL') @@ -19,26 +17,11 @@ def _get_url(url, registry_base_url=None): return '{}{}'.format(registry_base_url, url) -def _wrap_exceptions(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - max_attempts = 5 - attempt = 0 - while True: - attempt += 1 - try: - return fn(*args, **kwargs) - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc: - if attempt < max_attempts: - time.sleep(1) - continue - raise RegistryException( - 'Unable to connect to registry hub' - ) from exc - return wrapper - - -@_wrap_exceptions +def _get_with_retries(path, registry_base_url=None): + get_fn = functools.partial(_get, path, registry_base_url) + return connection_exception_retry(get_fn, 5) + + def _get(path, registry_base_url=None): url = _get_url(path, registry_base_url) logger.debug('Making package registry request: GET {}'.format(url)) @@ -50,22 +33,22 @@ def _get(path, registry_base_url=None): def index(registry_base_url=None): - return _get('api/v1/index.json', registry_base_url) + return _get_with_retries('api/v1/index.json', registry_base_url) index_cached = memoized(index) def packages(registry_base_url=None): - return _get('api/v1/packages.json', registry_base_url) + return _get_with_retries('api/v1/packages.json', registry_base_url) def package(name, registry_base_url=None): - return _get('api/v1/{}.json'.format(name), registry_base_url) + return _get_with_retries('api/v1/{}.json'.format(name), registry_base_url) def package_version(name, version, registry_base_url=None): - return _get('api/v1/{}/{}.json'.format(name, version), registry_base_url) + return _get_with_retries('api/v1/{}/{}.json'.format(name, version), registry_base_url) def get_available_versions(name): diff --git a/core/dbt/clients/system.py b/core/dbt/clients/system.py index 7d4de07108c..9fe4f4a3d48 100644 --- a/core/dbt/clients/system.py +++ b/core/dbt/clients/system.py @@ -1,4 +1,5 @@ import errno +import functools import fnmatch import json import os @@ -15,9 +16,8 @@ ) import dbt.exceptions -import dbt.utils - from dbt.logger import GLOBAL_LOGGER as logger +from dbt.utils import _connection_exception_retry as connection_exception_retry if sys.platform == 'win32': from ctypes import WinDLL, c_bool @@ -441,6 +441,13 @@ def run_cmd( return out, err +def download_with_retries( + url: str, path: str, timeout: Optional[Union[float, tuple]] = None +) -> None: + download_fn = functools.partial(download, url, path, timeout) + connection_exception_retry(download_fn, 5) + + def download( url: str, path: str, timeout: Optional[Union[float, tuple]] = None ) -> None: diff --git a/core/dbt/deps/registry.py b/core/dbt/deps/registry.py index 59593dcce60..43033138fd7 100644 --- a/core/dbt/deps/registry.py +++ b/core/dbt/deps/registry.py @@ -61,7 +61,7 @@ def install(self, project, renderer): system.make_directory(os.path.dirname(tar_path)) download_url = metadata.downloads.tarball - system.download(download_url, tar_path) + system.download_with_retries(download_url, tar_path) deps_path = project.modules_path package_name = self.get_project_name(project, renderer) system.untar_package(tar_path, deps_path, package_name) diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 38f39580d47..674207c4153 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -714,7 +714,7 @@ def system_error(operation_name): .format(operation_name)) -class RegistryException(Exception): +class ConnectionException(Exception): pass diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 9b0034a9345..2e2fe7eae4f 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -9,7 +9,12 @@ import jinja2 import json import os +import requests +import time + from contextlib import contextmanager +from dbt.exceptions import ConnectionException +from dbt.logger import GLOBAL_LOGGER as logger from enum import Enum from typing_extensions import Protocol from typing import ( @@ -602,3 +607,19 @@ def __getitem__(self, name: str) -> Any: def __contains__(self, name) -> bool: return any((name in entry for entry in self._itersource())) + + +def _connection_exception_retry(fn, max_attempts: int, attempt: int = 0): + """Attempts to run a function that makes an external call, if the call fails + on a connection error or timeout, it will be tried up to 5 more times. + """ + try: + return fn() + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as exc: + if attempt <= max_attempts - 1: + logger.debug('Retrying external call. Attempt: ' + + f'{attempt} Max attempts: {max_attempts}') + time.sleep(1) + _connection_exception_retry(fn, max_attempts, attempt + 1) + else: + raise ConnectionException('External connection exception occurred: ' + str(exc)) diff --git a/test/unit/test_core_dbt_utils.py b/test/unit/test_core_dbt_utils.py new file mode 100644 index 00000000000..4c91fb70c44 --- /dev/null +++ b/test/unit/test_core_dbt_utils.py @@ -0,0 +1,42 @@ +import requests +import unittest + +from dbt.exceptions import ConnectionException +from dbt.utils import _connection_exception_retry as connection_exception_retry + + +class TestCoreDbtUtils(unittest.TestCase): + def test_connection_exception_retry_none(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) + self.assertEqual(1, counter) + + def test_connection_exception_retry_max(self): + Counter._reset() + with self.assertRaises(ConnectionException): + connection_exception_retry(lambda: Counter._add_with_exception(), 5) + self.assertEqual(6, counter) # 6 = original attempt plus 5 retries + + def test_connection_exception_retry_success(self): + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_limited_exception(), 5) + self.assertEqual(2, counter) # 2 = original attempt plus 1 retry + + +counter:int = 0 +class Counter(): + def _add(): + global counter + counter+=1 + def _add_with_exception(): + global counter + counter+=1 + raise requests.exceptions.ConnectionError + def _add_with_limited_exception(): + global counter + counter+=1 + if counter < 2: + raise requests.exceptions.ConnectionError + def _reset(): + global counter + counter = 0 diff --git a/test/unit/test_registry_get_request_exception.py b/test/unit/test_registry_get_request_exception.py index 254169d9894..44033fe0546 100644 --- a/test/unit/test_registry_get_request_exception.py +++ b/test/unit/test_registry_get_request_exception.py @@ -1,9 +1,9 @@ import unittest -from dbt.exceptions import RegistryException -from dbt.clients.registry import _get +from dbt.exceptions import ConnectionException +from dbt.clients.registry import _get_with_retries class testRegistryGetRequestException(unittest.TestCase): def test_registry_request_error_catching(self): - # using non routable IP to test connection error logic in the _get function - self.assertRaises(RegistryException, _get, '', 'http://0.0.0.0') + # using non routable IP to test connection error logic in the _get_with_retries function + self.assertRaises(ConnectionException, _get_with_retries, '', 'http://0.0.0.0')