Skip to content

Commit

Permalink
add deferrable support to DatabricksNotebookOperator (#39295)
Browse files Browse the repository at this point in the history
related: #39178

This PR intends to make DatabricksNotebookOperator deferrable
  • Loading branch information
rawwar authored May 14, 2024
1 parent 7db851f commit 1e4663f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 8 deletions.
1 change: 1 addition & 0 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class BaseDatabricksHook(BaseHook):
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param caller: The name of the operator that is calling the hook.
"""

conn_name_attr: str = "databricks_conn_id"
Expand Down
40 changes: 34 additions & 6 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)

error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}"

if event["repair_run"]:
if event.get("repair_run"):
log.warning(
"%s but since repair run is set, repairing the run with all failed tasks",
error_message,
Expand Down Expand Up @@ -923,9 +923,11 @@ class DatabricksNotebookOperator(BaseOperator):
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param databricks_conn_id: The name of the Airflow connection to use.
:param deferrable: Run operator in the deferrable mode.
"""

template_fields = ("notebook_params",)
CALLER = "DatabricksNotebookOperator"

def __init__(
self,
Expand All @@ -942,6 +944,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
self.notebook_path = notebook_path
Expand All @@ -958,19 +961,20 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
self.deferrable = deferrable
super().__init__(**kwargs)

@cached_property
def _hook(self) -> DatabricksHook:
return self._get_hook(caller="DatabricksNotebookOperator")
return self._get_hook(caller=self.CALLER)

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
caller=self.CALLER,
)

def _get_task_timeout_seconds(self) -> int:
Expand Down Expand Up @@ -1041,6 +1045,19 @@ def monitor_databricks_job(self) -> None:
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
if self.deferrable and not run_state.is_terminal:
return self.defer(
trigger=DatabricksExecutionTrigger(
run_id=self.databricks_run_id,
databricks_conn_id=self.databricks_conn_id,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.CALLER,
),
method_name=DEFER_METHOD_NAME,
)
while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(self.databricks_run_id)
Expand All @@ -1056,13 +1073,24 @@ def monitor_databricks_job(self) -> None:
)
if not run_state.is_successful:
raise AirflowException(
"Task failed. Final state %s. Reason: %s",
run_state.result_state,
run_state.state_message,
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)

def execute(self, context: Context) -> None:
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()

def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
2 changes: 2 additions & 0 deletions airflow/providers/databricks/triggers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
retry_args: dict[Any, Any] | None = None,
run_page_url: str | None = None,
repair_run: bool = False,
caller: str = "DatabricksExecutionTrigger",
) -> None:
super().__init__()
self.run_id = run_id
Expand All @@ -63,6 +64,7 @@ def __init__(
retry_limit=self.retry_limit,
retry_delay=self.retry_delay,
retry_args=retry_args,
caller=caller,
)

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand Down
48 changes: 46 additions & 2 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,50 @@ def test_execute_without_wait_for_termination(self):
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_execute_with_deferrable(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {"state": {"life_cycle_state": "PENDING"}}
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=True,
deferrable=True,
)
operator.databricks_run_id = 12345

with pytest.raises(TaskDeferred) as exec_info:
operator.monitor_databricks_job()
assert isinstance(
exec_info.value.trigger, DatabricksExecutionTrigger
), "Trigger is not a DatabricksExecutionTrigger"
assert exec_info.value.method_name == "execute_complete"

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_execute_with_deferrable_early_termination(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "FAILURE",
}
}
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=True,
deferrable=True,
)
operator.databricks_run_id = 12345

with pytest.raises(AirflowException) as exec_info:
operator.monitor_databricks_job()
exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
assert exception_message == str(exec_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
Expand Down Expand Up @@ -1896,10 +1940,10 @@ def test_monitor_databricks_job_failed(self, mock_databricks_hook):

operator.databricks_run_id = 12345

exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'"
with pytest.raises(AirflowException) as exc_info:
operator.monitor_databricks_job()
assert exception_message in str(exc_info.value)
exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
assert exception_message == str(exc_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_launch_notebook_job(self, mock_databricks_hook):
Expand Down

0 comments on commit 1e4663f

Please sign in to comment.