From 12a8616332b121ef09b45a9b23e9185a3b734599 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 22 Apr 2024 18:39:05 +0530 Subject: [PATCH 1/9] Contribute DatabricksNotebookOperator --- .../databricks/operators/databricks.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 22f37e95555d6..dfa80abfd8e36 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -892,3 +892,125 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator): def __init__(self, *args, **kwargs): super().__init__(deferrable=True, *args, **kwargs) + + +class DatabricksNotebookOperator(BaseOperator): + """ + Runs a notebook on Databricks using an Airflow operator. + + The DatabricksNotebookOperator allows users to launch and monitor notebook + job runs on Databricks as Aiflow tasks. + + :param notebook_path: The path to the notebook in Databricks. + :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved + from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository + defined in git_source. If the value is empty, the task will use GIT if git_source is defined + and WORKSPACE otherwise. For more information please visit + https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task. + :param notebook_packages: A list of the Python libraries to be installed on the cluster running the + notebook. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param existing_cluster_id: ID for existing cluster on which to run this task. + :param job_cluster_key: The key for the job cluster. + :param databricks_conn_id: The name of the Airflow connection to use. + """ + + template_fields = ("notebook_params",) + + def __init__( + self, + notebook_path: str, + source: str, + notebook_params: dict | None = None, + notebook_packages: list[dict[str, Any]] | None = None, + new_cluster: dict[str, Any] | None = None, + existing_cluster_id: str | None = None, + job_cluster_key: str | None = None, + databricks_conn_id: str = "databricks_default", + **kwargs: Any, + ): + self.notebook_path = notebook_path + self.source = source + self.notebook_params = notebook_params or {} + self.notebook_packages = notebook_packages or [] + self.new_cluster = new_cluster or {} + self.existing_cluster_id = existing_cluster_id or "" + self.job_cluster_key = job_cluster_key or "" + self.databricks_conn_id = databricks_conn_id + self.databricks_run_id = "" + super().__init__(**kwargs) + + def _get_task_base_json(self) -> dict[str, Any]: + """Get task base json to be used for task submissions.""" + return { + # Timeout seconds value of 0 for the Databricks Jobs API means the job runs forever. + # That is also the default behavior of Databricks jobs to run a job forever without a default + # timeout value. + "timeout_seconds": int(self.execution_timeout.total_seconds()) if self.execution_timeout else 0, + "email_notifications": {}, + "notebook_task": { + "notebook_path": self.notebook_path, + "source": self.source, + "base_parameters": self.notebook_params, + }, + "libraries": self.notebook_packages, + } + + def _get_databricks_task_id(self, task_id: str): + """Get the databricks task ID using dag_id and task_id. Removes illegal characters.""" + return f"{self.dag_id}__" + task_id.replace(".", "__") + + def _get_run_json(self): + """Get run json to be used for task submissions.""" + run_json = { + "run_name": self._get_databricks_task_id(self.task_id), + **self._get_task_base_json(), + } + if self.new_cluster and self.existing_cluster_id: + raise ValueError("Both new_cluster and existing_cluster_id are set. Only one should be set.") + if self.new_cluster: + run_json["new_cluster"] = self.new_cluster + elif self.existing_cluster_id: + run_json["existing_cluster_id"] = self.existing_cluster_id + else: + raise ValueError("Must specify either existing_cluster_id or new_cluster.") + return run_json + + def launch_notebook_job(self): + hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) + run_json = self._get_run_json() + self.databricks_run_id = hook.submit_run(run_json) + url = hook.get_run_page_url(self.databricks_run_id) + self.log.info("Check the job run in Databricks: %s", url) + return self.databricks_run_id + + def monitor_databricks_job(self): + hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) + run = hook.get_run(self.databricks_run_id) + run_state = RunState(**run["state"]) + self.log.info("Current state of the job: %s", run_state.life_cycle_state) + while not run_state.is_terminal: + time.sleep(5) + run = hook.get_run(self.databricks_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state + ) + self.log.info("Current state of the job: %s", run_state.life_cycle_state) + if run_state.life_cycle_state != "TERMINATED": + raise AirflowException( + f"Databricks job failed with state {run_state.life_cycle_state}. " + f"Message: {run_state.state_message}" + ) + if not run_state.is_successful: + raise AirflowException( + "Task failed. Final state %s. Reason: %s", + run_state.result_state, + run_state.state_message, + ) + self.log.info("Task succeeded. Final state %s.", run_state.result_state) + + def execute(self, context: Context) -> Any: + self.launch_notebook_job() + self.monitor_databricks_job() From 32f11b85e495a994c5296b55fc8633fea9ccb002 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 22 Apr 2024 20:52:07 +0530 Subject: [PATCH 2/9] Address @tatiana's comment to use a cached hook across methods --- .../databricks/operators/databricks.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index dfa80abfd8e36..e6d3beaf08666 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -927,6 +927,10 @@ def __init__( new_cluster: dict[str, Any] | None = None, existing_cluster_id: str | None = None, job_cluster_key: str | None = None, + polling_period_seconds: int = 5, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, databricks_conn_id: str = "databricks_default", **kwargs: Any, ): @@ -937,10 +941,27 @@ def __init__( self.new_cluster = new_cluster or {} self.existing_cluster_id = existing_cluster_id or "" self.job_cluster_key = job_cluster_key or "" + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args self.databricks_conn_id = databricks_conn_id self.databricks_run_id = "" super().__init__(**kwargs) + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksNotebookOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) + def _get_task_base_json(self) -> dict[str, Any]: """Get task base json to be used for task submissions.""" return { @@ -978,21 +999,19 @@ def _get_run_json(self): return run_json def launch_notebook_job(self): - hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) run_json = self._get_run_json() - self.databricks_run_id = hook.submit_run(run_json) - url = hook.get_run_page_url(self.databricks_run_id) + self.databricks_run_id = self._hook.submit_run(run_json) + url = self._hook.get_run_page_url(self.databricks_run_id) self.log.info("Check the job run in Databricks: %s", url) return self.databricks_run_id def monitor_databricks_job(self): - hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) - run = hook.get_run(self.databricks_run_id) + run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) while not run_state.is_terminal: - time.sleep(5) - run = hook.get_run(self.databricks_run_id) + time.sleep(self.polling_period_seconds) + run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info( "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state From e3e1f2ea95b865adcf47d35324f2147fa5f6ee64 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Apr 2024 00:26:32 +0530 Subject: [PATCH 3/9] Add tests --- .../databricks/operators/test_databricks.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 26d59baa61610..67cc29ac9d978 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -28,6 +28,7 @@ from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, + DatabricksNotebookOperator, DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, @@ -1754,3 +1755,90 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ db_mock.get_run_page_url.assert_called_once_with(RUN_ID) assert op.run_id == RUN_ID assert not mock_defer.called + + +class TestDatabricksNotebookOperator: + def test_execute(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + operator.launch_notebook_job = MagicMock(return_value="12345") + operator.monitor_databricks_job = MagicMock() + + operator.execute({}) + + operator.launch_notebook_job.assert_called_once() + operator.monitor_databricks_job.assert_called_once() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_monitor_databricks_job_successful(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} + } + + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + + operator.databricks_run_id = "12345" + operator.monitor_databricks_job() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_monitor_databricks_job_failed(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED"} + } + + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + + operator.databricks_run_id = "12345" + with pytest.raises(AirflowException): + operator.monitor_databricks_job() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_launch_notebook_job(self, mock_databricks_hook): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="test_cluster_id", + ) + operator._hook.submit_run.return_value = "12345" + + run_id = operator.launch_notebook_job() + + assert run_id == "12345" + + def test_both_new_and_existing_cluster_set(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + new_cluster={"new_cluster_config_key": "new_cluster_config_value"}, + existing_cluster_id="existing_cluster_id", + databricks_conn_id="test_conn_id", + ) + with pytest.raises(ValueError): + operator._get_run_json() + + def test_both_new_and_existing_cluster_unset(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + with pytest.raises(ValueError): + operator._get_run_json() From a4fcdb2b02038174c64473be5fcfc2424ad46da8 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Apr 2024 00:55:49 +0530 Subject: [PATCH 4/9] Add example DAGs and docs --- .../databricks/operators/databricks.py | 6 ++- airflow/providers/databricks/provider.yaml | 1 + .../operators/notebook.rst | 44 +++++++++++++++ .../databricks/example_databricks.py | 54 +++++++++++++++++++ 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 docs/apache-airflow-providers-databricks/operators/notebook.rst diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index e6d3beaf08666..9a52ec281f51b 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -899,7 +899,11 @@ class DatabricksNotebookOperator(BaseOperator): Runs a notebook on Databricks using an Airflow operator. The DatabricksNotebookOperator allows users to launch and monitor notebook - job runs on Databricks as Aiflow tasks. + job runs on Databricks as Airflow tasks. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksNotebookOperator` :param notebook_path: The path to the notebook in Databricks. :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index ad22b7c34368b..541f6a153aade 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -86,6 +86,7 @@ integrations: external-doc-url: https://databricks.com/ how-to-guide: - /docs/apache-airflow-providers-databricks/operators/jobs_create.rst + - /docs/apache-airflow-providers-databricks/operators/notebook.rst - /docs/apache-airflow-providers-databricks/operators/submit_run.rst - /docs/apache-airflow-providers-databricks/operators/run_now.rst logo: /integration-logos/databricks/Databricks.png diff --git a/docs/apache-airflow-providers-databricks/operators/notebook.rst b/docs/apache-airflow-providers-databricks/operators/notebook.rst new file mode 100644 index 0000000000000..b87d0d20e6f5a --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/notebook.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:DatabricksNotebookOperator: + + +DatabricksNotebookOperator +========================== + +Use the :class:`~airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` to launch and monitor +notebook job runs on Databricks as Airflow tasks. + + + +Examples +-------- + +Running a notebook in Databricks on a new cluster +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_notebook_new_cluster] + :end-before: [END howto_operator_databricks_notebook_new_cluster] + +Running a notebook in Databricks on an existing cluster +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_notebook_existing_cluster] + :end-before: [END howto_operator_databricks_notebook_existing_cluster] diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index d2fe87db301f9..62b8e5df50fd4 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -39,6 +39,7 @@ from airflow import DAG from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, + DatabricksNotebookOperator, DatabricksRunNowOperator, DatabricksSubmitRunOperator, ) @@ -147,6 +148,59 @@ # [END howto_operator_databricks_named] notebook_task >> spark_jar_task + # [START howto_operator_databricks_notebook_new_cluster] + new_cluster_spec = { + "cluster_name": "", + "spark_version": "11.3.x-scala2.12", + "aws_attributes": { + "first_on_demand": 1, + "availability": "SPOT_WITH_FALLBACK", + "zone_id": "us-east-2b", + "spot_bid_price_percent": 100, + "ebs_volume_count": 0, + }, + "node_type_id": "i3.xlarge", + "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"}, + "enable_elastic_disk": False, + "data_security_mode": "LEGACY_SINGLE_USER_STANDARD", + "runtime_engine": "STANDARD", + "num_workers": 8, + } + + notebook_1 = DatabricksNotebookOperator( + task_id="notebook_1", + notebook_path="/Shared/Notebook_1", + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", + "repo": "https://pypi.org/simple", + } + }, + {"pypi": {"package": "Faker"}}, + ], + source="WORKSPACE", + new_cluster=new_cluster_spec, + ) + # [END howto_operator_databricks_notebook_new_cluster] + + # [START howto_operator_databricks_notebook_existing_cluster] + notebook_2 = DatabricksNotebookOperator( + task_id="notebook_2", + notebook_path="/Shared/Notebook_2", + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", + "repo": "https://pypi.org/simple", + } + }, + ], + source="WORKSPACE", + existing_cluster_id="existing_cluster_id", + ) + # [END howto_operator_databricks_notebook_existing_cluster] + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure From 0b07c1e5ada43b0b6defeef69f96a76b36b2fb39 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Apr 2024 16:18:52 +0530 Subject: [PATCH 5/9] Address next set of comments from @tatiana --- .../databricks/operators/databricks.py | 20 +++++++++++++++---- .../databricks/operators/test_databricks.py | 12 +++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 9a52ec281f51b..877fc4da9afb9 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -917,6 +917,10 @@ class DatabricksNotebookOperator(BaseOperator): :param new_cluster: Specs for a new cluster on which this task will be run. :param existing_cluster_id: ID for existing cluster on which to run this task. :param job_cluster_key: The key for the job cluster. + :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param databricks_conn_id: The name of the Airflow connection to use. """ @@ -968,11 +972,19 @@ def _get_hook(self, caller: str) -> DatabricksHook: def _get_task_base_json(self) -> dict[str, Any]: """Get task base json to be used for task submissions.""" + if self.execution_timeout is None: + # By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when + # execution_timeout is not defined, the task continues to run indefinitely. Therefore, + # to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating + # that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, + # where a timeout seconds value of 0 signifies an indefinite run duration. + # More details can be found in the Databricks documentation: + # See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds + timeout_seconds = 0 + else: + timeout_seconds = int(self.execution_timeout.total_seconds()) return { - # Timeout seconds value of 0 for the Databricks Jobs API means the job runs forever. - # That is also the default behavior of Databricks jobs to run a job forever without a default - # timeout value. - "timeout_seconds": int(self.execution_timeout.total_seconds()) if self.execution_timeout else 0, + "timeout_seconds": timeout_seconds, "email_notifications": {}, "notebook_task": { "notebook_path": self.notebook_path, diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 67cc29ac9d978..ddc2d4a5b2440 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1842,3 +1842,15 @@ def test_both_new_and_existing_cluster_unset(self): ) with pytest.raises(ValueError): operator._get_run_json() + + def test_job_runs_forever_by_default(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="existing_cluster_id", + ) + run_json = operator._get_run_json() + assert operator.execution_timeout is None + assert run_json["timeout_seconds"] == 0 From 7d8d0d8c04a9227eef8971b1a9f0ed6a0b6d4cbe Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Apr 2024 16:42:32 +0530 Subject: [PATCH 6/9] Apply suggestions from code review Co-authored-by: Tatiana Al-Chueyr --- airflow/providers/databricks/operators/databricks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 877fc4da9afb9..b48f9b075b5d6 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -933,8 +933,8 @@ def __init__( notebook_params: dict | None = None, notebook_packages: list[dict[str, Any]] | None = None, new_cluster: dict[str, Any] | None = None, - existing_cluster_id: str | None = None, - job_cluster_key: str | None = None, + existing_cluster_id: str = "", + job_cluster_key: str = "", polling_period_seconds: int = 5, databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, @@ -947,8 +947,8 @@ def __init__( self.notebook_params = notebook_params or {} self.notebook_packages = notebook_packages or [] self.new_cluster = new_cluster or {} - self.existing_cluster_id = existing_cluster_id or "" - self.job_cluster_key = job_cluster_key or "" + self.existing_cluster_id = existing_cluster_id + self.job_cluster_key = job_cluster_key self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay From c086e76c0288b3db3c07cc667617dafa7f812cea Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Apr 2024 17:09:10 +0530 Subject: [PATCH 7/9] Move timeout_seconds calculation to a separate method --- .../databricks/operators/databricks.py | 35 ++++++++++++------- .../databricks/operators/test_databricks.py | 14 +++++++- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index b48f9b075b5d6..4ff0fc06a3db6 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -970,21 +970,32 @@ def _get_hook(self, caller: str) -> DatabricksHook: caller=caller, ) + def _get_task_timeout_seconds(self) -> int: + """ + Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task. + + By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when + execution_timeout is not defined, the task continues to run indefinitely. Therefore, + to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating + that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, + where a timeout seconds value of 0 signifies an indefinite run duration. + More details can be found in the Databricks documentation: + See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds + """ + if self.execution_timeout is None: + return 0 + execution_timeout_seconds = int(self.execution_timeout.total_seconds()) + if execution_timeout_seconds == 0: + raise ValueError( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to " + "`None` if you desire the task to run indefinitely." + ) + return execution_timeout_seconds + def _get_task_base_json(self) -> dict[str, Any]: """Get task base json to be used for task submissions.""" - if self.execution_timeout is None: - # By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when - # execution_timeout is not defined, the task continues to run indefinitely. Therefore, - # to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating - # that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, - # where a timeout seconds value of 0 signifies an indefinite run duration. - # More details can be found in the Databricks documentation: - # See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds - timeout_seconds = 0 - else: - timeout_seconds = int(self.execution_timeout.total_seconds()) return { - "timeout_seconds": timeout_seconds, + "timeout_seconds": self._get_task_timeout_seconds(), "email_notifications": {}, "notebook_task": { "notebook_path": self.notebook_path, diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index ddc2d4a5b2440..0f70f7f7fe1c9 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from unittest import mock from unittest.mock import MagicMock @@ -1854,3 +1854,15 @@ def test_job_runs_forever_by_default(self): run_json = operator._get_run_json() assert operator.execution_timeout is None assert run_json["timeout_seconds"] == 0 + + def test_zero_execution_timeout_raises_error(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="existing_cluster_id", + execution_timeout=timedelta(seconds=0), + ) + with pytest.raises(ValueError): + operator._get_run_json() From aa41fc25c9df307bd6ab0c8142abdcc26a77e41c Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Thu, 25 Apr 2024 14:54:56 +0530 Subject: [PATCH 8/9] Apply suggestions from code review Co-authored-by: Wei Lee --- .../providers/databricks/operators/databricks.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 4ff0fc06a3db6..f0ea5b3d0daa2 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -958,7 +958,7 @@ def __init__( super().__init__(**kwargs) @cached_property - def _hook(self): + def _hook(self) -> DatabricksHook: return self._get_hook(caller="DatabricksNotebookOperator") def _get_hook(self, caller: str) -> DatabricksHook: @@ -1005,11 +1005,11 @@ def _get_task_base_json(self) -> dict[str, Any]: "libraries": self.notebook_packages, } - def _get_databricks_task_id(self, task_id: str): + def _get_databricks_task_id(self, task_id: str) -> str: """Get the databricks task ID using dag_id and task_id. Removes illegal characters.""" - return f"{self.dag_id}__" + task_id.replace(".", "__") + return f"{self.dag_id}__{task_id.replace('.', '__')}" - def _get_run_json(self): + def _get_run_json(self) -> dict[str, Any]: """Get run json to be used for task submissions.""" run_json = { "run_name": self._get_databricks_task_id(self.task_id), @@ -1025,14 +1025,14 @@ def _get_run_json(self): raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def launch_notebook_job(self): + def launch_notebook_job(self) -> str: run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) self.log.info("Check the job run in Databricks: %s", url) return self.databricks_run_id - def monitor_databricks_job(self): + def monitor_databricks_job(self) -> None: run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) @@ -1057,6 +1057,6 @@ def monitor_databricks_job(self): ) self.log.info("Task succeeded. Final state %s.", run_state.result_state) - def execute(self, context: Context) -> Any: + def execute(self, context: Context) -> None: self.launch_notebook_job() self.monitor_databricks_job() From 20dacc7cec64d0055fad79943fd6afa453dbe775 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Thu, 25 Apr 2024 18:33:19 +0530 Subject: [PATCH 9/9] Addres Lee-W's comments --- .../databricks/operators/databricks.py | 12 +++-- .../databricks/operators/test_databricks.py | 54 ++++++++++++++----- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index f0ea5b3d0daa2..1f16e5667b9a5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -921,6 +921,7 @@ class DatabricksNotebookOperator(BaseOperator): :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. :param databricks_retry_delay: Number of seconds to wait between retries. :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param databricks_conn_id: The name of the Airflow connection to use. """ @@ -939,6 +940,7 @@ def __init__( databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, databricks_retry_args: dict[Any, Any] | None = None, + wait_for_termination: bool = True, databricks_conn_id: str = "databricks_default", **kwargs: Any, ): @@ -953,8 +955,9 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args + self.wait_for_termination = wait_for_termination self.databricks_conn_id = databricks_conn_id - self.databricks_run_id = "" + self.databricks_run_id: int | None = None super().__init__(**kwargs) @cached_property @@ -1025,7 +1028,7 @@ def _get_run_json(self) -> dict[str, Any]: raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def launch_notebook_job(self) -> str: + def launch_notebook_job(self) -> int: run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) @@ -1033,6 +1036,8 @@ def launch_notebook_job(self) -> str: return self.databricks_run_id def monitor_databricks_job(self) -> None: + if self.databricks_run_id is None: + raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) @@ -1059,4 +1064,5 @@ def monitor_databricks_job(self) -> None: def execute(self, context: Context) -> None: self.launch_notebook_job() - self.monitor_databricks_job() + if self.wait_for_termination: + self.monitor_databricks_job() diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 0f70f7f7fe1c9..902aa37e918ea 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1758,23 +1758,41 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ class TestDatabricksNotebookOperator: - def test_execute(self): + def test_execute_with_wait_for_termination(self): operator = DatabricksNotebookOperator( task_id="test_task", notebook_path="test_path", source="test_source", databricks_conn_id="test_conn_id", ) - operator.launch_notebook_job = MagicMock(return_value="12345") + operator.launch_notebook_job = MagicMock(return_value=12345) operator.monitor_databricks_job = MagicMock() operator.execute({}) + assert operator.wait_for_termination is True operator.launch_notebook_job.assert_called_once() operator.monitor_databricks_job.assert_called_once() + def test_execute_without_wait_for_termination(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + wait_for_termination=False, + ) + operator.launch_notebook_job = MagicMock(return_value=12345) + operator.monitor_databricks_job = MagicMock() + + operator.execute({}) + + assert operator.wait_for_termination is False + operator.launch_notebook_job.assert_called_once() + operator.monitor_databricks_job.assert_not_called() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_monitor_databricks_job_successful(self, mock_databricks_hook): + def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} } @@ -1786,13 +1804,13 @@ def test_monitor_databricks_job_successful(self, mock_databricks_hook): databricks_conn_id="test_conn_id", ) - operator.databricks_run_id = "12345" + operator.databricks_run_id = 12345 operator.monitor_databricks_job() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_monitor_databricks_job_failed(self, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { - "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED"} + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"} } operator = DatabricksNotebookOperator( @@ -1802,9 +1820,12 @@ def test_monitor_databricks_job_failed(self, mock_databricks_hook): databricks_conn_id="test_conn_id", ) - operator.databricks_run_id = "12345" - with pytest.raises(AirflowException): + operator.databricks_run_id = 12345 + + exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'" + with pytest.raises(AirflowException) as exc_info: operator.monitor_databricks_job() + assert exception_message in str(exc_info.value) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_launch_notebook_job(self, mock_databricks_hook): @@ -1815,11 +1836,11 @@ def test_launch_notebook_job(self, mock_databricks_hook): databricks_conn_id="test_conn_id", existing_cluster_id="test_cluster_id", ) - operator._hook.submit_run.return_value = "12345" + operator._hook.submit_run.return_value = 12345 run_id = operator.launch_notebook_job() - assert run_id == "12345" + assert run_id == 12345 def test_both_new_and_existing_cluster_set(self): operator = DatabricksNotebookOperator( @@ -1830,8 +1851,10 @@ def test_both_new_and_existing_cluster_set(self): existing_cluster_id="existing_cluster_id", databricks_conn_id="test_conn_id", ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set." + assert str(exc_info.value) == exception_message def test_both_new_and_existing_cluster_unset(self): operator = DatabricksNotebookOperator( @@ -1840,8 +1863,10 @@ def test_both_new_and_existing_cluster_unset(self): source="test_source", databricks_conn_id="test_conn_id", ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = "Must specify either existing_cluster_id or new_cluster." + assert str(exc_info.value) == exception_message def test_job_runs_forever_by_default(self): operator = DatabricksNotebookOperator( @@ -1864,5 +1889,10 @@ def test_zero_execution_timeout_raises_error(self): existing_cluster_id="existing_cluster_id", execution_timeout=timedelta(seconds=0), ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = ( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. " + "Set it instead to `None` if you desire the task to run indefinitely." + ) + assert str(exc_info.value) == exception_message