diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7636842209..6e00211567 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -22,10 +22,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.databricks.hooks.databricks import DatabricksHook if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.context import Context XCOM_RUN_ID_KEY = 'run_id' @@ -107,9 +108,23 @@ class DatabricksJobRunLink(BaseOperatorLink): name = "See Databricks Job Run" - def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - run_page_url = ti.xcom_pull(task_ids=operator.task_id, key=XCOM_RUN_PAGE_URL_KEY) + def get_link( + self, + operator, + dttm=None, + *, + ti_key: Optional["TaskInstanceKey"] = None, + ) -> str: + if ti_key is not None: + run_page_url = XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) + else: + assert dttm + run_page_url = XCom.get_one( + key=XCOM_RUN_PAGE_URL_KEY, + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + ) return run_page_url