diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 802b8efb87842..4e05ddb0b1af4 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -33,6 +33,9 @@ except ImportError: import urlparse +RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") +START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") +TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit') GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get') @@ -188,6 +191,15 @@ def cancel_run(self, run_id): json = {'run_id': run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json) + def restart_cluster(self, json): + self._do_api_call(RESTART_CLUSTER_ENDPOINT, json) + + def start_cluster(self, json): + self._do_api_call(START_CLUSTER_ENDPOINT, json) + + def terminate_cluster(self, json): + self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json) + def _retryable_error(exception): return isinstance(exception, requests_exceptions.ConnectionError) \ diff --git a/setup.cfg b/setup.cfg index 622cc1303a173..881fe0107d9b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + [metadata] name = Airflow summary = Airflow is a system to programmatically author, schedule and monitor data pipelines. @@ -34,4 +35,3 @@ all_files = 1 upload-dir = docs/_build/html [easy_install] - diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index a022431899b4d..5c1d7876d2f32 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -52,6 +52,7 @@ 'node_type_id': 'r3.xlarge', 'num_workers': 1 } +CLUSTER_ID = 'cluster_id' RUN_ID = 1 HOST = 'xx.cloud.databricks.com' HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' @@ -93,6 +94,27 @@ def cancel_run_endpoint(host): return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) +def start_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/start'.format(host) + + +def restart_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/restart'.format(host) + + +def terminate_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/delete'.format(host) + + def create_valid_response_mock(content): response = mock.MagicMock() response.json.return_value = content @@ -293,6 +315,54 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_start_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.start_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + start_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_restart_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + restart_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_terminate_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + terminate_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + class DatabricksHookTokenTest(unittest.TestCase): """