diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index f0ea5b3d0daa2..1f16e5667b9a5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -921,6 +921,7 @@ class DatabricksNotebookOperator(BaseOperator): :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. :param databricks_retry_delay: Number of seconds to wait between retries. :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param databricks_conn_id: The name of the Airflow connection to use. """ @@ -939,6 +940,7 @@ def __init__( databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, databricks_retry_args: dict[Any, Any] | None = None, + wait_for_termination: bool = True, databricks_conn_id: str = "databricks_default", **kwargs: Any, ): @@ -953,8 +955,9 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args + self.wait_for_termination = wait_for_termination self.databricks_conn_id = databricks_conn_id - self.databricks_run_id = "" + self.databricks_run_id: int | None = None super().__init__(**kwargs) @cached_property @@ -1025,7 +1028,7 @@ def _get_run_json(self) -> dict[str, Any]: raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def launch_notebook_job(self) -> str: + def launch_notebook_job(self) -> int: run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) @@ -1033,6 +1036,8 @@ def launch_notebook_job(self) -> str: return self.databricks_run_id def monitor_databricks_job(self) -> None: + if self.databricks_run_id is None: + raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) @@ -1059,4 +1064,5 @@ def monitor_databricks_job(self) -> None: def execute(self, context: Context) -> None: self.launch_notebook_job() - self.monitor_databricks_job() + if self.wait_for_termination: + self.monitor_databricks_job() diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 0f70f7f7fe1c9..902aa37e918ea 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1758,23 +1758,41 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ class TestDatabricksNotebookOperator: - def test_execute(self): + def test_execute_with_wait_for_termination(self): operator = DatabricksNotebookOperator( task_id="test_task", notebook_path="test_path", source="test_source", databricks_conn_id="test_conn_id", ) - operator.launch_notebook_job = MagicMock(return_value="12345") + operator.launch_notebook_job = MagicMock(return_value=12345) operator.monitor_databricks_job = MagicMock() operator.execute({}) + assert operator.wait_for_termination is True operator.launch_notebook_job.assert_called_once() operator.monitor_databricks_job.assert_called_once() + def test_execute_without_wait_for_termination(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + wait_for_termination=False, + ) + operator.launch_notebook_job = MagicMock(return_value=12345) + operator.monitor_databricks_job = MagicMock() + + operator.execute({}) + + assert operator.wait_for_termination is False + operator.launch_notebook_job.assert_called_once() + operator.monitor_databricks_job.assert_not_called() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_monitor_databricks_job_successful(self, mock_databricks_hook): + def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} } @@ -1786,13 +1804,13 @@ def test_monitor_databricks_job_successful(self, mock_databricks_hook): databricks_conn_id="test_conn_id", ) - operator.databricks_run_id = "12345" + operator.databricks_run_id = 12345 operator.monitor_databricks_job() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_monitor_databricks_job_failed(self, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { - "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED"} + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"} } operator = DatabricksNotebookOperator( @@ -1802,9 +1820,12 @@ def test_monitor_databricks_job_failed(self, mock_databricks_hook): databricks_conn_id="test_conn_id", ) - operator.databricks_run_id = "12345" - with pytest.raises(AirflowException): + operator.databricks_run_id = 12345 + + exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'" + with pytest.raises(AirflowException) as exc_info: operator.monitor_databricks_job() + assert exception_message in str(exc_info.value) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_launch_notebook_job(self, mock_databricks_hook): @@ -1815,11 +1836,11 @@ def test_launch_notebook_job(self, mock_databricks_hook): databricks_conn_id="test_conn_id", existing_cluster_id="test_cluster_id", ) - operator._hook.submit_run.return_value = "12345" + operator._hook.submit_run.return_value = 12345 run_id = operator.launch_notebook_job() - assert run_id == "12345" + assert run_id == 12345 def test_both_new_and_existing_cluster_set(self): operator = DatabricksNotebookOperator( @@ -1830,8 +1851,10 @@ def test_both_new_and_existing_cluster_set(self): existing_cluster_id="existing_cluster_id", databricks_conn_id="test_conn_id", ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set." + assert str(exc_info.value) == exception_message def test_both_new_and_existing_cluster_unset(self): operator = DatabricksNotebookOperator( @@ -1840,8 +1863,10 @@ def test_both_new_and_existing_cluster_unset(self): source="test_source", databricks_conn_id="test_conn_id", ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = "Must specify either existing_cluster_id or new_cluster." + assert str(exc_info.value) == exception_message def test_job_runs_forever_by_default(self): operator = DatabricksNotebookOperator( @@ -1864,5 +1889,10 @@ def test_zero_execution_timeout_raises_error(self): existing_cluster_id="existing_cluster_id", execution_timeout=timedelta(seconds=0), ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: operator._get_run_json() + exception_message = ( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. " + "Set it instead to `None` if you desire the task to run indefinitely." + ) + assert str(exc_info.value) == exception_message