From 5512abb04510c757bbdcbf2e5d7f9aceb96bfe15 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 27 Dec 2023 11:26:54 -0800 Subject: [PATCH] Fix get_leaves calculation for teardown in nested group When arrowing `group` >> `task`, the "leaves" of `group` are connected to `task`. When calculating leaves in the group, teardown tasks are ignored, and we recurse upstream to find non-teardowns. What was happening, and what this fixes, is you might recurse to a work task that already has another non-teardown downstream in the group. In that case you should ignore the work task (because it already has a non-teardown descendent). Resolves #36345 --- airflow/utils/task_group.py | 13 +++++++++- tests/utils/test_task_group.py | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index e2f8e16ebbcdf..732205f479d46 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -371,6 +371,16 @@ def get_leaves(self) -> Generator[BaseOperator, None, None]: tasks = list(self) ids = {x.task_id for x in tasks} + def has_non_teardown_downstream(task, exclude: str): + for down_task in task.downstream_list: + if down_task.task_id == exclude: + continue + elif down_task.task_id not in ids: + continue + elif not down_task.is_teardown: + return True + return False + def recurse_for_first_non_teardown(task): for upstream_task in task.upstream_list: if upstream_task.task_id not in ids: @@ -381,7 +391,8 @@ def recurse_for_first_non_teardown(task): elif task.is_teardown and upstream_task.is_setup: # don't go through the teardown-to-setup path continue - else: + # return unless upstream task already has non-teardown downstream in group + elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id): yield upstream_task for task in tasks: diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index a9f61debc68a6..ad5d355d87500 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -584,6 +584,53 @@ def test_dag_edges_setup_teardown(): ] +def test_dag_edges_setup_teardown_nested(): + from airflow.decorators import task, task_group + from airflow.models.dag import DAG + from airflow.operators.empty import EmptyOperator + + execution_date = pendulum.parse("20200101") + + with DAG(dag_id="s_t_dag", start_date=execution_date) as dag: + + @task + def test_task(): + print("Hello world!") + + @task_group + def inner(): + inner_start = EmptyOperator(task_id="start") + inner_end = EmptyOperator(task_id="end") + + test_task_r = test_task.override(task_id="work")() + inner_start >> test_task_r >> inner_end.as_teardown(setups=inner_start) + + @task_group + def outer(): + outer_work = EmptyOperator(task_id="work") + inner_group = inner() + inner_group >> outer_work + + dag_start = EmptyOperator(task_id="dag_start") + dag_end = EmptyOperator(task_id="dag_end") + dag_start >> outer() >> dag_end + + edges = dag_edges(dag) + + actual = sorted((e["source_id"], e["target_id"], e.get("is_setup_teardown")) for e in edges) + assert actual == [ + ("dag_start", "outer.upstream_join_id", None), + ("outer.downstream_join_id", "dag_end", None), + ("outer.inner.downstream_join_id", "outer.work", None), + ("outer.inner.start", "outer.inner.end", True), + ("outer.inner.start", "outer.inner.work", None), + ("outer.inner.work", "outer.inner.downstream_join_id", None), + ("outer.inner.work", "outer.inner.end", None), + ("outer.upstream_join_id", "outer.inner.start", None), + ("outer.work", "outer.downstream_join_id", None), + ] + + def test_duplicate_group_id(): from airflow.exceptions import DuplicateTaskIdFound