Skip to content

Commit

Permalink
Fixing task status for non-running and non-committed tasks (#22410)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: becbb4ab443995b21d783cadfba7fbfdf3b1530d
  • Loading branch information
megan-parker authored and Cloud Composer Team committed Sep 12, 2024
1 parent 772f7ce commit eb08ffe
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
15 changes: 14 additions & 1 deletion airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,20 @@ def set_dag_run_state_to_failed(
task.dag = dag
tasks.append(task)

return set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session)
# Mark non-finished tasks as SKIPPED.
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in(State.running),
)

tis = [ti for ti in tis]
if commit:
for ti in tis:
ti.set_state(State.SKIPPED)

return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session)


def __set_dag_run_state_to_running_or_queued(
Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2119,7 +2119,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed, origin):

response = self.render_template(
'airflow/confirm.html',
message="Here's the list of task instances you are about to mark as failed",
message="Here's the list of task instances you are about to mark as failed or skipped",
details=details,
)

Expand Down
30 changes: 22 additions & 8 deletions tests/api/common/test_mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ def compare(x, y):

return len([s for s in states if compare(s, state)])

def _get_num_tasks_with_non_completed_state(self):
"""
Return the non completed tasks.
:return: number of tasks in non completed state (SUCCESS, FAILED, SKIPPED, UPSTREAM_FAILED)
"""
expected = len(self.INITIAL_TASK_STATES.values()) - self._get_num_tasks_with_starting_state(
State.SUCCESS, inclusion=True
)
expected = expected - self._get_num_tasks_with_starting_state(State.FAILED, inclusion=True)
expected = expected - self._get_num_tasks_with_starting_state(State.SKIPPED, inclusion=True)
expected = expected - self._get_num_tasks_with_starting_state(State.UPSTREAM_FAILED, inclusion=True)

return expected

def _set_default_task_instance_states(self, dr):
for task_id, state in self.INITIAL_TASK_STATES.items():
dr.get_task_instance(task_id).set_state(state)
Expand Down Expand Up @@ -514,8 +528,8 @@ def test_set_running_dag_run_to_failed(self):
self._set_default_task_instance_states(dr)

altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True)
# Only running task should be altered.
expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True)
# Only non-completed tasks should be altered.
expected = self._get_num_tasks_with_non_completed_state()
assert len(altered) == expected
self._verify_dag_run_state(self.dag1, date, State.FAILED)
assert dr.get_task_instance('run_after_loop').state == State.FAILED
Expand Down Expand Up @@ -561,8 +575,8 @@ def test_set_success_dag_run_to_failed(self):
self._set_default_task_instance_states(dr)

altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True)
# Only running task should be altered.
expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True)
# Only non-completed tasks should be altered.
expected = self._get_num_tasks_with_non_completed_state()
assert len(altered) == expected
self._verify_dag_run_state(self.dag1, date, State.FAILED)
assert dr.get_task_instance('run_after_loop').state == State.FAILED
Expand Down Expand Up @@ -609,8 +623,8 @@ def test_set_failed_dag_run_to_failed(self):

altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True)

# Only running task should be altered.
expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True)
# Only non-completed tasks should be altered.
expected = self._get_num_tasks_with_non_completed_state()
assert len(altered) == expected
self._verify_dag_run_state(self.dag1, date, State.FAILED)
assert dr.get_task_instance('run_after_loop').state == State.FAILED
Expand Down Expand Up @@ -655,8 +669,8 @@ def test_set_state_without_commit(self):

will_be_altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=False)

# Only the running task should be altered.
expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True)
# Only the non-completed tasks should be altered.
expected = self._get_num_tasks_with_non_completed_state()
assert len(will_be_altered) == expected
self._verify_dag_run_state(self.dag1, date, State.RUNNING)
self._verify_task_instance_states_remain_default(dr)
Expand Down

0 comments on commit eb08ffe

Please sign in to comment.