diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index d6118f247f597..38ce8e332455a 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -20,6 +20,7 @@ from __future__ import annotations import time +from abc import ABC, abstractmethod from functools import cached_property from logging import Logger from typing import TYPE_CHECKING, Any, Sequence @@ -899,87 +900,64 @@ def __init__(self, *args, **kwargs): super().__init__(deferrable=True, *args, **kwargs) -class DatabricksNotebookOperator(BaseOperator): +class DatabricksTaskBaseOperator(BaseOperator, ABC): """ - Runs a notebook on Databricks using an Airflow operator. - - The DatabricksNotebookOperator allows users to launch and monitor notebook - job runs on Databricks as Airflow tasks. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:DatabricksNotebookOperator` + Base class for operators that are run as Databricks job tasks or tasks within a Databricks workflow. - :param notebook_path: The path to the notebook in Databricks. - :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved - from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository - defined in git_source. If the value is empty, the task will use GIT if git_source is defined - and WORKSPACE otherwise. For more information please visit - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate - :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task. - :param notebook_packages: A list of the Python libraries to be installed on the cluster running the - notebook. - :param new_cluster: Specs for a new cluster on which this task will be run. + :param caller: The name of the caller operator to be used in the logs. + :param databricks_conn_id: The name of the Airflow connection to use. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param deferrable: Whether to run the operator in the deferrable mode. :param existing_cluster_id: ID for existing cluster on which to run this task. :param job_cluster_key: The key for the job cluster. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param notebook_packages: A list of the Python libraries to be installed on the cluster running the + notebook. + :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task. :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. - :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. - :param databricks_retry_delay: Number of seconds to wait between retries. - :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. - :param databricks_conn_id: The name of the Airflow connection to use. - :param deferrable: Run operator in the deferrable mode. + :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within + a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow. """ - template_fields = ( - "notebook_params", - "workflow_run_metadata", - ) - CALLER = "DatabricksNotebookOperator" - def __init__( self, - notebook_path: str, - source: str, - notebook_params: dict | None = None, - notebook_packages: list[dict[str, Any]] | None = None, - new_cluster: dict[str, Any] | None = None, + caller: str = "DatabricksTaskBaseOperator", + databricks_conn_id: str = "databricks_default", + databricks_retry_args: dict[Any, Any] | None = None, + databricks_retry_delay: int = 1, + databricks_retry_limit: int = 3, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), existing_cluster_id: str = "", job_cluster_key: str = "", + new_cluster: dict[str, Any] | None = None, polling_period_seconds: int = 5, - databricks_retry_limit: int = 3, - databricks_retry_delay: int = 1, - databricks_retry_args: dict[Any, Any] | None = None, wait_for_termination: bool = True, - databricks_conn_id: str = "databricks_default", - workflow_run_metadata: dict | None = None, - deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + workflow_run_metadata: dict[str, Any] | None = None, **kwargs: Any, ): - self.notebook_path = notebook_path - self.source = source - self.notebook_params = notebook_params or {} - self.notebook_packages = notebook_packages or [] - self.new_cluster = new_cluster or {} + self.caller = caller + self.databricks_conn_id = databricks_conn_id + self.databricks_retry_args = databricks_retry_args + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_limit = databricks_retry_limit + self.deferrable = deferrable self.existing_cluster_id = existing_cluster_id self.job_cluster_key = job_cluster_key + self.new_cluster = new_cluster or {} self.polling_period_seconds = polling_period_seconds - self.databricks_retry_limit = databricks_retry_limit - self.databricks_retry_delay = databricks_retry_delay - self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination - self.databricks_conn_id = databricks_conn_id - self.databricks_run_id: int | None = None - self.deferrable = deferrable + self.workflow_run_metadata = workflow_run_metadata - # This is used to store the metadata of the Databricks job run when the job is launched from within DatabricksWorkflowTaskGroup. - self.workflow_run_metadata: dict | None = workflow_run_metadata + self.databricks_run_id: int | None = None super().__init__(**kwargs) @cached_property def _hook(self) -> DatabricksHook: - return self._get_hook(caller=self.CALLER) + return self._get_hook(caller=self.caller) def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( @@ -987,44 +965,9 @@ def _get_hook(self, caller: str) -> DatabricksHook: retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, - caller=self.CALLER, + caller=caller, ) - def _get_task_timeout_seconds(self) -> int: - """ - Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task. - - By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when - execution_timeout is not defined, the task continues to run indefinitely. Therefore, - to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating - that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, - where a timeout seconds value of 0 signifies an indefinite run duration. - More details can be found in the Databricks documentation: - See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds - """ - if self.execution_timeout is None: - return 0 - execution_timeout_seconds = int(self.execution_timeout.total_seconds()) - if execution_timeout_seconds == 0: - raise ValueError( - "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to " - "`None` if you desire the task to run indefinitely." - ) - return execution_timeout_seconds - - def _get_task_base_json(self) -> dict[str, Any]: - """Get task base json to be used for task submissions.""" - return { - "timeout_seconds": self._get_task_timeout_seconds(), - "email_notifications": {}, - "notebook_task": { - "notebook_path": self.notebook_path, - "source": self.source, - "base_parameters": self.notebook_params, - }, - "libraries": self.notebook_packages, - } - def _get_databricks_task_id(self, task_id: str) -> str: """Get the databricks task ID using dag_id and task_id. Removes illegal characters.""" return f"{self.dag_id}__{task_id.replace('.', '__')}" @@ -1049,58 +992,10 @@ def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None: return None - def _extend_workflow_notebook_packages( - self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup - ) -> None: - """Extend the task group packages into the notebook's packages, without adding any duplicates.""" - for task_group_package in databricks_workflow_task_group.notebook_packages: - exists = any( - task_group_package == existing_package for existing_package in self.notebook_packages - ) - if not exists: - self.notebook_packages.append(task_group_package) - - def _convert_to_databricks_workflow_task( - self, relevant_upstreams: list[BaseOperator], context: Context | None = None - ) -> dict[str, object]: - """Convert the operator to a Databricks workflow task that can be a task in a workflow.""" - databricks_workflow_task_group = self._databricks_workflow_task_group - if not databricks_workflow_task_group: - raise AirflowException( - "Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup." - ) - - if hasattr(databricks_workflow_task_group, "notebook_packages"): - self._extend_workflow_notebook_packages(databricks_workflow_task_group) - - if hasattr(databricks_workflow_task_group, "notebook_params"): - self.notebook_params = { - **self.notebook_params, - **databricks_workflow_task_group.notebook_params, - } - - base_task_json = self._get_task_base_json() - result = { - "task_key": self._get_databricks_task_id(self.task_id), - "depends_on": [ - {"task_key": self._get_databricks_task_id(task_id)} - for task_id in self.upstream_task_ids - if task_id in relevant_upstreams - ], - **base_task_json, - } - - if self.existing_cluster_id and self.job_cluster_key: - raise ValueError( - "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task." - ) - - if self.existing_cluster_id: - result["existing_cluster_id"] = self.existing_cluster_id - elif self.job_cluster_key: - result["job_cluster_key"] = self.job_cluster_key - - return result + @abstractmethod + def _get_task_base_json(self) -> dict[str, Any]: + """Get the base json for the task.""" + raise NotImplementedError() def _get_run_json(self) -> dict[str, Any]: """Get run json to be used for task submissions.""" @@ -1118,7 +1013,8 @@ def _get_run_json(self) -> dict[str, Any]: raise ValueError("Must specify either existing_cluster_id or new_cluster.") return run_json - def launch_notebook_job(self) -> int: + def _launch_job(self) -> int: + """Launch the job on Databricks.""" run_json = self._get_run_json() self.databricks_run_id = self._hook.submit_run(run_json) url = self._hook.get_run_page_url(self.databricks_run_id) @@ -1126,6 +1022,7 @@ def launch_notebook_job(self) -> int: return self.databricks_run_id def _handle_terminal_run_state(self, run_state: RunState) -> None: + """Handle the terminal state of the run.""" if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value: raise AirflowException( f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}" @@ -1136,28 +1033,74 @@ def _handle_terminal_run_state(self, run_state: RunState) -> None: ) self.log.info("Task succeeded. Final state %s.", run_state.result_state) + def _get_current_databricks_task(self) -> dict[str, Any]: + """Retrieve the Databricks task corresponding to the current Airflow task.""" + if self.databricks_run_id is None: + raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") + return {task["task_key"]: task for task in self._hook.get_run(self.databricks_run_id)["tasks"]}[ + self._get_databricks_task_id(self.task_id) + ] + + def _convert_to_databricks_workflow_task( + self, relevant_upstreams: list[BaseOperator], context: Context | None = None + ) -> dict[str, object]: + """Convert the operator to a Databricks workflow task that can be a task in a workflow.""" + base_task_json = self._get_task_base_json() + result = { + "task_key": self._get_databricks_task_id(self.task_id), + "depends_on": [ + {"task_key": self._get_databricks_task_id(task_id)} + for task_id in self.upstream_task_ids + if task_id in relevant_upstreams + ], + **base_task_json, + } + + if self.existing_cluster_id and self.job_cluster_key: + raise ValueError( + "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task." + ) + if self.existing_cluster_id: + result["existing_cluster_id"] = self.existing_cluster_id + elif self.job_cluster_key: + result["job_cluster_key"] = self.job_cluster_key + + return result + def monitor_databricks_job(self) -> None: + """ + Monitor the Databricks job. + + Wait for the job to terminate. If deferrable, defer the task. + """ if self.databricks_run_id is None: raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.") - run = self._hook.get_run(self.databricks_run_id) + current_task_run_id = self._get_current_databricks_task()["run_id"] + run = self._hook.get_run(current_task_run_id) + run_page_url = run["run_page_url"] + self.log.info("Check the task run in Databricks: %s", run_page_url) run_state = RunState(**run["state"]) - self.log.info("Current state of the job: %s", run_state.life_cycle_state) + self.log.info( + "Current state of the the databricks task %s is %s", + self._get_databricks_task_id(self.task_id), + run_state.life_cycle_state, + ) if self.deferrable and not run_state.is_terminal: self.defer( trigger=DatabricksExecutionTrigger( - run_id=self.databricks_run_id, + run_id=current_task_run_id, databricks_conn_id=self.databricks_conn_id, polling_period_seconds=self.polling_period_seconds, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, - caller=self.CALLER, + caller=self.caller, ), method_name=DEFER_METHOD_NAME, ) while not run_state.is_terminal: time.sleep(self.polling_period_seconds) - run = self._hook.get_run(self.databricks_run_id) + run = self._hook.get_run(current_task_run_id) run_state = RunState(**run["state"]) self.log.info( "Current state of the databricks task %s is %s", @@ -1167,6 +1110,7 @@ def monitor_databricks_job(self) -> None: self._handle_terminal_run_state(run_state) def execute(self, context: Context) -> None: + """Execute the operator. Launch the job and monitor it if wait_for_termination is set to True.""" if self._databricks_workflow_task_group: # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched. if not self.workflow_run_metadata: @@ -1178,10 +1122,226 @@ def execute(self, context: Context) -> None: self.databricks_run_id = workflow_run_metadata.run_id self.databricks_conn_id = workflow_run_metadata.conn_id else: - self.launch_notebook_job() + self._launch_job() if self.wait_for_termination: self.monitor_databricks_job() def execute_complete(self, context: dict | None, event: dict) -> None: run_state = RunState.from_json(event["run_state"]) self._handle_terminal_run_state(run_state) + + +class DatabricksNotebookOperator(DatabricksTaskBaseOperator): + """ + Runs a notebook on Databricks using an Airflow operator. + + The DatabricksNotebookOperator allows users to launch and monitor notebook job runs on Databricks as + Airflow tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job + clusters, which allows users to run their tasks on cheaper clusters that can be shared between tasks. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksNotebookOperator` + + :param notebook_path: The path to the notebook in Databricks. + :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved + from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository + defined in git_source. If the value is empty, the task will use GIT if git_source is defined + and WORKSPACE otherwise. For more information please visit + https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + :param databricks_conn_id: The name of the Airflow connection to use. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param deferrable: Whether to run the operator in the deferrable mode. + :param existing_cluster_id: ID for existing cluster on which to run this task. + :param job_cluster_key: The key for the job cluster. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param notebook_packages: A list of the Python libraries to be installed on the cluster running the + notebook. + :param notebook_params: A dict of key-value pairs to be passed as optional params to the notebook task. + :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within + a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow. + """ + + template_fields = ( + "notebook_params", + "workflow_run_metadata", + ) + CALLER = "DatabricksNotebookOperator" + + def __init__( + self, + notebook_path: str, + source: str, + databricks_conn_id: str = "databricks_default", + databricks_retry_args: dict[Any, Any] | None = None, + databricks_retry_delay: int = 1, + databricks_retry_limit: int = 3, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + existing_cluster_id: str = "", + job_cluster_key: str = "", + new_cluster: dict[str, Any] | None = None, + notebook_packages: list[dict[str, Any]] | None = None, + notebook_params: dict | None = None, + polling_period_seconds: int = 5, + wait_for_termination: bool = True, + workflow_run_metadata: dict | None = None, + **kwargs: Any, + ): + self.notebook_path = notebook_path + self.source = source + self.notebook_packages = notebook_packages or [] + self.notebook_params = notebook_params or {} + + super().__init__( + caller=self.CALLER, + databricks_conn_id=databricks_conn_id, + databricks_retry_args=databricks_retry_args, + databricks_retry_delay=databricks_retry_delay, + databricks_retry_limit=databricks_retry_limit, + deferrable=deferrable, + existing_cluster_id=existing_cluster_id, + job_cluster_key=job_cluster_key, + new_cluster=new_cluster, + polling_period_seconds=polling_period_seconds, + wait_for_termination=wait_for_termination, + workflow_run_metadata=workflow_run_metadata, + **kwargs, + ) + + def _get_task_timeout_seconds(self) -> int: + """ + Get the timeout seconds value for the Databricks job based on the execution timeout value provided for the Airflow task. + + By default, tasks in Airflow have an execution_timeout set to None. In Airflow, when + execution_timeout is not defined, the task continues to run indefinitely. Therefore, + to mirror this behavior in the Databricks Jobs API, we set the timeout to 0, indicating + that the job should run indefinitely. This aligns with the default behavior of Databricks jobs, + where a timeout seconds value of 0 signifies an indefinite run duration. + More details can be found in the Databricks documentation: + See https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds + """ + if self.execution_timeout is None: + return 0 + execution_timeout_seconds = int(self.execution_timeout.total_seconds()) + if execution_timeout_seconds == 0: + raise ValueError( + "If you've set an `execution_timeout` for the task, ensure it's not `0`. Set it instead to " + "`None` if you desire the task to run indefinitely." + ) + return execution_timeout_seconds + + def _get_task_base_json(self) -> dict[str, Any]: + """Get task base json to be used for task submissions.""" + return { + "timeout_seconds": self._get_task_timeout_seconds(), + "email_notifications": {}, + "notebook_task": { + "notebook_path": self.notebook_path, + "source": self.source, + "base_parameters": self.notebook_params, + }, + "libraries": self.notebook_packages, + } + + def _extend_workflow_notebook_packages( + self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup + ) -> None: + """Extend the task group packages into the notebook's packages, without adding any duplicates.""" + for task_group_package in databricks_workflow_task_group.notebook_packages: + exists = any( + task_group_package == existing_package for existing_package in self.notebook_packages + ) + if not exists: + self.notebook_packages.append(task_group_package) + + def _convert_to_databricks_workflow_task( + self, relevant_upstreams: list[BaseOperator], context: Context | None = None + ) -> dict[str, object]: + """Convert the operator to a Databricks workflow task that can be a task in a workflow.""" + databricks_workflow_task_group = self._databricks_workflow_task_group + if not databricks_workflow_task_group: + raise AirflowException( + "Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup." + ) + + if hasattr(databricks_workflow_task_group, "notebook_packages"): + self._extend_workflow_notebook_packages(databricks_workflow_task_group) + + if hasattr(databricks_workflow_task_group, "notebook_params"): + self.notebook_params = { + **self.notebook_params, + **databricks_workflow_task_group.notebook_params, + } + + return super()._convert_to_databricks_workflow_task(relevant_upstreams, context=context) + + +class DatabricksTaskOperator(DatabricksTaskBaseOperator): + """ + Runs a task on Databricks using an Airflow operator. + + The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow + tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which + allows users to run their tasks on cheaper clusters that can be shared between tasks. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksTaskOperator` + + :param task_config: The configuration of the task to be run on Databricks. + :param databricks_conn_id: The name of the Airflow connection to use. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param deferrable: Whether to run the operator in the deferrable mode. + :param existing_cluster_id: ID for existing cluster on which to run this task. + :param job_cluster_key: The key for the job cluster. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + """ + + CALLER = "DatabricksTaskOperator" + template_fields = ("workflow_run_metadata",) + + def __init__( + self, + task_config: dict, + databricks_conn_id: str = "databricks_default", + databricks_retry_args: dict[Any, Any] | None = None, + databricks_retry_delay: int = 1, + databricks_retry_limit: int = 3, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + existing_cluster_id: str = "", + job_cluster_key: str = "", + new_cluster: dict[str, Any] | None = None, + polling_period_seconds: int = 5, + wait_for_termination: bool = True, + workflow_run_metadata: dict | None = None, + **kwargs, + ): + self.task_config = task_config + + super().__init__( + caller=self.CALLER, + databricks_conn_id=databricks_conn_id, + databricks_retry_args=databricks_retry_args, + databricks_retry_delay=databricks_retry_delay, + databricks_retry_limit=databricks_retry_limit, + deferrable=deferrable, + existing_cluster_id=existing_cluster_id, + job_cluster_key=job_cluster_key, + new_cluster=new_cluster, + polling_period_seconds=polling_period_seconds, + wait_for_termination=wait_for_termination, + workflow_run_metadata=workflow_run_metadata, + **kwargs, + ) + + def _get_task_base_json(self) -> dict[str, Any]: + """Get task base json to be used for task submissions.""" + return self.task_config diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 80506dc16c15d..c4f24867dd509 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -92,6 +92,7 @@ integrations: - /docs/apache-airflow-providers-databricks/operators/notebook.rst - /docs/apache-airflow-providers-databricks/operators/submit_run.rst - /docs/apache-airflow-providers-databricks/operators/run_now.rst + - /docs/apache-airflow-providers-databricks/operators/task.rst logo: /integration-logos/databricks/Databricks.png tags: [service] - integration-name: Databricks SQL diff --git a/docs/apache-airflow-providers-databricks/operators/task.rst b/docs/apache-airflow-providers-databricks/operators/task.rst new file mode 100644 index 0000000000000..476e72c494b9b --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/task.rst @@ -0,0 +1,46 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:DatabricksTaskOperator: + + +DatabricksTaskOperator +====================== + +Use the :class:`~airflow.providers.databricks.operators.databricks.DatabricksTaskOperator` to launch and monitor +task runs on Databricks as Airflow tasks. This can be used as a standalone operator in a DAG and as well as part of a +Databricks Workflow by using it as an operator(task) within the +:class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup`. + + + +Examples +-------- + +Running a notebook in Databricks using DatabricksTaskOperator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_task_notebook] + :end-before: [END howto_operator_databricks_task_notebook] + +Running a SQL query in Databricks using DatabricksTaskOperator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_task_sql] + :end-before: [END howto_operator_databricks_task_sql] diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 2774385ea5a71..6822bb53697d5 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -33,6 +33,8 @@ DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, DatabricksSubmitRunOperator, + DatabricksTaskBaseOperator, + DatabricksTaskOperator, ) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils import databricks as utils @@ -1832,6 +1834,16 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ class TestDatabricksNotebookOperator: + def test_is_instance_of_databricks_task_base_operator(self): + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + ) + + assert isinstance(operator, DatabricksTaskBaseOperator) + def test_execute_with_wait_for_termination(self): operator = DatabricksNotebookOperator( task_id="test_task", @@ -1839,13 +1851,13 @@ def test_execute_with_wait_for_termination(self): source="test_source", databricks_conn_id="test_conn_id", ) - operator.launch_notebook_job = MagicMock(return_value=12345) + operator._launch_job = MagicMock(return_value=12345) operator.monitor_databricks_job = MagicMock() operator.execute({}) assert operator.wait_for_termination is True - operator.launch_notebook_job.assert_called_once() + operator._launch_job.assert_called_once() operator.monitor_databricks_job.assert_called_once() def test_execute_without_wait_for_termination(self): @@ -1856,18 +1868,24 @@ def test_execute_without_wait_for_termination(self): databricks_conn_id="test_conn_id", wait_for_termination=False, ) - operator.launch_notebook_job = MagicMock(return_value=12345) + operator._launch_job = MagicMock(return_value=12345) operator.monitor_databricks_job = MagicMock() operator.execute({}) assert operator.wait_for_termination is False - operator.launch_notebook_job.assert_called_once() + operator._launch_job.assert_called_once() operator.monitor_databricks_job.assert_not_called() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_execute_with_deferrable(self, mock_databricks_hook): - mock_databricks_hook.return_value.get_run.return_value = {"state": {"life_cycle_state": "PENDING"}} + @mock.patch( + "airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator._get_current_databricks_task" + ) + def test_execute_with_deferrable(self, mock_get_current_task, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": {"life_cycle_state": "PENDING"}, + "run_page_url": "test_url", + } operator = DatabricksNotebookOperator( task_id="test_task", notebook_path="test_path", @@ -1886,13 +1904,17 @@ def test_execute_with_deferrable(self, mock_databricks_hook): assert exec_info.value.method_name == "execute_complete" @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_execute_with_deferrable_early_termination(self, mock_databricks_hook): + @mock.patch( + "airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator._get_current_databricks_task" + ) + def test_execute_with_deferrable_early_termination(self, mock_get_current_task, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { "state": { "life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE", - } + }, + "run_page_url": "test_url", } operator = DatabricksNotebookOperator( task_id="test_task", @@ -1910,9 +1932,15 @@ def test_execute_with_deferrable_early_termination(self, mock_databricks_hook): assert exception_message == str(exec_info.value) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook): + @mock.patch( + "airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator._get_current_databricks_task" + ) + def test_monitor_databricks_job_successful_raises_no_exception( + self, mock_get_current_task, mock_databricks_hook + ): mock_databricks_hook.return_value.get_run.return_value = { - "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "run_page_url": "test_url", } operator = DatabricksNotebookOperator( @@ -1926,9 +1954,13 @@ def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databr operator.monitor_databricks_job() @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_monitor_databricks_job_failed(self, mock_databricks_hook): + @mock.patch( + "airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator._get_current_databricks_task" + ) + def test_monitor_databricks_job_failed(self, mock_get_current_task, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { - "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"} + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "FAILURE"}, + "run_page_url": "test_url", } operator = DatabricksNotebookOperator( @@ -1956,7 +1988,7 @@ def test_launch_notebook_job(self, mock_databricks_hook): ) operator._hook.submit_run.return_value = 12345 - run_id = operator.launch_notebook_job() + run_id = operator._launch_job() assert run_id == 12345 @@ -2133,3 +2165,41 @@ def test_convert_to_databricks_workflow_task_cluster_conflict(self): match="Both existing_cluster_id and job_cluster_key are set. Only one can be set per task.", ): operator._convert_to_databricks_workflow_task(relevant_upstreams) + + +class TestDatabricksTaskOperator: + def test_is_instance_of_databricks_task_base_operator(self): + task_config = { + "sql_task": { + "query": { + "query_id": "c9cf6468-babe-41a6-abc3-10ac358c71ee", + }, + "warehouse_id": "cf414a2206dfb397", + } + } + operator = DatabricksTaskOperator( + task_id="test_task", + databricks_conn_id="test_conn_id", + task_config=task_config, + ) + + assert isinstance(operator, DatabricksTaskBaseOperator) + + def test_get_task_base_json(self): + task_config = { + "sql_task": { + "query": { + "query_id": "c9cf646-8babe-41a6-abc3-10ac358c71ee", + }, + "warehouse_id": "cf414a2206dfb397", + } + } + operator = DatabricksTaskOperator( + task_id="test_task", + databricks_conn_id="test_conn_id", + task_config=task_config, + ) + task_base_json = operator._get_task_base_json() + + assert operator.task_config == task_config + assert task_base_json == task_config diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index 62b8e5df50fd4..82c5d313421a9 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -42,11 +42,15 @@ DatabricksNotebookOperator, DatabricksRunNowOperator, DatabricksSubmitRunOperator, + DatabricksTaskOperator, ) ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_databricks_operator" +QUERY_ID = os.environ.get("QUERY_ID", "c9cf6468-babe-41a6-abc3-10ac358c71ee") +WAREHOUSE_ID = os.environ.get("WAREHOUSE_ID", "cf414a2206dfb397") + with DAG( dag_id=DAG_ID, schedule="@daily", @@ -201,6 +205,39 @@ ) # [END howto_operator_databricks_notebook_existing_cluster] + # [START howto_operator_databricks_task_notebook] + task_operator_nb_1 = DatabricksTaskOperator( + task_id="nb_1", + databricks_conn_id="databricks_conn", + job_cluster_key="Shared_job_cluster", + task_config={ + "notebook_task": { + "notebook_path": "/Shared/Notebook_1", + "source": "WORKSPACE", + }, + "libraries": [ + {"pypi": {"package": "Faker"}}, + {"pypi": {"package": "simplejson"}}, + ], + }, + ) + # [END howto_operator_databricks_task_notebook] + + # [START howto_operator_databricks_task_sql] + task_operator_sql_query = DatabricksTaskOperator( + task_id="sql_query", + databricks_conn_id="databricks_conn", + task_config={ + "sql_task": { + "query": { + "query_id": QUERY_ID, + }, + "warehouse_id": WAREHOUSE_ID, + } + }, + ) + # [END howto_operator_databricks_task_sql] + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure diff --git a/tests/system/providers/databricks/example_databricks_workflow.py b/tests/system/providers/databricks/example_databricks_workflow.py index 6b05f34684c9d..6639708b532fb 100644 --- a/tests/system/providers/databricks/example_databricks_workflow.py +++ b/tests/system/providers/databricks/example_databricks_workflow.py @@ -23,7 +23,10 @@ from datetime import timedelta from airflow.models.dag import DAG -from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator +from airflow.providers.databricks.operators.databricks import ( + DatabricksNotebookOperator, + DatabricksTaskOperator, +) from airflow.providers.databricks.operators.databricks_workflow import DatabricksWorkflowTaskGroup from airflow.utils.timezone import datetime @@ -35,6 +38,9 @@ GROUP_ID = os.getenv("DATABRICKS_GROUP_ID", "1234").replace(".", "_") USER = os.environ.get("USER") +QUERY_ID = os.environ.get("QUERY_ID", "c9cf6468-babe-41a6-abc3-10ac358c71ee") +WAREHOUSE_ID = os.environ.get("WAREHOUSE_ID", "cf414a2206dfb397") + job_cluster_spec = [ { "job_cluster_key": "Shared_job_cluster", @@ -95,6 +101,7 @@ job_cluster_key="Shared_job_cluster", execution_timeout=timedelta(seconds=600), ) + notebook_2 = DatabricksNotebookOperator( task_id="workflow_notebook_2", databricks_conn_id=DATABRICKS_CONN_ID, @@ -103,7 +110,37 @@ job_cluster_key="Shared_job_cluster", notebook_params={"foo": "bar", "ds": "{{ ds }}"}, ) - notebook_1 >> notebook_2 + + task_operator_nb_1 = DatabricksTaskOperator( + task_id="nb_1", + databricks_conn_id="databricks_conn", + job_cluster_key="Shared_job_cluster", + task_config={ + "notebook_task": { + "notebook_path": "/Shared/Notebook_1", + "source": "WORKSPACE", + }, + "libraries": [ + {"pypi": {"package": "Faker"}}, + {"pypi": {"package": "simplejson"}}, + ], + }, + ) + + sql_query = DatabricksTaskOperator( + task_id="sql_query", + databricks_conn_id="databricks_conn", + task_config={ + "sql_task": { + "query": { + "query_id": QUERY_ID, + }, + "warehouse_id": WAREHOUSE_ID, + } + }, + ) + + notebook_1 >> notebook_2 >> task_operator_nb_1 >> sql_query # [END howto_databricks_workflow_notebook] from tests.system.utils.watcher import watcher