diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 1443ff4740b94..12df5f40457ee 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -24,6 +24,7 @@ from airflow.hooks.base_hook import BaseHook from requests import exceptions as requests_exceptions from requests.auth import AuthBase +from time import sleep from airflow.utils.log.logging_mixin import LoggingMixin @@ -47,7 +48,8 @@ def __init__( self, databricks_conn_id='databricks_default', timeout_seconds=180, - retry_limit=3): + retry_limit=3, + retry_delay=1): """ :param databricks_conn_id: The name of the databricks connection to use. :type databricks_conn_id: string @@ -57,12 +59,16 @@ def __init__( :param retry_limit: The number of times to retry the connection in case of service outages. :type retry_limit: int + :param retry_delay: The number of seconds to wait between retries (it + might be a floating point number). + :type retry_delay: float """ self.databricks_conn_id = databricks_conn_id self.databricks_conn = self.get_connection(databricks_conn_id) self.timeout_seconds = timeout_seconds assert retry_limit >= 1, 'Retry limit must be greater than equal to 1' self.retry_limit = retry_limit + self.retry_delay = retry_delay def _parse_host(self, host): """ @@ -117,7 +123,8 @@ def _do_api_call(self, endpoint_info, json): else: raise AirflowException('Unexpected HTTP Method: ' + method) - for attempt_num in range(1, self.retry_limit + 1): + attempt_num = 1 + while True: try: response = request_func( url, @@ -125,21 +132,42 @@ def _do_api_call(self, endpoint_info, json): auth=auth, headers=USER_AGENT_HEADER, timeout=self.timeout_seconds) - if response.status_code == requests.codes.ok: - return response.json() - else: + response.raise_for_status() + return response.json() + except (requests_exceptions.ConnectionError, + requests_exceptions.Timeout) as e: + self._log_request_error(attempt_num, e) + except requests_exceptions.HTTPError as e: + response = e.response + if not self._retriable_error(response): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException('Response: {0}, Status Code: {1}'.format( response.content, response.status_code)) - except (requests_exceptions.ConnectionError, - requests_exceptions.Timeout) as e: - self.log.error( - 'Attempt %s API Request to Databricks failed with reason: %s', - attempt_num, e - ) - raise AirflowException(('API requests to Databricks failed {} times. ' + - 'Giving up.').format(self.retry_limit)) + + self._log_request_error(attempt_num, e) + + if attempt_num == self.retry_limit: + raise AirflowException(('API requests to Databricks failed {} times. ' + + 'Giving up.').format(self.retry_limit)) + + attempt_num += 1 + sleep(self.retry_delay) + + def _log_request_error(self, attempt_num, error): + self.log.error( + 'Attempt %s API Request to Databricks failed with reason: %s', + attempt_num, error + ) + + @staticmethod + def _retriable_error(response): + try: + error_code = response.json().get('error_code') + return error_code == 'TEMPORARILY_UNAVAILABLE' + except ValueError: + # not a valid JSON + return False def submit_run(self, json): """ diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 7b8d522dba85b..3245a99256502 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator): :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :type databricks_retry_limit: int + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :type databricks_retry_delay: float :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: boolean """ @@ -168,6 +171,7 @@ def __init__( databricks_conn_id='databricks_default', polling_period_seconds=30, databricks_retry_limit=3, + databricks_retry_delay=1, do_xcom_push=False, **kwargs): """ @@ -178,6 +182,7 @@ def __init__( self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay if spark_jar_task is not None: self.json['spark_jar_task'] = spark_jar_task if notebook_task is not None: @@ -232,7 +237,8 @@ def _log_run_page_url(self, url): def get_hook(self): return DatabricksHook( self.databricks_conn_id, - retry_limit=self.databricks_retry_limit) + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay) def execute(self, context): hook = self.get_hook() diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 6052a6d54f1f8..e199f72332d2d 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -114,31 +114,51 @@ def test_init_bad_retry_limit(self): DatabricksHook(retry_limit = 0) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests): - for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_with_error_retry(self, _, mock_requests): + for exception in [ + requests_exceptions.ConnectionError(), + requests_exceptions.Timeout(), + self._build_http_error('TEMPORARILY_UNAVAILABLE')]: with mock.patch.object(self.hook.log, 'error') as mock_errors: - mock_requests.reset_mock() - mock_requests.post.side_effect = exception() + self._setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_waits_between_retries(self, mock_sleep, mock_requests): + retry_delay = 5 + self.hook = DatabricksHook(retry_delay=retry_delay) + + for exception in [ + requests_exceptions.ConnectionError(), + requests_exceptions.Timeout(), + self._build_http_error('TEMPORARILY_UNAVAILABLE')]: + with mock.patch.object(self.hook.log, 'error'): + mock_sleep.reset_mock() + self._setup_mock_requests(mock_requests, exception) + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) + mock_sleep.assert_called_with(retry_delay) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_do_api_call_with_bad_status_code(self, mock_requests): - mock_requests.codes.ok = 200 - status_code_mock = mock.PropertyMock(return_value=500) - type(mock_requests.post.return_value).status_code = status_code_mock + response = mock.MagicMock() + response.raise_for_status.side_effect = self._build_http_error('ERROR') + mock_requests.post.return_value = response with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_submit_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock json = { 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER @@ -158,10 +178,7 @@ def test_submit_run(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_page_url(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_page_url = self.hook.get_run_page_url(RUN_ID) @@ -175,10 +192,7 @@ def test_get_run_page_url(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_state(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_state = self.hook.get_run_state(RUN_ID) @@ -195,10 +209,7 @@ def test_get_run_state(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_cancel_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock self.hook.cancel_run(RUN_ID) @@ -209,6 +220,23 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) + @staticmethod + def _setup_mock_requests(mock_requests, exception): + mock_requests.reset_mock() + if type(exception) in [requests_exceptions.ConnectionError, + requests_exceptions.Timeout]: + mock_requests.post.side_effect = exception + elif type(exception) == requests_exceptions.HTTPError: + mock_requests.raise_for_status.side_effect = exception + + @staticmethod + def _build_http_error(error_code): + response = mock.MagicMock() + error_info = {'error_code': error_code, 'message': ''} + response.json.return_value = error_info + response.text = json.dumps(error_info) + return requests_exceptions.HTTPError(response=response) + class DatabricksHookTokenTest(unittest.TestCase): """ diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index f77da2ec18eda..afe1a92f28d9e 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class): 'run_name': TASK_ID }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class): 'run_name': TASK_ID, }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID)