Skip to content

Commit

Permalink
serialize asset/dataset timetable conditions in OpenLineage info also…
Browse files Browse the repository at this point in the history
… for Airflow 2 (#43434)

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski authored Nov 5, 2024
1 parent d8f71a2 commit 73f2eab
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@

_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
_IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0")
_IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0")

# dataset is renamed to asset since Airflow 3.0
from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails
from airflow.datasets import Dataset as Asset

if _IS_AIRFLOW_2_8_OR_HIGHER:
from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails

if _IS_AIRFLOW_2_9_OR_HIGHER:
from airflow.datasets import (
DatasetAll as AssetAll,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
except ImportError:
return None

if not hasattr(get_hook_lineage_collector(), "has_collected"):
return None
if not get_hook_lineage_collector().has_collected:
return None

Expand Down
24 changes: 23 additions & 1 deletion providers/src/airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,31 @@ class DagInfo(InfoJsonEncodable):
"start_date",
"tags",
]
casts = {"timetable": lambda dag: dag.timetable.serialize() if getattr(dag, "timetable", None) else None}
casts = {"timetable": lambda dag: DagInfo.serialize_timetable(dag)}
renames = {"_dag_id": "dag_id"}

@classmethod
def serialize_timetable(cls, dag: DAG) -> dict[str, Any]:
serialized = dag.timetable.serialize()
if serialized != {} and serialized is not None:
return serialized
if (
hasattr(dag, "dataset_triggers")
and isinstance(dag.dataset_triggers, list)
and len(dag.dataset_triggers)
):
triggers = dag.dataset_triggers
return {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "uri": trigger.uri, "extra": trigger.extra}
for trigger in triggers
],
}
}
return {}


class DagRunInfo(InfoJsonEncodable):
"""Defines encoding DagRun object to JSON."""
Expand Down
125 changes: 124 additions & 1 deletion providers/tests/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from pkg_resources import parse_version

from airflow.models import DAG as AIRFLOW_DAG, DagModel
from airflow.providers.common.compat.assets import Asset
from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet
from airflow.providers.openlineage.utils.utils import (
DagInfo,
InfoJsonEncodable,
OpenLineageRedactor,
_get_all_packages_installed,
Expand All @@ -40,11 +42,18 @@
get_fully_qualified_class_name,
is_operator_disabled,
)
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils import timezone
from airflow.utils.log.secrets_masker import _secrets_masker
from airflow.utils.state import State

from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, BashOperator
from tests_common.test_utils.compat import (
AIRFLOW_V_2_8_PLUS,
AIRFLOW_V_2_9_PLUS,
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
BashOperator,
)

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
Expand Down Expand Up @@ -322,3 +331,117 @@ def test_does_not_include_full_task_info(mock_include_full_task_info):
MagicMock(),
)["airflow"].task
)


@pytest.mark.db_test
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions")
def test_serialize_timetable():
from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny
from airflow.timetables.simple import AssetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)
dag = MagicMock()
dag.timetable = AssetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"asset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"},
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"},
],
},
],
}
}


@pytest.mark.db_test
@pytest.mark.skipif(
not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS,
reason="This test checks serialization only in 2.10 conditions",
)
def test_serialize_timetable_2_10():
from airflow.providers.common.compat.assets import AssetAlias, AssetAll, AssetAny
from airflow.timetables.simple import DatasetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)

dag = MagicMock()
dag.timetable = DatasetTriggeredTimetable(asset)
dag_info = DagInfo(dag)

assert dag_info.timetable == {
"dataset_condition": {
"__type": DagAttributeTypes.DATASET_ANY,
"objects": [
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "2"},
{"__type": DagAttributeTypes.DATASET_ANY, "objects": []},
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "3"},
{
"__type": DagAttributeTypes.DATASET_ALL,
"objects": [
{"__type": DagAttributeTypes.DATASET_ANY, "objects": []},
{"__type": DagAttributeTypes.DATASET, "extra": None, "uri": "4"},
],
},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_9_PLUS or AIRFLOW_V_2_10_PLUS,
reason="This test checks serialization only in 2.9 conditions",
)
def test_serialize_timetable_2_9():
dag = MagicMock()
dag.timetable.serialize.return_value = {}
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "extra": None, "uri": "a"},
{"__type": "dataset", "extra": None, "uri": "b"},
],
}
}


@pytest.mark.skipif(
not AIRFLOW_V_2_8_PLUS or AIRFLOW_V_2_9_PLUS,
reason="This test checks serialization only in 2.8 conditions",
)
def test_serialize_timetable_2_8():
dag = MagicMock()
dag.timetable.serialize.return_value = {}
dag.dataset_triggers = [Asset("a"), Asset("b")]
dag_info = DagInfo(dag)
assert dag_info.timetable == {
"dataset_condition": {
"__type": "dataset_all",
"objects": [
{"__type": "dataset", "extra": None, "uri": "a"},
{"__type": "dataset", "extra": None, "uri": "b"},
],
}
}

0 comments on commit 73f2eab

Please sign in to comment.