diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 22f37e95555d6..1f16e5667b9a5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -892,3 +892,177 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator): def __init__(self, *args, **kwargs): super().__init__(deferrable=True, *args, **kwargs) + + +class DatabricksNotebookOperator(BaseOperator): + """ + Runs a notebook on Databricks using an Airflow operator. + + The DatabricksNotebookOperator allows users to launch and monitor notebook + job runs on Databricks as Airflow tasks. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksNotebookOperator` + + :param notebook_path: The path to the notebook in Databricks. + :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved + from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository + defined in git_source. If the value is empty, the task will use GIT if git_source is defined + and WORKSPACE otherwise. For more information please visit + https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task. + :param notebook_packages: A list of the Python libraries to be installed on the cluster running the + notebook. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param existing_cluster_id: ID for existing cluster on which to run this task. + :param job_cluster_key: The key for the job cluster. + :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + :param databricks_conn_id: The name of the Airflow connection to use. + """ + + template_fields = ("notebook_params",) + + def __init__( + self, + notebook_path: str, + source: str, + notebook_params: dict | None = None, + notebook_packages: list[dict[str, Any]] | None = None, + new_cluster: dict[str, Any] | None = None, + existing_cluster_id: str = "", + job_cluster_key: str = "", + polling_period_seconds: int = 5, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, + wait_for_termination: bool = True, + databricks_conn_id: str = "databricks_default", + **kwargs: Any, + ): + self.notebook_path = notebook_path + self.source = source + self.notebook_params = notebook_params or {} + self.notebook_packages = notebook_packages or [] + self.new_cluster = new_cluster or {} + self.existing_cluster_id = existing_cluster_id + self.job_cluster_key = job_cluster_key + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + self.wait_for_termination = wait_for_termination + self.databricks_conn_id = databricks_conn_id + self.databricks_run_id: int | None = None + super().__init__(**kwargs) + + @cached_property + def _hook(self) -> DatabricksHook: + return self._get_hook(caller="DatabricksNotebookOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) + + def _get_task_timeout_seconds(self) -> int: + """ + Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task. + + By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when + execution_timeout is not defined, the task continues to run indefinitely. Therefore, + to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating + that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, + where a timeout seconds value of 0 signifies an indefinite run duration. + More details can be found in the Databricks documentation: + See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds + """ + if self.execution_timeout is None: + return 0 + execution_timeout_seconds = int(self.execution_timeout.total_seconds()) + if execution_timeout_seconds == 0: + raise ValueError( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to " + "`None` if you desire the task to run indefinitely." + ) + return execution_timeout_seconds + + def _get_task_base_json(self) -> dict[str, Any]: + """Get task base json to be used for task submissions.""" + return { + "timeout_seconds": self._get_task_timeout_seconds(), + "email_notifications": {}, + "notebook_task": { + "notebook_path": self.notebook_path, + "source": self.source, + "base_parameters": self.notebook_params, + }, + "libraries": self.notebook_packages, + } + + def _get_databricks_task_id(self, task_id: str) -> str: + """Get the databricks task ID using dag_id and task_id. Removes illegal characters.""" + return f"{self.dag_id}__{task_id.replace('.', '__')}" + + def _get_run_json(self) -> dict[str, Any]: + """Get run json to be used for task submissions.""" + run_json = { + "run_name": self._get_databricks_task_id(self.task_id), + **self._get_task_base_json(), + } + if self.new_cluster and self.existing_cluster_id: + raise ValueError("Both new_cluster and existing_cluster_id are set. Only one should be set.") + if self.new_cluster: + run_json["new_cluster"] = self.new_cluster + elif self.existing_cluster_id: + run_json["existing_cluster_id"] = self.existing_cluster_id + else: + raise ValueError("Must specify either existing_cluster_id or new_cluster.") + return run_json + + def launch_notebook_job(self) -> int: + run_json = self._get_run_json() + self.databricks_run_id = self._hook.submit_run(run_json) + url = self._hook.get_run_page_url(self.databricks_run_id) + self.log.info("Check the job run in Databricks: %s", url) + return self.databricks_run_id + + def monitor_databricks_job(self) -> None: + if self.databricks_run_id is None: + raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") + run = self._hook.get_run(self.databricks_run_id) + run_state = RunState(**run["state"]) + self.log.info("Current state of the job: %s", run_state.life_cycle_state) + while not run_state.is_terminal: + time.sleep(self.polling_period_seconds) + run = self._hook.get_run(self.databricks_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state + ) + self.log.info("Current state of the job: %s", run_state.life_cycle_state) + if run_state.life_cycle_state != "TERMINATED": + raise AirflowException( + f"Databricks job failed with state {run_state.life_cycle_state}. " + f"Message: {run_state.state_message}" + ) + if not run_state.is_successful: + raise AirflowException( + "Task failed. Final state %s. Reason: %s", + run_state.result_state, + run_state.state_message, + ) + self.log.info("Task succeeded. Final state %s.", run_state.result_state) + + def execute(self, context: Context) -> None: + self.launch_notebook_job() + if self.wait_for_termination: + self.monitor_databricks_job() diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index ad22b7c34368b..541f6a153aade 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -86,6 +86,7 @@ integrations: external-doc-url: https://databricks.com/ how-to-guide: - /docs/apache-airflow-providers-databricks/operators/jobs_create.rst + - /docs/apache-airflow-providers-databricks/operators/notebook.rst - /docs/apache-airflow-providers-databricks/operators/submit_run.rst - /docs/apache-airflow-providers-databricks/operators/run_now.rst logo: /integration-logos/databricks/Databricks.png diff --git a/docs/apache-airflow-providers-databricks/operators/notebook.rst b/docs/apache-airflow-providers-databricks/operators/notebook.rst new file mode 100644 index 0000000000000..b87d0d20e6f5a --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/notebook.rst @@ -0,0 +1,44 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:DatabricksNotebookOperator: + + +DatabricksNotebookOperator +========================== + +Use the :class:`~airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` to launch and monitor +notebook job runs on Databricks as Airflow tasks. + + + +Examples +-------- + +Running a notebook in Databricks on a new cluster +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_notebook_new_cluster] + :end-before: [END howto_operator_databricks_notebook_new_cluster] + +Running a notebook in Databricks on an existing cluster +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_notebook_existing_cluster] + :end-before: [END howto_operator_databricks_notebook_existing_cluster] diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 26d59baa61610..902aa37e918ea 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from unittest import mock from unittest.mock import MagicMock @@ -28,6 +28,7 @@ from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, + DatabricksNotebookOperator, DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, @@ -1754,3 +1755,144 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ db_mock.get_run_page_url.assert_called_once_with(RUN_ID) assert op.run_id == RUN_ID assert not mock_defer.called + + +class TestDatabricksNotebookOperator: + def test_execute_with_wait_for_termination(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + operator.launch_notebook_job = MagicMock(return_value=12345) + operator.monitor_databricks_job = MagicMock() + + operator.execute({}) + + assert operator.wait_for_termination is True + operator.launch_notebook_job.assert_called_once() + operator.monitor_databricks_job.assert_called_once() + + def test_execute_without_wait_for_termination(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + wait_for_termination=False, + ) + operator.launch_notebook_job = MagicMock(return_value=12345) + operator.monitor_databricks_job = MagicMock() + + operator.execute({}) + + assert operator.wait_for_termination is False + operator.launch_notebook_job.assert_called_once() + operator.monitor_databricks_job.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} + } + + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + + operator.databricks_run_id = 12345 + operator.monitor_databricks_job() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_monitor_databricks_job_failed(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"} + } + + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + + operator.databricks_run_id = 12345 + + exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'" + with pytest.raises(AirflowException) as exc_info: + operator.monitor_databricks_job() + assert exception_message in str(exc_info.value) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_launch_notebook_job(self, mock_databricks_hook): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="test_cluster_id", + ) + operator._hook.submit_run.return_value = 12345 + + run_id = operator.launch_notebook_job() + + assert run_id == 12345 + + def test_both_new_and_existing_cluster_set(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + new_cluster={"new_cluster_config_key": "new_cluster_config_value"}, + existing_cluster_id="existing_cluster_id", + databricks_conn_id="test_conn_id", + ) + with pytest.raises(ValueError) as exc_info: + operator._get_run_json() + exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set." + assert str(exc_info.value) == exception_message + + def test_both_new_and_existing_cluster_unset(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + with pytest.raises(ValueError) as exc_info: + operator._get_run_json() + exception_message = "Must specify either existing_cluster_id or new_cluster." + assert str(exc_info.value) == exception_message + + def test_job_runs_forever_by_default(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="existing_cluster_id", + ) + run_json = operator._get_run_json() + assert operator.execution_timeout is None + assert run_json["timeout_seconds"] == 0 + + def test_zero_execution_timeout_raises_error(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + existing_cluster_id="existing_cluster_id", + execution_timeout=timedelta(seconds=0), + ) + with pytest.raises(ValueError) as exc_info: + operator._get_run_json() + exception_message = ( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. " + "Set it instead to `None` if you desire the task to run indefinitely." + ) + assert str(exc_info.value) == exception_message diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index d2fe87db301f9..62b8e5df50fd4 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -39,6 +39,7 @@ from airflow import DAG from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, + DatabricksNotebookOperator, DatabricksRunNowOperator, DatabricksSubmitRunOperator, ) @@ -147,6 +148,59 @@ # [END howto_operator_databricks_named] notebook_task >> spark_jar_task + # [START howto_operator_databricks_notebook_new_cluster] + new_cluster_spec = { + "cluster_name": "", + "spark_version": "11.3.x-scala2.12", + "aws_attributes": { + "first_on_demand": 1, + "availability": "SPOT_WITH_FALLBACK", + "zone_id": "us-east-2b", + "spot_bid_price_percent": 100, + "ebs_volume_count": 0, + }, + "node_type_id": "i3.xlarge", + "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"}, + "enable_elastic_disk": False, + "data_security_mode": "LEGACY_SINGLE_USER_STANDARD", + "runtime_engine": "STANDARD", + "num_workers": 8, + } + + notebook_1 = DatabricksNotebookOperator( + task_id="notebook_1", + notebook_path="/Shared/Notebook_1", + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", + "repo": "https://pypi.org/simple", + } + }, + {"pypi": {"package": "Faker"}}, + ], + source="WORKSPACE", + new_cluster=new_cluster_spec, + ) + # [END howto_operator_databricks_notebook_new_cluster] + + # [START howto_operator_databricks_notebook_existing_cluster] + notebook_2 = DatabricksNotebookOperator( + task_id="notebook_2", + notebook_path="/Shared/Notebook_2", + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", + "repo": "https://pypi.org/simple", + } + }, + ], + source="WORKSPACE", + existing_cluster_id="existing_cluster_id", + ) + # [END howto_operator_databricks_notebook_existing_cluster] + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure