Skip to content

Commit

Permalink
[AIRFLOW-2709] Improve error handling in Databricks hook
Browse files Browse the repository at this point in the history
  • Loading branch information
betabandido committed Jul 3, 2018
1 parent 985a433 commit 0f6c493
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 37 deletions.
54 changes: 41 additions & 13 deletions airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.hooks.base_hook import BaseHook
from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
from time import sleep

from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -47,7 +48,8 @@ def __init__(
self,
databricks_conn_id='databricks_default',
timeout_seconds=180,
retry_limit=3):
retry_limit=3,
retry_delay=1):
"""
:param databricks_conn_id: The name of the databricks connection to use.
:type databricks_conn_id: string
Expand All @@ -57,12 +59,16 @@ def __init__(
:param retry_limit: The number of times to retry the connection in case of
service outages.
:type retry_limit: int
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:type retry_delay: float
"""
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
self.retry_limit = retry_limit
self.retry_delay = retry_delay

def _parse_host(self, host):
"""
Expand Down Expand Up @@ -117,29 +123,51 @@ def _do_api_call(self, endpoint_info, json):
else:
raise AirflowException('Unexpected HTTP Method: ' + method)

for attempt_num in range(1, self.retry_limit + 1):
attempt_num = 1
while True:
try:
response = request_func(
url,
json=json,
auth=auth,
headers=USER_AGENT_HEADER,
timeout=self.timeout_seconds)
if response.status_code == requests.codes.ok:
return response.json()
else:
response.raise_for_status()
return response.json()
except (requests_exceptions.ConnectionError,
requests_exceptions.Timeout) as e:
self._log_request_error(attempt_num, e)
except requests_exceptions.HTTPError as e:
response = e.response
if not self._retriable_error(response):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException('Response: {0}, Status Code: {1}'.format(
response.content, response.status_code))
except (requests_exceptions.ConnectionError,
requests_exceptions.Timeout) as e:
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, e
)
raise AirflowException(('API requests to Databricks failed {} times. ' +
'Giving up.').format(self.retry_limit))

self._log_request_error(attempt_num, e)

if attempt_num == self.retry_limit:
raise AirflowException(('API requests to Databricks failed {} times. ' +
'Giving up.').format(self.retry_limit))

attempt_num += 1
sleep(self.retry_delay)

def _log_request_error(self, attempt_num, error):
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, error
)

@staticmethod
def _retriable_error(response):
try:
error_code = response.json().get('error_code')
return error_code == 'TEMPORARILY_UNAVAILABLE'
except ValueError:
# not a valid JSON
return False

def submit_run(self, json):
"""
Expand Down
8 changes: 7 additions & 1 deletion airflow/contrib/operators/databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:type databricks_retry_limit: int
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:type databricks_retry_delay: float
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:type do_xcom_push: boolean
"""
Expand All @@ -168,6 +171,7 @@ def __init__(
databricks_conn_id='databricks_default',
polling_period_seconds=30,
databricks_retry_limit=3,
databricks_retry_delay=1,
do_xcom_push=False,
**kwargs):
"""
Expand All @@ -178,6 +182,7 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
if spark_jar_task is not None:
self.json['spark_jar_task'] = spark_jar_task
if notebook_task is not None:
Expand Down Expand Up @@ -232,7 +237,8 @@ def _log_run_page_url(self, url):
def get_hook(self):
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit)
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay)

def execute(self, context):
hook = self.get_hook()
Expand Down
66 changes: 47 additions & 19 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,31 +114,51 @@ def test_init_bad_retry_limit(self):
DatabricksHook(retry_limit = 0)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_error_retry(self, mock_requests):
for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
@mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
def test_do_api_call_with_error_retry(self, _, mock_requests):
for exception in [
requests_exceptions.ConnectionError(),
requests_exceptions.Timeout(),
self._build_http_error('TEMPORARILY_UNAVAILABLE')]:
with mock.patch.object(self.hook.log, 'error') as mock_errors:
mock_requests.reset_mock()
mock_requests.post.side_effect = exception()
self._setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
@mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
def test_do_api_call_waits_between_retries(self, mock_sleep, mock_requests):
retry_delay = 5
self.hook = DatabricksHook(retry_delay=retry_delay)

for exception in [
requests_exceptions.ConnectionError(),
requests_exceptions.Timeout(),
self._build_http_error('TEMPORARILY_UNAVAILABLE')]:
with mock.patch.object(self.hook.log, 'error'):
mock_sleep.reset_mock()
self._setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
mock_sleep.assert_called_with(retry_delay)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_bad_status_code(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=500)
type(mock_requests.post.return_value).status_code = status_code_mock
response = mock.MagicMock()
response.raise_for_status.side_effect = self._build_http_error('ERROR')
mock_requests.post.return_value = response
with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {'run_id': '1'}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {
'notebook_task': NOTEBOOK_TASK,
'new_cluster': NEW_CLUSTER
Expand All @@ -158,10 +178,7 @@ def test_submit_run(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_page_url(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_page_url = self.hook.get_run_page_url(RUN_ID)

Expand All @@ -175,10 +192,7 @@ def test_get_run_page_url(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_state(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_state = self.hook.get_run_state(RUN_ID)

Expand All @@ -195,10 +209,7 @@ def test_get_run_state(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_cancel_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.cancel_run(RUN_ID)

Expand All @@ -209,6 +220,23 @@ def test_cancel_run(self, mock_requests):
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)

@staticmethod
def _setup_mock_requests(mock_requests, exception):
mock_requests.reset_mock()
if type(exception) in [requests_exceptions.ConnectionError,
requests_exceptions.Timeout]:
mock_requests.post.side_effect = exception
elif type(exception) == requests_exceptions.HTTPError:
mock_requests.raise_for_status.side_effect = exception

@staticmethod
def _build_http_error(error_code):
response = mock.MagicMock()
error_info = {'error_code': error_code, 'message': ''}
response.json.return_value = error_info
response.text = json.dumps(error_info)
return requests_exceptions.HTTPError(response=response)


class DatabricksHookTokenTest(unittest.TestCase):
"""
Expand Down
10 changes: 6 additions & 4 deletions tests/contrib/operators/test_databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class):
'run_name': TASK_ID
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down Expand Up @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class):
'run_name': TASK_ID,
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down

0 comments on commit 0f6c493

Please sign in to comment.