Skip to content

Commit

Permalink
Add DatabricksWorkflowTaskGroup (apache#39771)
Browse files Browse the repository at this point in the history
This pull request introduces the [DatabricksWorkflowTaskGroup](https://github.com/astronomer/astro-provider-databricks/blob/main/src/astro_databricks/operators/workflow.py#L226)
to the Airflow Databricks provider from the [astro-provider-databricks](https://github.com/astronomer/astro-provider-databricks/tree/main)
repository. 
It marks another pull request aimed at contributing 
operators and features from that repository into the Airflow 
Databricks provider, the previous PR being apache#39178.

The task group launches a [Databricks Workflow](https://docs.databricks.com/en/workflows/index.html) 
and runs the notebook jobs from within it, resulting in a 
[75% cost reduction](https://www.databricks.com/product/pricing) ($0.40/DBU for all-purpose compute, 
$0.07/DBU for Jobs compute) when compared to executing 
``DatabricksNotebookOperator`` outside of ``DatabricksWorkflowTaskGroup``.

---------
Co-authored-by: Daniel Imberman <daniel.imberman@gmail.com>
Co-authored-by: Tatiana Al-Chueyr <tatiana.alchueyr@gmail.com>
Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
pankajkoti authored and fdemiane committed Jun 6, 2024
1 parent 964f6d8 commit cba9c45
Show file tree
Hide file tree
Showing 11 changed files with 998 additions and 25 deletions.
18 changes: 18 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from __future__ import annotations

import json
from enum import Enum
from typing import Any

from requests import exceptions as requests_exceptions
Expand Down Expand Up @@ -63,6 +64,23 @@
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")


class RunLifeCycleState(Enum):
"""Enum for the run life cycle state concept of Databricks runs.
See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
"""

BLOCKED = "BLOCKED"
INTERNAL_ERROR = "INTERNAL_ERROR"
PENDING = "PENDING"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
SKIPPED = "SKIPPED"
TERMINATED = "TERMINATED"
TERMINATING = "TERMINATING"
WAITING_FOR_RETRY = "WAITING_FOR_RETRY"


class RunState:
"""Utility class for the run state concept of Databricks runs."""

Expand Down
141 changes: 116 additions & 25 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState
from airflow.providers.databricks.operators.databricks_workflow import (
DatabricksWorkflowTaskGroup,
WorkflowRunMetadata,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

DEFER_METHOD_NAME = "execute_complete"
XCOM_RUN_ID_KEY = "run_id"
Expand Down Expand Up @@ -926,7 +931,10 @@ class DatabricksNotebookOperator(BaseOperator):
:param deferrable: Run operator in the deferrable mode.
"""

template_fields = ("notebook_params",)
template_fields = (
"notebook_params",
"workflow_run_metadata",
)
CALLER = "DatabricksNotebookOperator"

def __init__(
Expand All @@ -944,6 +952,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
workflow_run_metadata: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
Expand All @@ -962,6 +971,10 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
self.deferrable = deferrable

# This is used to store the metadata of the Databricks job run when the job is launched from within DatabricksWorkflowTaskGroup.
self.workflow_run_metadata: dict | None = workflow_run_metadata

super().__init__(**kwargs)

@cached_property
Expand Down Expand Up @@ -1016,6 +1029,79 @@ def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
return f"{self.dag_id}__{task_id.replace('.', '__')}"

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
"""
Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found.
If found, returns the task group. Otherwise, return None.
"""
parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group

while parent_tg:
if getattr(parent_tg, "is_databricks", False):
return parent_tg # type: ignore[return-value]

if getattr(parent_tg, "task_group", None):
parent_tg = parent_tg.task_group
else:
return None

return None

def _extend_workflow_notebook_packages(
self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
) -> None:
"""Extend the task group packages into the notebook's packages, without adding any duplicates."""
for task_group_package in databricks_workflow_task_group.notebook_packages:
exists = any(
task_group_package == existing_package for existing_package in self.notebook_packages
)
if not exists:
self.notebook_packages.append(task_group_package)

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
databricks_workflow_task_group = self._databricks_workflow_task_group
if not databricks_workflow_task_group:
raise AirflowException(
"Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup."
)

if hasattr(databricks_workflow_task_group, "notebook_packages"):
self._extend_workflow_notebook_packages(databricks_workflow_task_group)

if hasattr(databricks_workflow_task_group, "notebook_params"):
self.notebook_params = {
**self.notebook_params,
**databricks_workflow_task_group.notebook_params,
}

base_task_json = self._get_task_base_json()
result = {
"task_key": self._get_databricks_task_id(self.task_id),
"depends_on": [
{"task_key": self._get_databricks_task_id(task_id)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
**base_task_json,
}

if self.existing_cluster_id and self.job_cluster_key:
raise ValueError(
"Both existing_cluster_id and job_cluster_key are set. Only one can be set per task."
)

if self.existing_cluster_id:
result["existing_cluster_id"] = self.existing_cluster_id
elif self.job_cluster_key:
result["job_cluster_key"] = self.job_cluster_key

return result

def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
Expand All @@ -1039,6 +1125,17 @@ def launch_notebook_job(self) -> int:
self.log.info("Check the job run in Databricks: %s", url)
return self.databricks_run_id

def _handle_terminal_run_state(self, run_state: RunState) -> None:
if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)

def monitor_databricks_job(self) -> None:
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
Expand All @@ -1063,34 +1160,28 @@ def monitor_databricks_job(self) -> None:
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info(
"task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state
)
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
"Current state of the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
run_state.life_cycle_state,
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
self._handle_terminal_run_state(run_state)

def execute(self, context: Context) -> None:
self.launch_notebook_job()
if self._databricks_workflow_task_group:
# If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
if not self.workflow_run_metadata:
launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type]
**self.workflow_run_metadata
)
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
else:
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()

def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
self._handle_terminal_run_state(run_state)
Loading

0 comments on commit cba9c45

Please sign in to comment.