From 73f2eab68081e966fd808bfaca923eed1f81bc43 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Tue, 5 Nov 2024 11:22:47 +0100 Subject: [PATCH] serialize asset/dataset timetable conditions in OpenLineage info also for Airflow 2 (#43434) Signed-off-by: Maciej Obuchowski --- .../common/compat/assets/__init__.py | 5 +- .../openlineage/extractors/manager.py | 2 + .../providers/openlineage/utils/utils.py | 24 +++- .../tests/openlineage/plugins/test_utils.py | 125 +++++++++++++++++- 4 files changed, 153 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index 460204a4e417f..e302395f701ea 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -47,11 +47,14 @@ _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") _IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") + _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") # dataset is renamed to asset since Airflow 3.0 - from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails from airflow.datasets import Dataset as Asset + if _IS_AIRFLOW_2_8_OR_HIGHER: + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + if _IS_AIRFLOW_2_9_OR_HIGHER: from airflow.datasets import ( DatasetAll as AssetAll, diff --git a/providers/src/airflow/providers/openlineage/extractors/manager.py b/providers/src/airflow/providers/openlineage/extractors/manager.py index f6d572bae5313..be824335718b1 100644 --- a/providers/src/airflow/providers/openlineage/extractors/manager.py +++ b/providers/src/airflow/providers/openlineage/extractors/manager.py @@ -198,6 +198,8 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: except ImportError: return None + if not hasattr(get_hook_lineage_collector(), "has_collected"): + return None if not get_hook_lineage_collector().has_collected: return None diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index a00552eed251f..8c67c32f95b86 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -262,9 +262,31 @@ class DagInfo(InfoJsonEncodable): "start_date", "tags", ] - casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None} + casts = {"timetable": lambda dag: DagInfo.serialize_timetable(dag)} renames = {"_dag_id": "dag_id"} + @classmethod + def serialize_timetable(cls, dag: DAG) -> dict[str, Any]: + serialized = dag.timetable.serialize() + if serialized != {} and serialized is not None: + return serialized + if ( + hasattr(dag, "dataset_triggers") + and isinstance(dag.dataset_triggers, list) + and len(dag.dataset_triggers) + ): + triggers = dag.dataset_triggers + return { + "dataset_condition": { + "__type": "dataset_all", + "objects": [ + {"__type": "dataset", "uri": trigger.uri, "extra": trigger.extra} + for trigger in triggers + ], + } + } + return {} + class DagRunInfo(InfoJsonEncodable): """Defines encoding DagRun object to JSON.""" diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 624bdecb5b459..531e21d42de1a 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -29,8 +29,10 @@ from pkg_resources import parse_version from airflow.models import DAG as AIRFLOW_DAG, DagModel +from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( + DagInfo, InfoJsonEncodable, OpenLineageRedactor, _get_all_packages_installed, @@ -40,11 +42,18 @@ get_fully_qualified_class_name, is_operator_disabled, ) +from airflow.serialization.enums import DagAttributeTypes from airflow.utils import timezone from airflow.utils.log.secrets_masker import _secrets_masker from airflow.utils.state import State -from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, BashOperator +from tests_common.test_utils.compat import ( + AIRFLOW_V_2_8_PLUS, + AIRFLOW_V_2_9_PLUS, + AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, + BashOperator, +) if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType @@ -322,3 +331,117 @@ def test_does_not_include_full_task_info(mock_include_full_task_info): MagicMock(), )["airflow"].task ) + + +@pytest.mark.db_test +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") +def test_serialize_timetable(): + from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny + from airflow.timetables.simple import AssetTriggeredTimetable + + asset = AssetAny( + Asset("2"), + AssetAlias("example-alias"), + Asset("3"), + AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + ) + dag = MagicMock() + dag.timetable = AssetTriggeredTimetable(asset) + dag_info = DagInfo(dag) + + assert dag_info.timetable == { + "asset_condition": { + "__type": DagAttributeTypes.ASSET_ANY, + "objects": [ + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"}, + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"}, + { + "__type": DagAttributeTypes.ASSET_ALL, + "objects": [ + {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + {"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"}, + ], + }, + ], + } + } + + +@pytest.mark.db_test +@pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS, + reason="This test checks serialization only in 2.10 conditions", +) +def test_serialize_timetable_2_10(): + from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny + from airflow.timetables.simple import DatasetTriggeredTimetable + + asset = AssetAny( + Asset("2"), + AssetAlias("example-alias"), + Asset("3"), + AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + ) + + dag = MagicMock() + dag.timetable = DatasetTriggeredTimetable(asset) + dag_info = DagInfo(dag) + + assert dag_info.timetable == { + "dataset_condition": { + "__type": DagAttributeTypes.DATASET_ANY, + "objects": [ + {"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "2"}, + {"__type": DagAttributeTypes.DATASET_ANY, "objects": []}, + {"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "3"}, + { + "__type": DagAttributeTypes.DATASET_ALL, + "objects": [ + {"__type": DagAttributeTypes.DATASET_ANY, "objects": []}, + {"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "4"}, + ], + }, + ], + } + } + + +@pytest.mark.skipif( + not AIRFLOW_V_2_9_PLUS or AIRFLOW_V_2_10_PLUS, + reason="This test checks serialization only in 2.9 conditions", +) +def test_serialize_timetable_2_9(): + dag = MagicMock() + dag.timetable.serialize.return_value = {} + dag.dataset_triggers = [Asset("a"), Asset("b")] + dag_info = DagInfo(dag) + assert dag_info.timetable == { + "dataset_condition": { + "__type": "dataset_all", + "objects": [ + {"__type": "dataset", "extra": None, "uri": "a"}, + {"__type": "dataset", "extra": None, "uri": "b"}, + ], + } + } + + +@pytest.mark.skipif( + not AIRFLOW_V_2_8_PLUS or AIRFLOW_V_2_9_PLUS, + reason="This test checks serialization only in 2.8 conditions", +) +def test_serialize_timetable_2_8(): + dag = MagicMock() + dag.timetable.serialize.return_value = {} + dag.dataset_triggers = [Asset("a"), Asset("b")] + dag_info = DagInfo(dag) + assert dag_info.timetable == { + "dataset_condition": { + "__type": "dataset_all", + "objects": [ + {"__type": "dataset", "extra": None, "uri": "a"}, + {"__type": "dataset", "extra": None, "uri": "b"}, + ], + } + }