Skip to content

Commit

Permalink
[AIRFLOW-2974] Extended Databricks hook with clusters operation (apac…
Browse files Browse the repository at this point in the history
…he#3817)

Add hooks for:
- cluster start,
- restart,
- terminate.
Add unit tests for the added hooks.
Add hooks for cluster start, restart and terminate.
Add unit tests for the added hooks.
Add cluster_id variable for performing cluster operation tests.
  • Loading branch information
wmorris75 authored and ashb committed Oct 22, 2018
1 parent e098dec commit a3ecc0a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
12 changes: 12 additions & 0 deletions airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
except ImportError:
import urlparse

RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")

SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
Expand Down Expand Up @@ -188,6 +191,15 @@ def cancel_run(self, run_id):
json = {'run_id': run_id}
self._do_api_call(CANCEL_RUN_ENDPOINT, json)

def restart_cluster(self, json):
self._do_api_call(RESTART_CLUSTER_ENDPOINT, json)

def start_cluster(self, json):
self._do_api_call(START_CLUSTER_ENDPOINT, json)

def terminate_cluster(self, json):
self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)


def _retryable_error(exception):
return isinstance(exception, requests_exceptions.ConnectionError) \
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[metadata]
name = Airflow
summary = Airflow is a system to programmatically author, schedule and monitor data pipelines.
Expand All @@ -34,4 +35,3 @@ all_files = 1
upload-dir = docs/_build/html

[easy_install]

70 changes: 70 additions & 0 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
'node_type_id': 'r3.xlarge',
'num_workers': 1
}
CLUSTER_ID = 'cluster_id'
RUN_ID = 1
HOST = 'xx.cloud.databricks.com'
HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
Expand Down Expand Up @@ -93,6 +94,27 @@ def cancel_run_endpoint(host):
return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)


def start_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return 'https://{}/api/2.0/clusters/start'.format(host)


def restart_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return 'https://{}/api/2.0/clusters/restart'.format(host)


def terminate_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return 'https://{}/api/2.0/clusters/delete'.format(host)


def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
Expand Down Expand Up @@ -293,6 +315,54 @@ def test_cancel_run(self, mock_requests):
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_start_cluster(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.start_cluster({"cluster_id": CLUSTER_ID})

mock_requests.post.assert_called_once_with(
start_cluster_endpoint(HOST),
json={'cluster_id': CLUSTER_ID},
auth=(LOGIN, PASSWORD),
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_restart_cluster(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.restart_cluster({"cluster_id": CLUSTER_ID})

mock_requests.post.assert_called_once_with(
restart_cluster_endpoint(HOST),
json={'cluster_id': CLUSTER_ID},
auth=(LOGIN, PASSWORD),
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_terminate_cluster(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})

mock_requests.post.assert_called_once_with(
terminate_cluster_endpoint(HOST),
json={'cluster_id': CLUSTER_ID},
auth=(LOGIN, PASSWORD),
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)


class DatabricksHookTokenTest(unittest.TestCase):
"""
Expand Down

0 comments on commit a3ecc0a

Please sign in to comment.