Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-2709] Improve error handling in Databricks hook #3570

Merged
merged 7 commits into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.hooks.base_hook import BaseHook
from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
from time import sleep

from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -47,7 +48,8 @@ def __init__(
self,
databricks_conn_id='databricks_default',
timeout_seconds=180,
retry_limit=3):
retry_limit=3,
retry_delay=1.0):
"""
:param databricks_conn_id: The name of the databricks connection to use.
:type databricks_conn_id: string
Expand All @@ -57,13 +59,17 @@ def __init__(
:param retry_limit: The number of times to retry the connection in case of
service outages.
:type retry_limit: int
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:type retry_delay: float
"""
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
self.retry_limit = retry_limit
self.retry_delay = retry_delay

@staticmethod
def _parse_host(host):
Expand Down Expand Up @@ -119,29 +125,38 @@ def _do_api_call(self, endpoint_info, json):
else:
raise AirflowException('Unexpected HTTP Method: ' + method)

for attempt_num in range(1, self.retry_limit + 1):
attempt_num = 1
while True:
try:
response = request_func(
url,
json=json,
auth=auth,
headers=USER_AGENT_HEADER,
timeout=self.timeout_seconds)
if response.status_code == requests.codes.ok:
return response.json()
else:
response.raise_for_status()
return response.json()
except requests_exceptions.RequestException as e:
if not _retryable_error(e):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException('Response: {0}, Status Code: {1}'.format(
response.content, response.status_code))
except (requests_exceptions.ConnectionError,
requests_exceptions.Timeout) as e:
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, e
)
raise AirflowException(('API requests to Databricks failed {} times. ' +
'Giving up.').format(self.retry_limit))
e.response.content, e.response.status_code))

self._log_request_error(attempt_num, e)

if attempt_num == self.retry_limit:
raise AirflowException(('API requests to Databricks failed {} times. ' +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove duplicate braces please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually didn't add the extra braces. They are needed because of the + operator concatenating the two strings, and the usage of .format().

'Giving up.').format(self.retry_limit))

attempt_num += 1
sleep(self.retry_delay)

def _log_request_error(self, attempt_num, error):
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, error
)

def submit_run(self, json):
"""
Expand Down Expand Up @@ -175,6 +190,12 @@ def cancel_run(self, run_id):
self._do_api_call(CANCEL_RUN_ENDPOINT, json)


def _retryable_error(exception):
return isinstance(exception, requests_exceptions.ConnectionError) \
or isinstance(exception, requests_exceptions.Timeout) \
or exception.response is not None and exception.response.status_code >= 500


RUN_LIFE_CYCLE_STATES = [
'PENDING',
'RUNNING',
Expand Down
8 changes: 7 additions & 1 deletion airflow/contrib/operators/databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:type databricks_retry_limit: int
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:type databricks_retry_delay: float
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:type do_xcom_push: boolean
"""
Expand All @@ -168,6 +171,7 @@ def __init__(
databricks_conn_id='databricks_default',
polling_period_seconds=30,
databricks_retry_limit=3,
databricks_retry_delay=1,
do_xcom_push=False,
**kwargs):
"""
Expand All @@ -178,6 +182,7 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
if spark_jar_task is not None:
self.json['spark_jar_task'] = spark_jar_task
if notebook_task is not None:
Expand Down Expand Up @@ -232,7 +237,8 @@ def _log_run_page_url(self, url):
def get_hook(self):
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit)
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay)

def execute(self, context):
hook = self.get_hook()
Expand Down
144 changes: 114 additions & 30 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
# under the License.
#

import itertools
import json
import unittest

from requests import exceptions as requests_exceptions

from airflow import __version__
from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth
from airflow.contrib.hooks.databricks_hook import (
DatabricksHook,
RunState,
SUBMIT_RUN_ENDPOINT
)
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils import db
from requests import exceptions as requests_exceptions

try:
from unittest import mock
Expand Down Expand Up @@ -79,12 +85,48 @@ def get_run_endpoint(host):
"""
return 'https://{}/api/2.0/jobs/runs/get'.format(host)


def cancel_run_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)


def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
return response


def create_post_side_effect(exception, status_code=500):
if exception != requests_exceptions.HTTPError:
return exception()
else:
response = mock.MagicMock()
response.status_code = status_code
response.raise_for_status.side_effect = exception(response=response)
return response


def setup_mock_requests(
mock_requests,
exception,
status_code=500,
error_count=None,
response_content=None):

side_effect = create_post_side_effect(exception, status_code)

if error_count is None:
# POST requests will fail indefinitely
mock_requests.post.side_effect = itertools.repeat(side_effect)
else:
# POST requests will fail 'error_count' times, and then they will succeed (once)
mock_requests.post.side_effect = \
[side_effect] * error_count + [create_valid_response_mock(response_content)]


class DatabricksHookTest(unittest.TestCase):
"""
Tests for DatabricksHook.
Expand All @@ -99,7 +141,7 @@ def setUp(self, session=None):
conn.password = PASSWORD
session.commit()

self.hook = DatabricksHook()
self.hook = DatabricksHook(retry_delay=0)

def test_parse_host_with_proper_host(self):
host = self.hook._parse_host(HOST)
Expand All @@ -111,34 +153,85 @@ def test_parse_host_with_scheme(self):

def test_init_bad_retry_limit(self):
with self.assertRaises(ValueError):
DatabricksHook(retry_limit = 0)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_error_retry(self, mock_requests):
for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
with mock.patch.object(self.hook.log, 'error') as mock_errors:
mock_requests.reset_mock()
mock_requests.post.side_effect = exception()
DatabricksHook(retry_limit=0)

def test_do_api_call_retries_with_retryable_error(self):
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error') as mock_errors:
setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)
self.assertEquals(mock_errors.call_count, self.hook.retry_limit)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_bad_status_code(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=500)
type(mock_requests.post.return_value).status_code = status_code_mock
with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
setup_mock_requests(
mock_requests, requests_exceptions.HTTPError, status_code=400
)

with mock.patch.object(self.hook.log, 'error') as mock_errors:
with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

mock_errors.assert_not_called()

def test_do_api_call_succeeds_after_retrying(self):
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error') as mock_errors:
setup_mock_requests(
mock_requests,
exception,
error_count=2,
response_content={'run_id': '1'}
)

response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(mock_errors.call_count, 2)
self.assertEquals(response, {'run_id': '1'})

@mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
def test_do_api_call_waits_between_retries(self, mock_sleep):
retry_delay = 5
self.hook = DatabricksHook(retry_delay=retry_delay)

for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error'):
mock_sleep.reset_mock()
setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
mock_sleep.assert_called_with(retry_delay)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {'run_id': '1'}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {
'notebook_task': NOTEBOOK_TASK,
'new_cluster': NEW_CLUSTER
Expand All @@ -158,10 +251,7 @@ def test_submit_run(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_page_url(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_page_url = self.hook.get_run_page_url(RUN_ID)

Expand All @@ -175,10 +265,7 @@ def test_get_run_page_url(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_state(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_state = self.hook.get_run_state(RUN_ID)

Expand All @@ -195,10 +282,7 @@ def test_get_run_state(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_cancel_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.cancel_run(RUN_ID)

Expand Down
10 changes: 6 additions & 4 deletions tests/contrib/operators/test_databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class):
'run_name': TASK_ID
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down Expand Up @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class):
'run_name': TASK_ID,
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down