diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py
index 710074d239af0..ee5349add055e 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -29,6 +29,7 @@
from __future__ import annotations
import json
+from enum import Enum
from typing import Any
from requests import exceptions as requests_exceptions
@@ -63,6 +64,23 @@
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
+class RunLifeCycleState(Enum):
+ """Enum for the run life cycle state concept of Databricks runs.
+
+ See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
+ """
+
+ BLOCKED = "BLOCKED"
+ INTERNAL_ERROR = "INTERNAL_ERROR"
+ PENDING = "PENDING"
+ QUEUED = "QUEUED"
+ RUNNING = "RUNNING"
+ SKIPPED = "SKIPPED"
+ TERMINATED = "TERMINATED"
+ TERMINATING = "TERMINATING"
+ WAITING_FOR_RETRY = "WAITING_FOR_RETRY"
+
+
class RunState:
"""Utility class for the run state concept of Databricks runs."""
diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py
index ff8de101326be..d6118f247f597 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -29,13 +29,18 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState
+from airflow.providers.databricks.operators.databricks_workflow import (
+ DatabricksWorkflowTaskGroup,
+ WorkflowRunMetadata,
+)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
+ from airflow.utils.task_group import TaskGroup
DEFER_METHOD_NAME = "execute_complete"
XCOM_RUN_ID_KEY = "run_id"
@@ -926,7 +931,10 @@ class DatabricksNotebookOperator(BaseOperator):
:param deferrable: Run operator in the deferrable mode.
"""
- template_fields = ("notebook_params",)
+ template_fields = (
+ "notebook_params",
+ "workflow_run_metadata",
+ )
CALLER = "DatabricksNotebookOperator"
def __init__(
@@ -944,6 +952,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
+ workflow_run_metadata: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
@@ -962,6 +971,10 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
self.deferrable = deferrable
+
+ # This is used to store the metadata of the Databricks job run when the job is launched from within DatabricksWorkflowTaskGroup.
+ self.workflow_run_metadata: dict | None = workflow_run_metadata
+
super().__init__(**kwargs)
@cached_property
@@ -1016,6 +1029,79 @@ def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
return f"{self.dag_id}__{task_id.replace('.', '__')}"
+ @property
+ def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
+ """
+ Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found.
+
+ If found, returns the task group. Otherwise, return None.
+ """
+ parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group
+
+ while parent_tg:
+ if getattr(parent_tg, "is_databricks", False):
+ return parent_tg # type: ignore[return-value]
+
+ if getattr(parent_tg, "task_group", None):
+ parent_tg = parent_tg.task_group
+ else:
+ return None
+
+ return None
+
+ def _extend_workflow_notebook_packages(
+ self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
+ ) -> None:
+ """Extend the task group packages into the notebook's packages, without adding any duplicates."""
+ for task_group_package in databricks_workflow_task_group.notebook_packages:
+ exists = any(
+ task_group_package == existing_package for existing_package in self.notebook_packages
+ )
+ if not exists:
+ self.notebook_packages.append(task_group_package)
+
+ def _convert_to_databricks_workflow_task(
+ self, relevant_upstreams: list[BaseOperator], context: Context | None = None
+ ) -> dict[str, object]:
+ """Convert the operator to a Databricks workflow task that can be a task in a workflow."""
+ databricks_workflow_task_group = self._databricks_workflow_task_group
+ if not databricks_workflow_task_group:
+ raise AirflowException(
+ "Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup."
+ )
+
+ if hasattr(databricks_workflow_task_group, "notebook_packages"):
+ self._extend_workflow_notebook_packages(databricks_workflow_task_group)
+
+ if hasattr(databricks_workflow_task_group, "notebook_params"):
+ self.notebook_params = {
+ **self.notebook_params,
+ **databricks_workflow_task_group.notebook_params,
+ }
+
+ base_task_json = self._get_task_base_json()
+ result = {
+ "task_key": self._get_databricks_task_id(self.task_id),
+ "depends_on": [
+ {"task_key": self._get_databricks_task_id(task_id)}
+ for task_id in self.upstream_task_ids
+ if task_id in relevant_upstreams
+ ],
+ **base_task_json,
+ }
+
+ if self.existing_cluster_id and self.job_cluster_key:
+ raise ValueError(
+ "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task."
+ )
+
+ if self.existing_cluster_id:
+ result["existing_cluster_id"] = self.existing_cluster_id
+ elif self.job_cluster_key:
+ result["job_cluster_key"] = self.job_cluster_key
+
+ return result
+
def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
@@ -1039,6 +1125,17 @@ def launch_notebook_job(self) -> int:
self.log.info("Check the job run in Databricks: %s", url)
return self.databricks_run_id
+ def _handle_terminal_run_state(self, run_state: RunState) -> None:
+ if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
+ raise AirflowException(
+ f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}"
+ )
+ if not run_state.is_successful:
+ raise AirflowException(
+ f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
+ )
+ self.log.info("Task succeeded. Final state %s.", run_state.result_state)
+
def monitor_databricks_job(self) -> None:
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
@@ -1063,34 +1160,28 @@ def monitor_databricks_job(self) -> None:
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info(
- "task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state
- )
- self.log.info("Current state of the job: %s", run_state.life_cycle_state)
- if run_state.life_cycle_state != "TERMINATED":
- raise AirflowException(
- f"Databricks job failed with state {run_state.life_cycle_state}. "
- f"Message: {run_state.state_message}"
+ "Current state of the databricks task %s is %s",
+ self._get_databricks_task_id(self.task_id),
+ run_state.life_cycle_state,
)
- if not run_state.is_successful:
- raise AirflowException(
- f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
- )
- self.log.info("Task succeeded. Final state %s.", run_state.result_state)
+ self._handle_terminal_run_state(run_state)
def execute(self, context: Context) -> None:
- self.launch_notebook_job()
+ if self._databricks_workflow_task_group:
+ # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
+ if not self.workflow_run_metadata:
+ launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
+ self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
+ workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type]
+ **self.workflow_run_metadata
+ )
+ self.databricks_run_id = workflow_run_metadata.run_id
+ self.databricks_conn_id = workflow_run_metadata.conn_id
+ else:
+ self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()
def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
- if run_state.life_cycle_state != "TERMINATED":
- raise AirflowException(
- f"Databricks job failed with state {run_state.life_cycle_state}. "
- f"Message: {run_state.state_message}"
- )
- if not run_state.is_successful:
- raise AirflowException(
- f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
- )
- self.log.info("Task succeeded. Final state %s.", run_state.result_state)
+ self._handle_terminal_run_state(run_state)
diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py
new file mode 100644
index 0000000000000..8203145314fd0
--- /dev/null
+++ b/airflow/providers/databricks/operators/databricks_workflow.py
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import time
+from dataclasses import dataclass
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from mergedeep import merge
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState
+from airflow.utils.task_group import TaskGroup
+
+if TYPE_CHECKING:
+ from types import TracebackType
+
+ from airflow.models.taskmixin import DAGNode
+ from airflow.utils.context import Context
+
+
+@dataclass
+class WorkflowRunMetadata:
+ """
+ Metadata for a Databricks workflow run.
+
+ :param run_id: The ID of the Databricks workflow run.
+ :param job_id: The ID of the Databricks workflow job.
+ :param conn_id: The connection ID used to connect to Databricks.
+ """
+
+ conn_id: str
+ job_id: str
+ run_id: int
+
+
+def _flatten_node(
+ node: TaskGroup | BaseOperator | DAGNode, tasks: list[BaseOperator] | None = None
+) -> list[BaseOperator]:
+ """Flatten a node (either a TaskGroup or Operator) to a list of nodes."""
+ if tasks is None:
+ tasks = []
+ if isinstance(node, BaseOperator):
+ return [node]
+
+ if isinstance(node, TaskGroup):
+ new_tasks = []
+ for _, child in node.children.items():
+ new_tasks += _flatten_node(child, tasks)
+
+ return tasks + new_tasks
+
+ return tasks
+
+
+class _CreateDatabricksWorkflowOperator(BaseOperator):
+ """
+ Creates a Databricks workflow from a DatabricksWorkflowTaskGroup specified in a DAG.
+
+ :param task_id: The task_id of the operator
+ :param databricks_conn_id: The connection ID to use when connecting to Databricks.
+ :param existing_clusters: A list of existing clusters to use for the workflow.
+ :param extra_job_params: A dictionary of extra properties which will override the default Databricks
+ Workflow Job definitions.
+ :param job_clusters: A list of job clusters to use for the workflow.
+ :param max_concurrent_runs: The maximum number of concurrent runs for the workflow.
+ :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
+ will be passed to all notebooks in the workflow.
+ :param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be
+ populated after instantiation using the `add_task` method.
+ """
+
+ template_fields = ("notebook_params",)
+ caller = "_CreateDatabricksWorkflowOperator"
+
+ def __init__(
+ self,
+ task_id: str,
+ databricks_conn_id: str,
+ existing_clusters: list[str] | None = None,
+ extra_job_params: dict[str, Any] | None = None,
+ job_clusters: list[dict[str, object]] | None = None,
+ max_concurrent_runs: int = 1,
+ notebook_params: dict | None = None,
+ tasks_to_convert: list[BaseOperator] | None = None,
+ **kwargs,
+ ):
+ self.databricks_conn_id = databricks_conn_id
+ self.existing_clusters = existing_clusters or []
+ self.extra_job_params = extra_job_params or {}
+ self.job_clusters = job_clusters or []
+ self.max_concurrent_runs = max_concurrent_runs
+ self.notebook_params = notebook_params or {}
+ self.tasks_to_convert = tasks_to_convert or []
+ self.relevant_upstreams = [task_id]
+ super().__init__(task_id=task_id, **kwargs)
+
+ def _get_hook(self, caller: str) -> DatabricksHook:
+ return DatabricksHook(
+ self.databricks_conn_id,
+ caller=caller,
+ )
+
+ @cached_property
+ def _hook(self) -> DatabricksHook:
+ return self._get_hook(caller=self.caller)
+
+ def add_task(self, task: BaseOperator) -> None:
+ """Add a task to the list of tasks to convert to a Databricks workflow."""
+ self.tasks_to_convert.append(task)
+
+ @property
+ def job_name(self) -> str:
+ if not self.task_group:
+ raise AirflowException("Task group must be set before accessing job_name")
+ return f"{self.dag_id}.{self.task_group.group_id}"
+
+ def create_workflow_json(self, context: Context | None = None) -> dict[str, object]:
+ """Create a workflow json to be used in the Databricks API."""
+ task_json = [
+ task._convert_to_databricks_workflow_task( # type: ignore[attr-defined]
+ relevant_upstreams=self.relevant_upstreams, context=context
+ )
+ for task in self.tasks_to_convert
+ ]
+
+ default_json = {
+ "name": self.job_name,
+ "email_notifications": {"no_alert_for_skipped_runs": False},
+ "timeout_seconds": 0,
+ "tasks": task_json,
+ "format": "MULTI_TASK",
+ "job_clusters": self.job_clusters,
+ "max_concurrent_runs": self.max_concurrent_runs,
+ }
+ return merge(default_json, self.extra_job_params)
+
+ def _create_or_reset_job(self, context: Context) -> int:
+ job_spec = self.create_workflow_json(context=context)
+ existing_jobs = self._hook.list_jobs(job_name=self.job_name)
+ job_id = existing_jobs[0]["job_id"] if existing_jobs else None
+ if job_id:
+ self.log.info(
+ "Updating existing Databricks workflow job %s with spec %s",
+ self.job_name,
+ json.dumps(job_spec, indent=2),
+ )
+ self._hook.reset_job(job_id, job_spec)
+ else:
+ self.log.info(
+ "Creating new Databricks workflow job %s with spec %s",
+ self.job_name,
+ json.dumps(job_spec, indent=2),
+ )
+ job_id = self._hook.create_job(job_spec)
+ return job_id
+
+ def _wait_for_job_to_start(self, run_id: int) -> None:
+ run_url = self._hook.get_run_page_url(run_id)
+ self.log.info("Check the progress of the Databricks job at %s", run_url)
+ life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
+ if life_cycle_state not in (
+ RunLifeCycleState.PENDING.value,
+ RunLifeCycleState.RUNNING.value,
+ RunLifeCycleState.BLOCKED.value,
+ ):
+ raise AirflowException(f"Could not start the workflow job. State: {life_cycle_state}")
+ while life_cycle_state in (RunLifeCycleState.PENDING.value, RunLifeCycleState.BLOCKED.value):
+ self.log.info("Waiting for the Databricks job to start running")
+ time.sleep(5)
+ life_cycle_state = self._hook.get_run_state(run_id).life_cycle_state
+ self.log.info("Databricks job started. State: %s", life_cycle_state)
+
+ def execute(self, context: Context) -> Any:
+ if not isinstance(self.task_group, DatabricksWorkflowTaskGroup):
+ raise AirflowException("Task group must be a DatabricksWorkflowTaskGroup")
+
+ job_id = self._create_or_reset_job(context)
+
+ run_id = self._hook.run_now(
+ {
+ "job_id": job_id,
+ "jar_params": self.task_group.jar_params,
+ "notebook_params": self.notebook_params,
+ "python_params": self.task_group.python_params,
+ "spark_submit_params": self.task_group.spark_submit_params,
+ }
+ )
+
+ self._wait_for_job_to_start(run_id)
+
+ return {
+ "conn_id": self.databricks_conn_id,
+ "job_id": job_id,
+ "run_id": run_id,
+ }
+
+
+class DatabricksWorkflowTaskGroup(TaskGroup):
+ """
+ A task group that takes a list of tasks and creates a databricks workflow.
+
+ The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow
+ based on the metadata produced by those tasks. For a task to be eligible for this
+ TaskGroup, it must contain the ``_convert_to_databricks_workflow_task`` method. If any tasks
+ do not contain this method then the Taskgroup will raise an error at parse time.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:DatabricksWorkflowTaskGroup`
+
+ :param databricks_conn_id: The name of the databricks connection to use.
+ :param existing_clusters: A list of existing clusters to use for this workflow.
+ :param extra_job_params: A dictionary containing properties which will override the default
+ Databricks Workflow Job definitions.
+ :param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar
+ tasks in the workflow.
+ :param job_clusters: A list of job clusters to use for this workflow.
+ :param max_concurrent_runs: The maximum number of concurrent runs for this workflow.
+ :param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined
+ at the workflow task group level are installed for each of the notebook tasks under it. And
+ packages defined at the notebook task level are installed specific for the notebook task.
+ :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters
+ will be passed to all notebook tasks in the workflow.
+ :param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to
+ all python tasks in the workflow.
+ :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters
+ will be passed to all spark submit tasks.
+ """
+
+ is_databricks = True
+
+ def __init__(
+ self,
+ databricks_conn_id: str,
+ existing_clusters: list[str] | None = None,
+ extra_job_params: dict[str, Any] | None = None,
+ jar_params: list[str] | None = None,
+ job_clusters: list[dict] | None = None,
+ max_concurrent_runs: int = 1,
+ notebook_packages: list[dict[str, Any]] | None = None,
+ notebook_params: dict | None = None,
+ python_params: list | None = None,
+ spark_submit_params: list | None = None,
+ **kwargs,
+ ):
+ self.databricks_conn_id = databricks_conn_id
+ self.existing_clusters = existing_clusters or []
+ self.extra_job_params = extra_job_params or {}
+ self.jar_params = jar_params or []
+ self.job_clusters = job_clusters or []
+ self.max_concurrent_runs = max_concurrent_runs
+ self.notebook_packages = notebook_packages or []
+ self.notebook_params = notebook_params or {}
+ self.python_params = python_params or []
+ self.spark_submit_params = spark_submit_params or []
+ super().__init__(**kwargs)
+
+ def __exit__(
+ self, _type: type[BaseException] | None, _value: BaseException | None, _tb: TracebackType | None
+ ) -> None:
+ """Exit the context manager and add tasks to a single ``_CreateDatabricksWorkflowOperator``."""
+ roots = list(self.get_roots())
+ tasks = _flatten_node(self)
+
+ create_databricks_workflow_task = _CreateDatabricksWorkflowOperator(
+ dag=self.dag,
+ task_group=self,
+ task_id="launch",
+ databricks_conn_id=self.databricks_conn_id,
+ existing_clusters=self.existing_clusters,
+ extra_job_params=self.extra_job_params,
+ job_clusters=self.job_clusters,
+ max_concurrent_runs=self.max_concurrent_runs,
+ notebook_params=self.notebook_params,
+ )
+
+ for task in tasks:
+ if not (
+ hasattr(task, "_convert_to_databricks_workflow_task")
+ and callable(task._convert_to_databricks_workflow_task)
+ ):
+ raise AirflowException(
+ f"Task {task.task_id} does not support conversion to databricks workflow task."
+ )
+
+ task.workflow_run_metadata = create_databricks_workflow_task.output
+ create_databricks_workflow_task.relevant_upstreams.append(task.task_id)
+ create_databricks_workflow_task.add_task(task)
+
+ for root_task in roots:
+ root_task.set_upstream(create_databricks_workflow_task)
+
+ super().__exit__(_type, _value, _tb)
diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml
index a6e2174640a5b..80506dc16c15d 100644
--- a/airflow/providers/databricks/provider.yaml
+++ b/airflow/providers/databricks/provider.yaml
@@ -72,6 +72,7 @@ dependencies:
# The 2.9.1 (to be released soon) already contains the fix
- databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
- aiohttp>=3.9.2, <4
+ - mergedeep>=1.3.4
additional-extras:
# pip install apache-airflow-providers-databricks[sdk]
@@ -108,6 +109,12 @@ integrations:
- /docs/apache-airflow-providers-databricks/operators/repos_delete.rst
logo: /integration-logos/databricks/Databricks.png
tags: [service]
+ - integration-name: Databricks Workflow
+ external-doc-url: https://docs.databricks.com/en/workflows/index.html
+ how-to-guide:
+ - /docs/apache-airflow-providers-databricks/operators/workflow.rst
+ logo: /integration-logos/databricks/Databricks.png
+ tags: [service]
operators:
- integration-name: Databricks
@@ -119,6 +126,9 @@ operators:
- integration-name: Databricks Repos
python-modules:
- airflow.providers.databricks.operators.databricks_repos
+ - integration-name: Databricks Workflow
+ python-modules:
+ - airflow.providers.databricks.operators.databricks_workflow
hooks:
- integration-name: Databricks
diff --git a/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png
new file mode 100644
index 0000000000000..3a3cb669e0c08
Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/databricks_workflow_task_group_airflow_graph_view.png differ
diff --git a/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png
new file mode 100644
index 0000000000000..cb189b8105aaf
Binary files /dev/null and b/docs/apache-airflow-providers-databricks/img/workflow_run_databricks_graph_view.png differ
diff --git a/docs/apache-airflow-providers-databricks/operators/workflow.rst b/docs/apache-airflow-providers-databricks/operators/workflow.rst
new file mode 100644
index 0000000000000..f58514dd5187c
--- /dev/null
+++ b/docs/apache-airflow-providers-databricks/operators/workflow.rst
@@ -0,0 +1,71 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:DatabricksWorkflowTaskGroup:
+
+
+DatabricksWorkflowTaskGroup
+===========================
+
+Use the :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` to launch and monitor
+Databricks notebook job runs as Airflow tasks. The task group launches a `Databricks Workflow `_ and runs the notebook jobs from within it, resulting in a `75% cost reduction `_ ($0.40/DBU for all-purpose compute, $0.07/DBU for Jobs compute) when compared to executing ``DatabricksNotebookOperator`` outside of ``DatabricksWorkflowTaskGroup``.
+
+
+There are a few advantages to defining your Databricks Workflows in Airflow:
+
+======================================= ============================================= =================================
+Authoring interface via Databricks (Web-based with Databricks UI) via Airflow(Code with Airflow DAG)
+======================================= ============================================= =================================
+Workflow compute pricing ✅ ✅
+Notebook code in source control ✅ ✅
+Workflow structure in source control ✅
+Retry from beginning ✅ ✅
+Retry single task ✅
+Task groups within Workflows ✅
+Trigger workflows from other DAGs ✅
+Workflow-level parameters ✅
+======================================= ============================================= =================================
+
+Examples
+--------
+
+Example of what a DAG looks like with a DatabricksWorkflowTaskGroup
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks_workflow.py
+ :language: python
+ :start-after: [START howto_databricks_workflow_notebook]
+ :end-before: [END howto_databricks_workflow_notebook]
+
+With this example, Airflow will produce a job named ``.test_workflow__`` that will
+run task ``notebook_1`` and then ``notebook_2``. The job will be created in the databricks workspace
+if it does not already exist. If the job already exists, it will be updated to match
+the workflow defined in the DAG.
+
+The following image displays the resulting Databricks Workflow in the Airflow UI (based on the above example provided)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. image:: ../img/databricks_workflow_task_group_airflow_graph_view.png
+
+The corresponding Databricks Workflow in the Databricks UI for the run triggered from the Airflow DAG is depicted below
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. image:: ../img/workflow_run_databricks_graph_view.png
+
+
+To minimize update conflicts, we recommend that you keep parameters in the ``notebook_params`` of the
+``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible.
+This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on the job trigger time and
+do not modify the job definition.
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 7f29984d329b5..01c1e3378589d 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -411,6 +411,7 @@
"apache-airflow-providers-common-sql>=1.10.0",
"apache-airflow>=2.7.0",
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
+ "mergedeep>=1.3.4",
"requests>=2.27.0,<3"
],
"devel-deps": [
diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py
index d6e7eb3892919..2774385ea5a71 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -2014,3 +2014,122 @@ def test_zero_execution_timeout_raises_error(self):
"Set it instead to `None` if you desire the task to run indefinitely."
)
assert str(exc_info.value) == exception_message
+
+ def test_extend_workflow_notebook_packages(self):
+ """Test that the operator can extend the notebook packages of a Databricks workflow task group."""
+ databricks_workflow_task_group = MagicMock()
+ databricks_workflow_task_group.notebook_packages = [
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "pandas"}},
+ ]
+
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ notebook_packages=[
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "scipy"}},
+ ],
+ )
+
+ operator._extend_workflow_notebook_packages(databricks_workflow_task_group)
+
+ assert operator.notebook_packages == [
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "scipy"}},
+ {"pypi": {"package": "pandas"}},
+ ]
+
+ def test_convert_to_databricks_workflow_task(self):
+ """Test that the operator can convert itself to a Databricks workflow task."""
+ dag = DAG(dag_id="example_dag", start_date=datetime.now())
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ notebook_packages=[
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "scipy"}},
+ ],
+ dag=dag,
+ )
+
+ databricks_workflow_task_group = MagicMock()
+ databricks_workflow_task_group.notebook_packages = [
+ {"pypi": {"package": "numpy"}},
+ ]
+ databricks_workflow_task_group.notebook_params = {"param1": "value1"}
+
+ operator.notebook_packages = [{"pypi": {"package": "pandas"}}]
+ operator.notebook_params = {"param2": "value2"}
+ operator.task_group = databricks_workflow_task_group
+ operator.task_id = "test_task"
+ operator.upstream_task_ids = ["upstream_task"]
+ relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+ task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams)
+
+ expected_json = {
+ "task_key": "example_dag__test_task",
+ "depends_on": [],
+ "timeout_seconds": 0,
+ "email_notifications": {},
+ "notebook_task": {
+ "notebook_path": "/path/to/notebook",
+ "source": "WORKSPACE",
+ "base_parameters": {
+ "param2": "value2",
+ "param1": "value1",
+ },
+ },
+ "libraries": [
+ {"pypi": {"package": "pandas"}},
+ {"pypi": {"package": "numpy"}},
+ ],
+ }
+
+ assert task_json == expected_json
+
+ def test_convert_to_databricks_workflow_task_no_task_group(self):
+ """Test that an error is raised if the operator is not in a TaskGroup."""
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ notebook_packages=[
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "scipy"}},
+ ],
+ )
+ operator.task_group = None
+ relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+ with pytest.raises(
+ AirflowException,
+ match="Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup.",
+ ):
+ operator._convert_to_databricks_workflow_task(relevant_upstreams)
+
+ def test_convert_to_databricks_workflow_task_cluster_conflict(self):
+ """Test that an error is raised if both `existing_cluster_id` and `job_cluster_key` are set."""
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ notebook_packages=[
+ {"pypi": {"package": "numpy"}},
+ {"pypi": {"package": "scipy"}},
+ ],
+ )
+ databricks_workflow_task_group = MagicMock()
+ operator.existing_cluster_id = "existing-cluster-id"
+ operator.job_cluster_key = "job-cluster-key"
+ operator.task_group = databricks_workflow_task_group
+ relevant_upstreams = [MagicMock(task_id="upstream_task")]
+
+ with pytest.raises(
+ ValueError,
+ match="Both existing_cluster_id and job_cluster_key are set. Only one can be set per task.",
+ ):
+ operator._convert_to_databricks_workflow_task(relevant_upstreams)
diff --git a/tests/providers/databricks/operators/test_databricks_workflow.py b/tests/providers/databricks/operators/test_databricks_workflow.py
new file mode 100644
index 0000000000000..99f1a9d14815d
--- /dev/null
+++ b/tests/providers/databricks/operators/test_databricks_workflow.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow import DAG
+from airflow.exceptions import AirflowException
+from airflow.models.baseoperator import BaseOperator
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.databricks.hooks.databricks import RunLifeCycleState
+from airflow.providers.databricks.operators.databricks_workflow import (
+ DatabricksWorkflowTaskGroup,
+ _CreateDatabricksWorkflowOperator,
+ _flatten_node,
+)
+from airflow.utils import timezone
+
+pytestmark = pytest.mark.db_test
+
+DEFAULT_DATE = timezone.datetime(2021, 1, 1)
+
+
+@pytest.fixture
+def mock_databricks_hook():
+ """Provide a mock DatabricksHook."""
+ with patch("airflow.providers.databricks.operators.databricks_workflow.DatabricksHook") as mock_hook:
+ yield mock_hook
+
+
+@pytest.fixture
+def context():
+ """Provide a mock context object."""
+ return MagicMock()
+
+
+@pytest.fixture
+def mock_task_group():
+ """Provide a mock DatabricksWorkflowTaskGroup with necessary attributes."""
+ mock_group = MagicMock(spec=DatabricksWorkflowTaskGroup)
+ mock_group.group_id = "test_group"
+ return mock_group
+
+
+def test_flatten_node():
+ """Test that _flatten_node returns a flat list of operators."""
+ task_group = MagicMock(spec=DatabricksWorkflowTaskGroup)
+ base_operator = MagicMock(spec=BaseOperator)
+ task_group.children = {"task1": base_operator, "task2": base_operator}
+
+ result = _flatten_node(task_group)
+ assert result == [base_operator, base_operator]
+
+
+def test_create_workflow_json(mock_databricks_hook, context, mock_task_group):
+ """Test that _CreateDatabricksWorkflowOperator.create_workflow_json returns the expected JSON."""
+ operator = _CreateDatabricksWorkflowOperator(
+ task_id="test_task",
+ databricks_conn_id="databricks_default",
+ )
+ operator.task_group = mock_task_group
+
+ task = MagicMock(spec=BaseOperator)
+ task._convert_to_databricks_workflow_task = MagicMock(return_value={})
+ operator.add_task(task)
+
+ workflow_json = operator.create_workflow_json(context=context)
+
+ assert ".test_group" in workflow_json["name"]
+ assert "tasks" in workflow_json
+ assert workflow_json["format"] == "MULTI_TASK"
+ assert workflow_json["email_notifications"] == {"no_alert_for_skipped_runs": False}
+ assert workflow_json["job_clusters"] == []
+ assert workflow_json["max_concurrent_runs"] == 1
+ assert workflow_json["timeout_seconds"] == 0
+
+
+def test_create_or_reset_job_existing(mock_databricks_hook, context, mock_task_group):
+ """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job resets the job if it already exists."""
+ operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
+ operator.task_group = mock_task_group
+ operator._hook.list_jobs.return_value = [{"job_id": 123}]
+ operator._hook.create_job.return_value = 123
+
+ job_id = operator._create_or_reset_job(context)
+ assert job_id == 123
+ operator._hook.reset_job.assert_called_once()
+
+
+def test_create_or_reset_job_new(mock_databricks_hook, context, mock_task_group):
+ """Test that _CreateDatabricksWorkflowOperator._create_or_reset_job creates a new job if it does not exist."""
+ operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
+ operator.task_group = mock_task_group
+ operator._hook.list_jobs.return_value = []
+ operator._hook.create_job.return_value = 456
+
+ job_id = operator._create_or_reset_job(context)
+ assert job_id == 456
+ operator._hook.create_job.assert_called_once()
+
+
+def test_wait_for_job_to_start(mock_databricks_hook):
+ """Test that _CreateDatabricksWorkflowOperator._wait_for_job_to_start waits for the job to start."""
+ operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
+ mock_hook_instance = mock_databricks_hook.return_value
+ mock_hook_instance.get_run_state.side_effect = [
+ MagicMock(life_cycle_state=RunLifeCycleState.PENDING.value),
+ MagicMock(life_cycle_state=RunLifeCycleState.RUNNING.value),
+ ]
+
+ operator._wait_for_job_to_start(123)
+ mock_hook_instance.get_run_state.assert_called()
+
+
+def test_execute(mock_databricks_hook, context, mock_task_group):
+ """Test that _CreateDatabricksWorkflowOperator.execute runs the task group."""
+ operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
+ operator.task_group = mock_task_group
+ mock_task_group.jar_params = {}
+ mock_task_group.python_params = {}
+ mock_task_group.spark_submit_params = {}
+
+ mock_hook_instance = mock_databricks_hook.return_value
+ mock_hook_instance.run_now.return_value = 789
+ mock_hook_instance.list_jobs.return_value = [{"job_id": 123}]
+ mock_hook_instance.get_run_state.return_value = MagicMock(
+ life_cycle_state=RunLifeCycleState.RUNNING.value
+ )
+
+ task = MagicMock(spec=BaseOperator)
+ task._convert_to_databricks_workflow_task = MagicMock(return_value={})
+ operator.add_task(task)
+
+ result = operator.execute(context)
+
+ assert result == {
+ "conn_id": "databricks_default",
+ "job_id": 123,
+ "run_id": 789,
+ }
+ mock_hook_instance.run_now.assert_called_once()
+
+
+def test_execute_invalid_task_group(context):
+ """Test that _CreateDatabricksWorkflowOperator.execute raises an exception if the task group is invalid."""
+ operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default")
+ operator.task_group = MagicMock() # Not a DatabricksWorkflowTaskGroup
+
+ with pytest.raises(AirflowException, match="Task group must be a DatabricksWorkflowTaskGroup"):
+ operator.execute(context)
+
+
+@pytest.fixture
+def mock_databricks_workflow_operator():
+ with patch(
+ "airflow.providers.databricks.operators.databricks_workflow._CreateDatabricksWorkflowOperator"
+ ) as mock_operator:
+ yield mock_operator
+
+
+def test_task_group_initialization():
+ """Test that DatabricksWorkflowTaskGroup initializes correctly."""
+ with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE) as example_dag:
+ with DatabricksWorkflowTaskGroup(
+ group_id="test_databricks_workflow", databricks_conn_id="databricks_conn"
+ ) as task_group:
+ task_1 = EmptyOperator(task_id="task1")
+ task_1._convert_to_databricks_workflow_task = MagicMock(return_value={})
+ assert task_group.group_id == "test_databricks_workflow"
+ assert task_group.databricks_conn_id == "databricks_conn"
+ assert task_group.dag == example_dag
+
+
+def test_task_group_exit_creates_operator(mock_databricks_workflow_operator):
+ """Test that DatabricksWorkflowTaskGroup creates a _CreateDatabricksWorkflowOperator on exit."""
+ with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE) as example_dag:
+ with DatabricksWorkflowTaskGroup(
+ group_id="test_databricks_workflow",
+ databricks_conn_id="databricks_conn",
+ ) as task_group:
+ task1 = MagicMock(task_id="task1")
+ task1._convert_to_databricks_workflow_task = MagicMock(return_value={})
+ task2 = MagicMock(task_id="task2")
+ task2._convert_to_databricks_workflow_task = MagicMock(return_value={})
+
+ task_group.add(task1)
+ task_group.add(task2)
+
+ task1.set_downstream(task2)
+
+ mock_databricks_workflow_operator.assert_called_once_with(
+ dag=example_dag,
+ task_group=task_group,
+ task_id="launch",
+ databricks_conn_id="databricks_conn",
+ existing_clusters=[],
+ extra_job_params={},
+ job_clusters=[],
+ max_concurrent_runs=1,
+ notebook_params={},
+ )
+
+
+def test_task_group_root_tasks_set_upstream_to_operator(mock_databricks_workflow_operator):
+ """Test that tasks added to a DatabricksWorkflowTaskGroup are set upstream to the operator."""
+ with DAG(dag_id="example_databricks_workflow_dag", start_date=DEFAULT_DATE):
+ with DatabricksWorkflowTaskGroup(
+ group_id="test_databricks_workflow1",
+ databricks_conn_id="databricks_conn",
+ ) as task_group:
+ task1 = MagicMock(task_id="task1")
+ task1._convert_to_databricks_workflow_task = MagicMock(return_value={})
+ task_group.add(task1)
+
+ create_operator_instance = mock_databricks_workflow_operator.return_value
+ task1.set_upstream.assert_called_once_with(create_operator_instance)
diff --git a/tests/system/providers/databricks/example_databricks_workflow.py b/tests/system/providers/databricks/example_databricks_workflow.py
new file mode 100644
index 0000000000000..6b05f34684c9d
--- /dev/null
+++ b/tests/system/providers/databricks/example_databricks_workflow.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG for using the DatabricksWorkflowTaskGroup and DatabricksNotebookOperator."""
+
+from __future__ import annotations
+
+import os
+from datetime import timedelta
+
+from airflow.models.dag import DAG
+from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator
+from airflow.providers.databricks.operators.databricks_workflow import DatabricksWorkflowTaskGroup
+from airflow.utils.timezone import datetime
+
+EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
+
+DATABRICKS_CONN_ID = os.getenv("DATABRICKS_CONN_ID", "databricks_conn")
+DATABRICKS_NOTIFICATION_EMAIL = os.getenv("DATABRICKS_NOTIFICATION_EMAIL", "your_email@serviceprovider.com")
+
+GROUP_ID = os.getenv("DATABRICKS_GROUP_ID", "1234").replace(".", "_")
+USER = os.environ.get("USER")
+
+job_cluster_spec = [
+ {
+ "job_cluster_key": "Shared_job_cluster",
+ "new_cluster": {
+ "cluster_name": "",
+ "spark_version": "11.3.x-scala2.12",
+ "aws_attributes": {
+ "first_on_demand": 1,
+ "availability": "SPOT_WITH_FALLBACK",
+ "zone_id": "us-east-2b",
+ "spot_bid_price_percent": 100,
+ "ebs_volume_count": 0,
+ },
+ "node_type_id": "i3.xlarge",
+ "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"},
+ "enable_elastic_disk": False,
+ "data_security_mode": "LEGACY_SINGLE_USER_STANDARD",
+ "runtime_engine": "STANDARD",
+ "num_workers": 8,
+ },
+ }
+]
+dag = DAG(
+ dag_id="example_databricks_workflow",
+ start_date=datetime(2022, 1, 1),
+ schedule_interval=None,
+ catchup=False,
+ tags=["example", "databricks"],
+)
+with dag:
+ # [START howto_databricks_workflow_notebook]
+ task_group = DatabricksWorkflowTaskGroup(
+ group_id=f"test_workflow_{USER}_{GROUP_ID}",
+ databricks_conn_id=DATABRICKS_CONN_ID,
+ job_clusters=job_cluster_spec,
+ notebook_params={"ts": "{{ ts }}"},
+ notebook_packages=[
+ {
+ "pypi": {
+ "package": "simplejson==3.18.0", # Pin specification version of a package like this.
+ "repo": "https://pypi.org/simple", # You can specify your required Pypi index here.
+ }
+ },
+ ],
+ extra_job_params={
+ "email_notifications": {
+ "on_start": [DATABRICKS_NOTIFICATION_EMAIL],
+ },
+ },
+ )
+ with task_group:
+ notebook_1 = DatabricksNotebookOperator(
+ task_id="workflow_notebook_1",
+ databricks_conn_id=DATABRICKS_CONN_ID,
+ notebook_path="/Shared/Notebook_1",
+ notebook_packages=[{"pypi": {"package": "Faker"}}],
+ source="WORKSPACE",
+ job_cluster_key="Shared_job_cluster",
+ execution_timeout=timedelta(seconds=600),
+ )
+ notebook_2 = DatabricksNotebookOperator(
+ task_id="workflow_notebook_2",
+ databricks_conn_id=DATABRICKS_CONN_ID,
+ notebook_path="/Shared/Notebook_2",
+ source="WORKSPACE",
+ job_cluster_key="Shared_job_cluster",
+ notebook_params={"foo": "bar", "ds": "{{ ds }}"},
+ )
+ notebook_1 >> notebook_2
+ # [END howto_databricks_workflow_notebook]
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)