diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 710074d239af0..ee5349add055e 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -29,6 +29,7 @@ from __future__ import annotations import json +from enum import Enum from typing import Any from requests import exceptions as requests_exceptions @@ -63,6 +64,23 @@ SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions") +class RunLifeCycleState(Enum): + """Enum for the run life cycle state concept of Databricks runs. + + See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state + """ + + BLOCKED = "BLOCKED" + INTERNAL_ERROR = "INTERNAL_ERROR" + PENDING = "PENDING" + QUEUED = "QUEUED" + RUNNING = "RUNNING" + SKIPPED = "SKIPPED" + TERMINATED = "TERMINATED" + TERMINATING = "TERMINATING" + WAITING_FOR_RETRY = "WAITING_FOR_RETRY" + + class RunState: """Utility class for the run state concept of Databricks runs.""" diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index ff8de101326be..d6118f247f597 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -29,13 +29,18 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator, BaseOperatorLink, XCom -from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState +from airflow.providers.databricks.operators.databricks_workflow import ( + DatabricksWorkflowTaskGroup, + WorkflowRunMetadata, +) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event if TYPE_CHECKING: from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context + from airflow.utils.task_group import TaskGroup DEFER_METHOD_NAME = "execute_complete" XCOM_RUN_ID_KEY = "run_id" @@ -926,7 +931,10 @@ class DatabricksNotebookOperator(BaseOperator): :param deferrable: Run operator in the deferrable mode. """ - template_fields = ("notebook_params",) + template_fields = ( + "notebook_params", + "workflow_run_metadata", + ) CALLER = "DatabricksNotebookOperator" def __init__( @@ -944,6 +952,7 @@ def __init__( databricks_retry_args: dict[Any, Any] | None = None, wait_for_termination: bool = True, databricks_conn_id: str = "databricks_default", + workflow_run_metadata: dict | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ): @@ -962,6 +971,10 @@ def __init__( self.databricks_conn_id = databricks_conn_id self.databricks_run_id: int | None = None self.deferrable = deferrable + + # This is used to store the metadata of the Databricks job run when the job is launched from within DatabricksWorkflowTaskGroup. + self.workflow_run_metadata: dict | None = workflow_run_metadata + super().__init__(**kwargs) @cached_property @@ -1016,6 +1029,79 @@ def _get_databricks_task_id(self, task_id: str) -> str: """Get the databricks task ID using dag_id and task_id. Removes illegal characters.""" return f"{self.dag_id}__{task_id.replace('.', '__')}" + @property + def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None: + """ + Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found. + + If found, returns the task group. Otherwise, return None. + """ + parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group + + while parent_tg: + if getattr(parent_tg, "is_databricks", False): + return parent_tg # type: ignore[return-value] + + if getattr(parent_tg, "task_group", None): + parent_tg = parent_tg.task_group + else: + return None + + return None + + def _extend_workflow_notebook_packages( + self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup + ) -> None: + """Extend the task group packages into the notebook's packages, without adding any duplicates.""" + for task_group_package in databricks_workflow_task_group.notebook_packages: + exists = any( + task_group_package == existing_package for existing_package in self.notebook_packages + ) + if not exists: + self.notebook_packages.append(task_group_package) + + def _convert_to_databricks_workflow_task( + self, relevant_upstreams: list[BaseOperator], context: Context | None = None + ) -> dict[str, object]: + """Convert the operator to a Databricks workflow task that can be a task in a workflow.""" + databricks_workflow_task_group = self._databricks_workflow_task_group + if not databricks_workflow_task_group: + raise AirflowException( + "Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup." + ) + + if hasattr(databricks_workflow_task_group, "notebook_packages"): + self._extend_workflow_notebook_packages(databricks_workflow_task_group) + + if hasattr(databricks_workflow_task_group, "notebook_params"): + self.notebook_params = { + **self.notebook_params, + **databricks_workflow_task_group.notebook_params, + } + + base_task_json = self._get_task_base_json() + result = { + "task_key": self._get_databricks_task_id(self.task_id), + "depends_on": [ + {"task_key": self._get_databricks_task_id(task_id)} + for task_id in self.upstream_task_ids + if task_id in relevant_upstreams + ], + **base_task_json, + } + + if self.existing_cluster_id and self.job_cluster_key: + raise ValueError( + "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task." + ) + + if self.existing_cluster_id: + result["existing_cluster_id"] = self.existing_cluster_id + elif self.job_cluster_key: + result["job_cluster_key"] = self.job_cluster_key + + return result + def _get_run_json(self) -> dict[str, Any]: """Get run json to be used for task submissions.""" run_json = { @@ -1039,6 +1125,17 @@ def launch_notebook_job(self) -> int: self.log.info("Check the job run in Databricks: %s", url) return self.databricks_run_id + def _handle_terminal_run_state(self, run_state: RunState) -> None: + if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value: + raise AirflowException( + f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}" + ) + if not run_state.is_successful: + raise AirflowException( + f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}" + ) + self.log.info("Task succeeded. Final state %s.", run_state.result_state) + def monitor_databricks_job(self) -> None: if self.databricks_run_id is None: raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") @@ -1063,34 +1160,28 @@ def monitor_databricks_job(self) -> None: run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info( - "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state - ) - self.log.info("Current state of the job: %s", run_state.life_cycle_state) - if run_state.life_cycle_state != "TERMINATED": - raise AirflowException( - f"Databricks job failed with state {run_state.life_cycle_state}. " - f"Message: {run_state.state_message}" + "Current state of the databricks task %s is %s", + self._get_databricks_task_id(self.task_id), + run_state.life_cycle_state, ) - if not run_state.is_successful: - raise AirflowException( - f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}" - ) - self.log.info("Task succeeded. Final state %s.", run_state.result_state) + self._handle_terminal_run_state(run_state) def execute(self, context: Context) -> None: - self.launch_notebook_job() + if self._databricks_workflow_task_group: + # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched. + if not self.workflow_run_metadata: + launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch")) + self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id) + workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type] + **self.workflow_run_metadata + ) + self.databricks_run_id = workflow_run_metadata.run_id + self.databricks_conn_id = workflow_run_metadata.conn_id + else: + self.launch_notebook_job() if self.wait_for_termination: self.monitor_databricks_job() def execute_complete(self, context: dict | None, event: dict) -> None: run_state = RunState.from_json(event["run_state"]) - if run_state.life_cycle_state != "TERMINATED": - raise AirflowException( - f"Databricks job failed with state {run_state.life_cycle_state}. " - f"Message: {run_state.state_message}" - ) - if not run_state.is_successful: - raise AirflowException( - f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}" - ) - self.log.info("Task succeeded. Final state %s.", run_state.result_state) + self._handle_terminal_run_state(run_state) diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py new file mode 100644 index 0000000000000..8203145314fd0 --- /dev/null +++ b/airflow/providers/databricks/operators/databricks_workflow.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from mergedeep import merge + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState +from airflow.utils.task_group import TaskGroup + +if TYPE_CHECKING: + from types import TracebackType + + from airflow.models.taskmixin import DAGNode + from airflow.utils.context import Context + + +@dataclass +class WorkflowRunMetadata: + """ + Metadata for a Databricks workflow run. + + :param run_id: The ID of the Databricks workflow run. + :param job_id: The ID of the Databricks workflow job. + :param conn_id: The connection ID used to connect to Databricks. + """ + + conn_id: str + job_id: str + run_id: int + + +def _flatten_node( + node: TaskGroup | BaseOperator | DAGNode, tasks: list[BaseOperator] | None = None +) -> list[BaseOperator]: + """Flatten a node (either a TaskGroup or Operator) to a list of nodes.""" + if tasks is None: + tasks = [] + if isinstance(node, BaseOperator): + return [node] + + if isinstance(node, TaskGroup): + new_tasks = [] + for _, child in node.children.items(): + new_tasks += _flatten_node(child, tasks) + + return tasks + new_tasks + + return tasks + + +class _CreateDatabricksWorkflowOperator(BaseOperator): + """ + Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified in a DAG. + + :param task_id: The task_id of the operator + :param databricks_conn_id: The connection ID to use when connecting to Databricks. + :param existing_clusters: A list of existing clusters to use for the workflow. + :param extra_job_params: A dictionary of extra properties which will override the default Databricks + Workflow Job definitions. + :param job_clusters: A list of job clusters to use for the workflow. + :param max_concurrent_runs: The maximum number of concurrent runs for the workflow. + :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters + will be passed to all notebooks in the workflow. + :param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be + populated after instantiation using the `add_task` method. + """ + + template_fields = ("notebook_params",) + caller = "_CreateDatabricksWorkflowOperator" + + def __init__( + self, + task_id: str, + databricks_conn_id: str, + existing_clusters: list[str] | None = None, + extra_job_params: dict[str, Any] | None = None, + job_clusters: list[dict[str, object]] | None = None, + max_concurrent_runs: int = 1, + notebook_params: dict | None = None, + tasks_to_convert: list[BaseOperator] | None = None, + **kwargs, + ): + self.databricks_conn_id = databricks_conn_id + self.existing_clusters = existing_clusters or [] + self.extra_job_params = extra_job_params or {} + self.job_clusters = job_clusters or [] + self.max_concurrent_runs = max_concurrent_runs + self.notebook_params = notebook_params or {} + self.tasks_to_convert = tasks_to_convert or [] + self.relevant_upstreams = [task_id] + super().__init__(task_id=task_id, **kwargs) + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + caller=caller, + ) + + @cached_property + def _hook(self) -> DatabricksHook: + return self._get_hook(caller=self.caller) + + def add_task(self, task: BaseOperator) -> None: + """Add a task to the list of tasks to convert to a Databricks workflow.""" + self.tasks_to_convert.append(task) + + @property + def job_name(self) -> str: + if not self.task_group: + raise AirflowException("Task group must be set before accessing job_name") + return f"{self.dag_id}.{self.task_group.group_id}" + + def create_workflow_json(self, context: Context | None = None) -> dict[str, object]: + """Create a workflow json to be used in the Databricks API.""" + task_json = [ + task._convert_to_databricks_workflow_task( # type: ignore[attr-defined] + relevant_upstreams=self.relevant_upstreams, context=context + ) + for task in self.tasks_to_convert + ] + + default_json = { + "name": self.job_name, + "email_notifications": {"no_alert_for_skipped_runs": False}, + "timeout_seconds": 0, + "tasks": task_json, + "format": "MULTI_TASK", + "job_clusters": self.job_clusters, + "max_concurrent_runs": self.max_concurrent_runs, + } + return merge(default_json, self.extra_job_params) + + def _create_or_reset_job(self, context: Context) -> int: + job_spec = self.create_workflow_json(context=context) + existing_jobs = self._hook.list_jobs(job_name=self.job_name) + job_id = existing_jobs[0]["job_id"] if existing_jobs else None + if job_id: + self.log.info( + "Updating existing Databricks workflow job %s with spec %s", + self.job_name, + json.dumps(job_spec, indent=2), + ) + self._hook.reset_job(job_id, job_spec) + else: + self.log.info( + "Creating new Databricks workflow job %s with spec %s", + self.job_name, + json.dumps(job_spec, indent=2), + ) + job_id = self._hook.create_job(job_spec) + return job_id + + def _wait_for_job_to_start(self, run_id: int) -> None: + run_url = self._hook.get_run_page_url(run_id) + self.log.info("Check the progress of the Databricks job at %s", run_url) + life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state + if life_cycle_state not in ( + RunLifeCycleState.PENDING.value, + RunLifeCycleState.RUNNING.value, + RunLifeCycleState.BLOCKED.value, + ): + raise AirflowException(f"Could not start the workflow job. State: {life_cycle_state}") + while life_cycle_state in (RunLifeCycleState.PENDING.value, RunLifeCycleState.BLOCKED.value): + self.log.info("Waiting for the Databricks job to start running") + time.sleep(5) + life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state + self.log.info("Databricks job started. State: %s", life_cycle_state) + + def execute(self, context: Context) -> Any: + if not isinstance(self.task_group, DatabricksWorkflowTaskGroup): + raise AirflowException("Task group must be a DatabricksWorkflowTaskGroup") + + job_id = self._create_or_reset_job(context) + + run_id = self._hook.run_now( + { + "job_id": job_id, + "jar_params": self.task_group.jar_params, + "notebook_params": self.notebook_params, + "python_params": self.task_group.python_params, + "spark_submit_params": self.task_group.spark_submit_params, + } + ) + + self._wait_for_job_to_start(run_id) + + return { + "conn_id": self.databricks_conn_id, + "job_id": job_id, + "run_id": run_id, + } + + +class DatabricksWorkflowTaskGroup(TaskGroup): + """ + A task group that takes a list of tasks and creates a databricks workflow. + + The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow + based on the metadata produced by those tasks. For a task to be eligible for this + TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` method. If any tasks + do not contain this method then the Taskgroup will raise an error at parse time. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksWorkflowTaskGroup` + + :param databricks_conn_id: The name of the databricks connection to use. + :param existing_clusters: A list of existing clusters to use for this workflow. + :param extra_job_params: A dictionary containing properties which will override the default + Databricks Workflow Job definitions. + :param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar + tasks in the workflow. + :param job_clusters: A list of job clusters to use for this workflow. + :param max_concurrent_runs: The maximum number of concurrent runs for this workflow. + :param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined + at the workflow task group level are installed for each of the notebook tasks under it. And + packages defined at the notebook task level are installed specific for the notebook task. + :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters + will be passed to all notebook tasks in the workflow. + :param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to + all python tasks in the workflow. + :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters + will be passed to all spark submit tasks. + """ + + is_databricks = True + + def __init__( + self, + databricks_conn_id: str, + existing_clusters: list[str] | None = None, + extra_job_params: dict[str, Any] | None = None, + jar_params: list[str] | None = None, + job_clusters: list[dict] | None = None, + max_concurrent_runs: int = 1, + notebook_packages: list[dict[str, Any]] | None = None, + notebook_params: dict | None = None, + python_params: list | None = None, + spark_submit_params: list | None = None, + **kwargs, + ): + self.databricks_conn_id = databricks_conn_id + self.existing_clusters = existing_clusters or [] + self.extra_job_params = extra_job_params or {} + self.jar_params = jar_params or [] + self.job_clusters = job_clusters or [] + self.max_concurrent_runs = max_concurrent_runs + self.notebook_packages = notebook_packages or [] + self.notebook_params = notebook_params or {} + self.python_params = python_params or [] + self.spark_submit_params = spark_submit_params or [] + super().__init__(**kwargs) + + def __exit__( + self, _type: type[BaseException] | None, _value: BaseException | None, _tb: TracebackType | None + ) -> None: + """Exit the context manager and add tasks to a single ``_CreateDatabricksWorkflowOperator``.""" + roots = list(self.get_roots()) + tasks = _flatten_node(self) + + create_databricks_workflow_task = _CreateDatabricksWorkflowOperator( + dag=self.dag, + task_group=self, + task_id="launch", + databricks_conn_id=self.databricks_conn_id, + existing_clusters=self.existing_clusters, + extra_job_params=self.extra_job_params, + job_clusters=self.job_clusters, + max_concurrent_runs=self.max_concurrent_runs, + notebook_params=self.notebook_params, + ) + + for task in tasks: + if not ( + hasattr(task, "_convert_to_databricks_workflow_task") + and callable(task._convert_to_databricks_workflow_task) + ): + raise AirflowException( + f"Task {task.task_id} does not support conversion to databricks workflow task." + ) + + task.workflow_run_metadata = create_databricks_workflow_task.output + create_databricks_workflow_task.relevant_upstreams.append(task.task_id) + create_databricks_workflow_task.add_task(task) + + for root_task in roots: + root_task.set_upstream(create_databricks_workflow_task) + + super().__exit__(_type, _value, _tb) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index a6e2174640a5b..80506dc16c15d 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -72,6 +72,7 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.9.2, <4 + - mergedeep>=1.3.4 additional-extras: # pip install apache-airflow-providers-databricks[sdk] @@ -108,6 +109,12 @@ integrations: - /docs/apache-airflow-providers-databricks/operators/repos_delete.rst logo: /integration-logos/databricks/Databricks.png tags: [service] + - integration-name: Databricks Workflow + external-doc-url: https://docs.databricks.com/en/workflows/index.html + how-to-guide: + - /docs/apache-airflow-providers-databricks/operators/workflow.rst + logo: /integration-logos/databricks/Databricks.png + tags: [service] operators: - integration-name: Databricks @@ -119,6 +126,9 @@ operators: - integration-name: Databricks Repos python-modules: - airflow.providers.databricks.operators.databricks_repos + - integration-name: Databricks Workflow + python-modules: + - airflow.providers.databricks.operators.databricks_workflow hooks: - integration-name: Databricks diff --git a/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png new file mode 100644 index 0000000000000..3a3cb669e0c08 Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png differ diff --git a/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png new file mode 100644 index 0000000000000..cb189b8105aaf Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png differ diff --git a/docs/apache-airflow-providers-databricks/operators/workflow.rst b/docs/apache-airflow-providers-databricks/operators/workflow.rst new file mode 100644 index 0000000000000..f58514dd5187c --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/workflow.rst @@ -0,0 +1,71 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:DatabricksWorkflowTaskGroup: + + +DatabricksWorkflowTaskGroup +=========================== + +Use the :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` to launch and monitor +Databricks notebook job runs as Airflow tasks. The task group launches a `Databricks Workflow `_ and runs the notebook jobs from within it, resulting in a `75% cost reduction `_ ($0.40/DBU for all-purpose compute, $0.07/DBU for Jobs compute) when compared to executing ``DatabricksNotebookOperator`` outside of ``DatabricksWorkflowTaskGroup``. + + +There are a few advantages to defining your Databricks Workflows in Airflow: + +======================================= ============================================= ================================= +Authoring interface via Databricks (Web-based with Databricks UI) via Airflow(Code with Airflow DAG) +======================================= ============================================= ================================= +Workflow compute pricing ✅ ✅ +Notebook code in source control ✅ ✅ +Workflow structure in source control ✅ +Retry from beginning ✅ ✅ +Retry single task ✅ +Task groups within Workflows ✅ +Trigger workflows from other DAGs ✅ +Workflow-level parameters ✅ +======================================= ============================================= ================================= + +Examples +-------- + +Example of what a DAG looks like with a DatabricksWorkflowTaskGroup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks_workflow.py + :language: python + :start-after: [START howto_databricks_workflow_notebook] + :end-before: [END howto_databricks_workflow_notebook] + +With this example, Airflow will produce a job named ``.test_workflow__`` that will +run task ``notebook_1`` and then ``notebook_2``. The job will be created in the databricks workspace +if it does not already exist. If the job already exists, it will be updated to match +the workflow defined in the DAG. + +The following image displays the resulting Databricks Workflow in the Airflow UI (based on the above example provided) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. image:: ../img/databricks_workflow_task_group_airflow_graph_view.png + +The corresponding Databricks Workflow in the Databricks UI for the run triggered from the Airflow DAG is depicted below +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. image:: ../img/workflow_run_databricks_graph_view.png + + +To minimize update conflicts, we recommend that you keep parameters in the ``notebook_params`` of the +``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible. +This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on the job trigger time and +do not modify the job definition. diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 7f29984d329b5..01c1e3378589d 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -411,6 +411,7 @@ "apache-airflow-providers-common-sql>=1.10.0", "apache-airflow>=2.7.0", "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0", + "mergedeep>=1.3.4", "requests>=2.27.0,<3" ], "devel-deps": [ diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index d6e7eb3892919..2774385ea5a71 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -2014,3 +2014,122 @@ def test_zero_execution_timeout_raises_error(self): "Set it instead to `None` if you desire the task to run indefinitely." ) assert str(exc_info.value) == exception_message + + def test_extend_workflow_notebook_packages(self): + """Test that the operator can extend the notebook packages of a Databricks workflow task group.""" + databricks_workflow_task_group = MagicMock() + databricks_workflow_task_group.notebook_packages = [ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "pandas"}}, + ] + + operator = DatabricksNotebookOperator( + notebook_path="/path/to/notebook", + source="WORKSPACE", + task_id="test_task", + notebook_packages=[ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "scipy"}}, + ], + ) + + operator._extend_workflow_notebook_packages(databricks_workflow_task_group) + + assert operator.notebook_packages == [ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "scipy"}}, + {"pypi": {"package": "pandas"}}, + ] + + def test_convert_to_databricks_workflow_task(self): + """Test that the operator can convert itself to a Databricks workflow task.""" + dag = DAG(dag_id="example_dag", start_date=datetime.now()) + operator = DatabricksNotebookOperator( + notebook_path="/path/to/notebook", + source="WORKSPACE", + task_id="test_task", + notebook_packages=[ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "scipy"}}, + ], + dag=dag, + ) + + databricks_workflow_task_group = MagicMock() + databricks_workflow_task_group.notebook_packages = [ + {"pypi": {"package": "numpy"}}, + ] + databricks_workflow_task_group.notebook_params = {"param1": "value1"} + + operator.notebook_packages = [{"pypi": {"package": "pandas"}}] + operator.notebook_params = {"param2": "value2"} + operator.task_group = databricks_workflow_task_group + operator.task_id = "test_task" + operator.upstream_task_ids = ["upstream_task"] + relevant_upstreams = [MagicMock(task_id="upstream_task")] + + task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams) + + expected_json = { + "task_key": "example_dag__test_task", + "depends_on": [], + "timeout_seconds": 0, + "email_notifications": {}, + "notebook_task": { + "notebook_path": "/path/to/notebook", + "source": "WORKSPACE", + "base_parameters": { + "param2": "value2", + "param1": "value1", + }, + }, + "libraries": [ + {"pypi": {"package": "pandas"}}, + {"pypi": {"package": "numpy"}}, + ], + } + + assert task_json == expected_json + + def test_convert_to_databricks_workflow_task_no_task_group(self): + """Test that an error is raised if the operator is not in a TaskGroup.""" + operator = DatabricksNotebookOperator( + notebook_path="/path/to/notebook", + source="WORKSPACE", + task_id="test_task", + notebook_packages=[ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "scipy"}}, + ], + ) + operator.task_group = None + relevant_upstreams = [MagicMock(task_id="upstream_task")] + + with pytest.raises( + AirflowException, + match="Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup.", + ): + operator._convert_to_databricks_workflow_task(relevant_upstreams) + + def test_convert_to_databricks_workflow_task_cluster_conflict(self): + """Test that an error is raised if both `existing_cluster_id` and `job_cluster_key` are set.""" + operator = DatabricksNotebookOperator( + notebook_path="/path/to/notebook", + source="WORKSPACE", + task_id="test_task", + notebook_packages=[ + {"pypi": {"package": "numpy"}}, + {"pypi": {"package": "scipy"}}, + ], + ) + databricks_workflow_task_group = MagicMock() + operator.existing_cluster_id = "existing-cluster-id" + operator.job_cluster_key = "job-cluster-key" + operator.task_group = databricks_workflow_task_group + relevant_upstreams = [MagicMock(task_id="upstream_task")] + + with pytest.raises( + ValueError, + match="Both existing_cluster_id and job_cluster_key are set. Only one can be set per task.", + ): + operator._convert_to_databricks_workflow_task(relevant_upstreams) diff --git a/tests/providers/databricks/operators/test_databricks_workflow.py b/tests/providers/databricks/operators/test_databricks_workflow.py new file mode 100644 index 0000000000000..99f1a9d14815d --- /dev/null +++ b/tests/providers/databricks/operators/test_databricks_workflow.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow import DAG +from airflow.exceptions import AirflowException +from airflow.models.baseoperator import BaseOperator +from airflow.operators.empty import EmptyOperator +from airflow.providers.databricks.hooks.databricks import RunLifeCycleState +from airflow.providers.databricks.operators.databricks_workflow import ( + DatabricksWorkflowTaskGroup, + _CreateDatabricksWorkflowOperator, + _flatten_node, +) +from airflow.utils import timezone + +pytestmark = pytest.mark.db_test + +DEFAULT_DATE = timezone.datetime(2021, 1, 1) + + +@pytest.fixture +def mock_databricks_hook(): + """Provide a mock DatabricksHook.""" + with patch("airflow.providers.databricks.operators.databricks_workflow.DatabricksHook") as mock_hook: + yield mock_hook + + +@pytest.fixture +def context(): + """Provide a mock context object.""" + return MagicMock() + + +@pytest.fixture +def mock_task_group(): + """Provide a mock DatabricksWorkflowTaskGroup with necessary attributes.""" + mock_group = MagicMock(spec=DatabricksWorkflowTaskGroup) + mock_group.group_id = "test_group" + return mock_group + + +def test_flatten_node(): + """Test that _flatten_node returns a flat list of operators.""" + task_group = MagicMock(spec=DatabricksWorkflowTaskGroup) + base_operator = MagicMock(spec=BaseOperator) + task_group.children = {"task1": base_operator, "task2": base_operator} + + result = _flatten_node(task_group) + assert result == [base_operator, base_operator] + + +def test_create_workflow_json(mock_databricks_hook, context, mock_task_group): + """Test that _CreateDatabricksWorkflowOperator.create_workflow_json returns the expected JSON.""" + operator = _CreateDatabricksWorkflowOperator( + task_id="test_task", + databricks_conn_id="databricks_default", + ) + operator.task_group = mock_task_group + + task = MagicMock(spec=BaseOperator) + task._convert_to_databricks_workflow_task = MagicMock(return_value={}) + operator.add_task(task) + + workflow_json = operator.create_workflow_json(context=context) + + assert ".test_group" in workflow_json["name"] + assert "tasks" in workflow_json + assert workflow_json["format"] == "MULTI_TASK" + assert workflow_json["email_notifications"] == {"no_alert_for_skipped_runs": False} + assert workflow_json["job_clusters"] == [] + assert workflow_json["max_concurrent_runs"] == 1 + assert workflow_json["timeout_seconds"] == 0 + + +def test_create_or_reset_job_existing(mock_databricks_hook, context, mock_task_group): + """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job resets the job if it already exists.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.task_group = mock_task_group + operator._hook.list_jobs.return_value = [{"job_id": 123}] + operator._hook.create_job.return_value = 123 + + job_id = operator._create_or_reset_job(context) + assert job_id == 123 + operator._hook.reset_job.assert_called_once() + + +def test_create_or_reset_job_new(mock_databricks_hook, context, mock_task_group): + """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job creates a new job if it does not exist.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.task_group = mock_task_group + operator._hook.list_jobs.return_value = [] + operator._hook.create_job.return_value = 456 + + job_id = operator._create_or_reset_job(context) + assert job_id == 456 + operator._hook.create_job.assert_called_once() + + +def test_wait_for_job_to_start(mock_databricks_hook): + """Test that _CreateDatabricksWorkflowOperator._wait_for_job_to_start waits for the job to start.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + mock_hook_instance = mock_databricks_hook.return_value + mock_hook_instance.get_run_state.side_effect = [ + MagicMock(life_cycle_state=RunLifeCycleState.PENDING.value), + MagicMock(life_cycle_state=RunLifeCycleState.RUNNING.value), + ] + + operator._wait_for_job_to_start(123) + mock_hook_instance.get_run_state.assert_called() + + +def test_execute(mock_databricks_hook, context, mock_task_group): + """Test that _CreateDatabricksWorkflowOperator.execute runs the task group.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.task_group = mock_task_group + mock_task_group.jar_params = {} + mock_task_group.python_params = {} + mock_task_group.spark_submit_params = {} + + mock_hook_instance = mock_databricks_hook.return_value + mock_hook_instance.run_now.return_value = 789 + mock_hook_instance.list_jobs.return_value = [{"job_id": 123}] + mock_hook_instance.get_run_state.return_value = MagicMock( + life_cycle_state=RunLifeCycleState.RUNNING.value + ) + + task = MagicMock(spec=BaseOperator) + task._convert_to_databricks_workflow_task = MagicMock(return_value={}) + operator.add_task(task) + + result = operator.execute(context) + + assert result == { + "conn_id": "databricks_default", + "job_id": 123, + "run_id": 789, + } + mock_hook_instance.run_now.assert_called_once() + + +def test_execute_invalid_task_group(context): + """Test that _CreateDatabricksWorkflowOperator.execute raises an exception if the task group is invalid.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.task_group = MagicMock() # Not a DatabricksWorkflowTaskGroup + + with pytest.raises(AirflowException, match="Task group must be a DatabricksWorkflowTaskGroup"): + operator.execute(context) + + +@pytest.fixture +def mock_databricks_workflow_operator(): + with patch( + "airflow.providers.databricks.operators.databricks_workflow._CreateDatabricksWorkflowOperator" + ) as mock_operator: + yield mock_operator + + +def test_task_group_initialization(): + """Test that DatabricksWorkflowTaskGroup initializes correctly.""" + with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE) as example_dag: + with DatabricksWorkflowTaskGroup( + group_id="test_databricks_workflow", databricks_conn_id="databricks_conn" + ) as task_group: + task_1 = EmptyOperator(task_id="task1") + task_1._convert_to_databricks_workflow_task = MagicMock(return_value={}) + assert task_group.group_id == "test_databricks_workflow" + assert task_group.databricks_conn_id == "databricks_conn" + assert task_group.dag == example_dag + + +def test_task_group_exit_creates_operator(mock_databricks_workflow_operator): + """Test that DatabricksWorkflowTaskGroup creates a _CreateDatabricksWorkflowOperator on exit.""" + with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE) as example_dag: + with DatabricksWorkflowTaskGroup( + group_id="test_databricks_workflow", + databricks_conn_id="databricks_conn", + ) as task_group: + task1 = MagicMock(task_id="task1") + task1._convert_to_databricks_workflow_task = MagicMock(return_value={}) + task2 = MagicMock(task_id="task2") + task2._convert_to_databricks_workflow_task = MagicMock(return_value={}) + + task_group.add(task1) + task_group.add(task2) + + task1.set_downstream(task2) + + mock_databricks_workflow_operator.assert_called_once_with( + dag=example_dag, + task_group=task_group, + task_id="launch", + databricks_conn_id="databricks_conn", + existing_clusters=[], + extra_job_params={}, + job_clusters=[], + max_concurrent_runs=1, + notebook_params={}, + ) + + +def test_task_group_root_tasks_set_upstream_to_operator(mock_databricks_workflow_operator): + """Test that tasks added to a DatabricksWorkflowTaskGroup are set upstream to the operator.""" + with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE): + with DatabricksWorkflowTaskGroup( + group_id="test_databricks_workflow1", + databricks_conn_id="databricks_conn", + ) as task_group: + task1 = MagicMock(task_id="task1") + task1._convert_to_databricks_workflow_task = MagicMock(return_value={}) + task_group.add(task1) + + create_operator_instance = mock_databricks_workflow_operator.return_value + task1.set_upstream.assert_called_once_with(create_operator_instance) diff --git a/tests/system/providers/databricks/example_databricks_workflow.py b/tests/system/providers/databricks/example_databricks_workflow.py new file mode 100644 index 0000000000000..6b05f34684c9d --- /dev/null +++ b/tests/system/providers/databricks/example_databricks_workflow.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG for using the DatabricksWorkflowTaskGroup and DatabricksNotebookOperator.""" + +from __future__ import annotations + +import os +from datetime import timedelta + +from airflow.models.dag import DAG +from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator +from airflow.providers.databricks.operators.databricks_workflow import DatabricksWorkflowTaskGroup +from airflow.utils.timezone import datetime + +EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) + +DATABRICKS_CONN_ID = os.getenv("DATABRICKS_CONN_ID", "databricks_conn") +DATABRICKS_NOTIFICATION_EMAIL = os.getenv("DATABRICKS_NOTIFICATION_EMAIL", "your_email@serviceprovider.com") + +GROUP_ID = os.getenv("DATABRICKS_GROUP_ID", "1234").replace(".", "_") +USER = os.environ.get("USER") + +job_cluster_spec = [ + { + "job_cluster_key": "Shared_job_cluster", + "new_cluster": { + "cluster_name": "", + "spark_version": "11.3.x-scala2.12", + "aws_attributes": { + "first_on_demand": 1, + "availability": "SPOT_WITH_FALLBACK", + "zone_id": "us-east-2b", + "spot_bid_price_percent": 100, + "ebs_volume_count": 0, + }, + "node_type_id": "i3.xlarge", + "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"}, + "enable_elastic_disk": False, + "data_security_mode": "LEGACY_SINGLE_USER_STANDARD", + "runtime_engine": "STANDARD", + "num_workers": 8, + }, + } +] +dag = DAG( + dag_id="example_databricks_workflow", + start_date=datetime(2022, 1, 1), + schedule_interval=None, + catchup=False, + tags=["example", "databricks"], +) +with dag: + # [START howto_databricks_workflow_notebook] + task_group = DatabricksWorkflowTaskGroup( + group_id=f"test_workflow_{USER}_{GROUP_ID}", + databricks_conn_id=DATABRICKS_CONN_ID, + job_clusters=job_cluster_spec, + notebook_params={"ts": "{{ ts }}"}, + notebook_packages=[ + { + "pypi": { + "package": "simplejson==3.18.0", # Pin specification version of a package like this. + "repo": "https://pypi.org/simple", # You can specify your required Pypi index here. + } + }, + ], + extra_job_params={ + "email_notifications": { + "on_start": [DATABRICKS_NOTIFICATION_EMAIL], + }, + }, + ) + with task_group: + notebook_1 = DatabricksNotebookOperator( + task_id="workflow_notebook_1", + databricks_conn_id=DATABRICKS_CONN_ID, + notebook_path="/Shared/Notebook_1", + notebook_packages=[{"pypi": {"package": "Faker"}}], + source="WORKSPACE", + job_cluster_key="Shared_job_cluster", + execution_timeout=timedelta(seconds=600), + ) + notebook_2 = DatabricksNotebookOperator( + task_id="workflow_notebook_2", + databricks_conn_id=DATABRICKS_CONN_ID, + notebook_path="/Shared/Notebook_2", + source="WORKSPACE", + job_cluster_key="Shared_job_cluster", + notebook_params={"foo": "bar", "ds": "{{ ds }}"}, + ) + notebook_1 >> notebook_2 + # [END howto_databricks_workflow_notebook] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)