Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_leaves calculation for teardown in nested group #36456

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,16 @@ def get_leaves(self) -> Generator[BaseOperator, None, None]:
tasks = list(self)
ids = {x.task_id for x in tasks}

def has_non_teardown_downstream(task, exclude: str):
for down_task in task.downstream_list:
if down_task.task_id == exclude:
continue
elif down_task.task_id not in ids:
continue
elif not down_task.is_teardown:
return True
return False

def recurse_for_first_non_teardown(task):
for upstream_task in task.upstream_list:
if upstream_task.task_id not in ids:
Expand All @@ -381,7 +391,8 @@ def recurse_for_first_non_teardown(task):
elif task.is_teardown and upstream_task.is_setup:
# don't go through the teardown-to-setup path
continue
else:
# return unless upstream task already has non-teardown downstream in group
elif not has_non_teardown_downstream(upstream_task, exclude=task.task_id):
yield upstream_task

for task in tasks:
Expand Down
47 changes: 47 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,53 @@ def test_dag_edges_setup_teardown():
]


def test_dag_edges_setup_teardown_nested():
from airflow.decorators import task, task_group
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator

execution_date = pendulum.parse("20200101")

with DAG(dag_id="s_t_dag", start_date=execution_date) as dag:

@task
def test_task():
print("Hello world!")

@task_group
def inner():
inner_start = EmptyOperator(task_id="start")
inner_end = EmptyOperator(task_id="end")

test_task_r = test_task.override(task_id="work")()
inner_start >> test_task_r >> inner_end.as_teardown(setups=inner_start)

@task_group
def outer():
outer_work = EmptyOperator(task_id="work")
inner_group = inner()
inner_group >> outer_work

dag_start = EmptyOperator(task_id="dag_start")
dag_end = EmptyOperator(task_id="dag_end")
dag_start >> outer() >> dag_end

edges = dag_edges(dag)

actual = sorted((e["source_id"], e["target_id"], e.get("is_setup_teardown")) for e in edges)
assert actual == [
("dag_start", "outer.upstream_join_id", None),
("outer.downstream_join_id", "dag_end", None),
("outer.inner.downstream_join_id", "outer.work", None),
("outer.inner.start", "outer.inner.end", True),
("outer.inner.start", "outer.inner.work", None),
("outer.inner.work", "outer.inner.downstream_join_id", None),
("outer.inner.work", "outer.inner.end", None),
("outer.upstream_join_id", "outer.inner.start", None),
("outer.work", "outer.downstream_join_id", None),
]


def test_duplicate_group_id():
from airflow.exceptions import DuplicateTaskIdFound

Expand Down