From fb2ea03146e5b9956becbaa23cec440b21585df3 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 23 Aug 2018 13:46:40 -0400 Subject: [PATCH] [AIRFLOW-2974] Extended Databricks hook with clusters operation (#3817) Added hooks for cluster start, restart and terminate. Added unit tests for the added hooks. Added hooks for cluster start, restart and terminate. Added unit tests for the added hooks. Added cluster_id variable for performing cluster operation tests. --- README.md | 1 + airflow/contrib/hooks/databricks_hook.py | 12 +++ setup.cfg | 2 +- tests/contrib/hooks/test_databricks_hook.py | 84 +++++++++++++++++++-- 4 files changed, 93 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 34296c1cd334c..606f2da745afd 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ Currently **officially** using Airflow: 1. [Flipp](https://www.flipp.com) [[@sethwilsonwishabi](https://github.com/sethwilsonwishabi)] 1. [FreshBooks](https://github.com/freshbooks) [[@DinoCow](https://github.com/DinoCow)] 1. [Fundera](https://fundera.com) [[@andyxhadji](https://github.com/andyxhadji)] +1. [G Adventures](https://gadventures.com) [[@samuelmullin](https://github.com/samuelmullin)] 1. [GameWisp](https://gamewisp.com) [[@tjbiii](https://github.com/TJBIII) & [@theryanwalls](https://github.com/theryanwalls)] 1. [Gentner Lab](http://github.com/gentnerlab) [[@neuromusic](https://github.com/neuromusic)] 1. [Glassdoor](https://github.com/Glassdoor) [[@syvineckruyk](https://github.com/syvineckruyk)] diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 54f00e00907c0..52aa139545f4f 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -32,6 +32,9 @@ except ImportError: import urlparse +RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") +START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") +TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit') GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get') @@ -174,6 +177,15 @@ def cancel_run(self, run_id): json = {'run_id': run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json) + def restart_cluster(self, json): + self._do_api_call(RESTART_CLUSTER_ENDPOINT, json) + + def start_cluster(self, json): + self._do_api_call(START_CLUSTER_ENDPOINT, json) + + def terminate_cluster(self, json): + self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json) + RUN_LIFE_CYCLE_STATES = [ 'PENDING', diff --git a/setup.cfg b/setup.cfg index 622cc1303a173..881fe0107d9b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + [metadata] name = Airflow summary = Airflow is a system to programmatically author, schedule and monitor data pipelines. @@ -34,4 +35,3 @@ all_files = 1 upload-dir = docs/_build/html [easy_install] - diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index aca8dd96004b4..0443a54fe669a 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -46,6 +46,7 @@ 'node_type_id': 'r3.xlarge', 'num_workers': 1 } +CLUSTER_ID = 'cluster_id' RUN_ID = 1 HOST = 'xx.cloud.databricks.com' HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' @@ -79,16 +80,40 @@ def get_run_endpoint(host): """ return 'https://{}/api/2.0/jobs/runs/get'.format(host) + def cancel_run_endpoint(host): """ Utility function to generate the get run endpoint given the host. """ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +def start_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/start'.format(host) + + +def restart_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/restart'.format(host) + + +def terminate_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/delete'.format(host) + + class DatabricksHookTest(unittest.TestCase): """ Tests for DatabricksHook. """ + @db.provide_session def setUp(self, session=None): conn = session.query(Connection) \ @@ -111,7 +136,7 @@ def test_parse_host_with_scheme(self): def test_init_bad_retry_limit(self): with self.assertRaises(ValueError): - DatabricksHook(retry_limit = 0) + DatabricksHook(retry_limit=0) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_do_api_call_with_error_retry(self, mock_requests): @@ -140,8 +165,8 @@ def test_submit_run(self, mock_requests): status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock json = { - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER } run_id = self.hook.submit_run(json) @@ -209,11 +234,60 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_start_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.start_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + start_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_restart_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + restart_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_terminate_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + terminate_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + class DatabricksHookTokenTest(unittest.TestCase): """ Tests for DatabricksHook when auth is done with token. """ + @db.provide_session def setUp(self, session=None): conn = session.query(Connection) \ @@ -231,8 +305,8 @@ def test_submit_run(self, mock_requests): status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock json = { - 'notebook_task': NOTEBOOK_TASK, - 'new_cluster': NEW_CLUSTER + 'notebook_task': NOTEBOOK_TASK, + 'new_cluster': NEW_CLUSTER } run_id = self.hook.submit_run(json)