Skip to content

Commit

Permalink
Addres Lee-W's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Apr 25, 2024
1 parent aa41fc2 commit 20dacc7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
12 changes: 9 additions & 3 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ class DatabricksNotebookOperator(BaseOperator):
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param databricks_conn_id: The name of the Airflow connection to use.
"""

Expand All @@ -939,6 +940,7 @@ def __init__(
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
**kwargs: Any,
):
Expand All @@ -953,8 +955,9 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id = ""
self.databricks_run_id: int | None = None
super().__init__(**kwargs)

@cached_property
Expand Down Expand Up @@ -1025,14 +1028,16 @@ def _get_run_json(self) -> dict[str, Any]:
raise ValueError("Must specify either existing_cluster_id or new_cluster.")
return run_json

def launch_notebook_job(self) -> str:
def launch_notebook_job(self) -> int:
run_json = self._get_run_json()
self.databricks_run_id = self._hook.submit_run(run_json)
url = self._hook.get_run_page_url(self.databricks_run_id)
self.log.info("Check the job run in Databricks: %s", url)
return self.databricks_run_id

def monitor_databricks_job(self) -> None:
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
Expand All @@ -1059,4 +1064,5 @@ def monitor_databricks_job(self) -> None:

def execute(self, context: Context) -> None:
self.launch_notebook_job()
self.monitor_databricks_job()
if self.wait_for_termination:
self.monitor_databricks_job()
54 changes: 42 additions & 12 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,23 +1758,41 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_


class TestDatabricksNotebookOperator:
def test_execute(self):
def test_execute_with_wait_for_termination(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
)
operator.launch_notebook_job = MagicMock(return_value="12345")
operator.launch_notebook_job = MagicMock(return_value=12345)
operator.monitor_databricks_job = MagicMock()

operator.execute({})

assert operator.wait_for_termination is True
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_called_once()

def test_execute_without_wait_for_termination(self):
operator = DatabricksNotebookOperator(
task_id="test_task",
notebook_path="test_path",
source="test_source",
databricks_conn_id="test_conn_id",
wait_for_termination=False,
)
operator.launch_notebook_job = MagicMock(return_value=12345)
operator.monitor_databricks_job = MagicMock()

operator.execute({})

assert operator.wait_for_termination is False
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_successful(self, mock_databricks_hook):
def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}
}
Expand All @@ -1786,13 +1804,13 @@ def test_monitor_databricks_job_successful(self, mock_databricks_hook):
databricks_conn_id="test_conn_id",
)

operator.databricks_run_id = "12345"
operator.databricks_run_id = 12345
operator.monitor_databricks_job()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_failed(self, mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
"state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED"}
"state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"}
}

operator = DatabricksNotebookOperator(
Expand All @@ -1802,9 +1820,12 @@ def test_monitor_databricks_job_failed(self, mock_databricks_hook):
databricks_conn_id="test_conn_id",
)

operator.databricks_run_id = "12345"
with pytest.raises(AirflowException):
operator.databricks_run_id = 12345

exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'"
with pytest.raises(AirflowException) as exc_info:
operator.monitor_databricks_job()
assert exception_message in str(exc_info.value)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_launch_notebook_job(self, mock_databricks_hook):
Expand All @@ -1815,11 +1836,11 @@ def test_launch_notebook_job(self, mock_databricks_hook):
databricks_conn_id="test_conn_id",
existing_cluster_id="test_cluster_id",
)
operator._hook.submit_run.return_value = "12345"
operator._hook.submit_run.return_value = 12345

run_id = operator.launch_notebook_job()

assert run_id == "12345"
assert run_id == 12345

def test_both_new_and_existing_cluster_set(self):
operator = DatabricksNotebookOperator(
Expand All @@ -1830,8 +1851,10 @@ def test_both_new_and_existing_cluster_set(self):
existing_cluster_id="existing_cluster_id",
databricks_conn_id="test_conn_id",
)
with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set."
assert str(exc_info.value) == exception_message

def test_both_new_and_existing_cluster_unset(self):
operator = DatabricksNotebookOperator(
Expand All @@ -1840,8 +1863,10 @@ def test_both_new_and_existing_cluster_unset(self):
source="test_source",
databricks_conn_id="test_conn_id",
)
with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = "Must specify either existing_cluster_id or new_cluster."
assert str(exc_info.value) == exception_message

def test_job_runs_forever_by_default(self):
operator = DatabricksNotebookOperator(
Expand All @@ -1864,5 +1889,10 @@ def test_zero_execution_timeout_raises_error(self):
existing_cluster_id="existing_cluster_id",
execution_timeout=timedelta(seconds=0),
)
with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc_info:
operator._get_run_json()
exception_message = (
"If you've set an `execution_timeout` for the task, ensure it's not `0`. "
"Set it instead to `None` if you desire the task to run indefinitely."
)
assert str(exc_info.value) == exception_message

0 comments on commit 20dacc7

Please sign in to comment.