diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 22cde7febf..d11f490247 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -468,7 +468,20 @@ def set_dag_run_state_to_failed( task.dag = dag tasks.append(task) - return set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session) + # Mark non-finished tasks as SKIPPED. + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.run_id == run_id, + TaskInstance.state.not_in(State.finished), + TaskInstance.state.not_in(State.running), + ) + + tis = [ti for ti in tis] + if commit: + for ti in tis: + ti.set_state(State.SKIPPED) + + return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session) def __set_dag_run_state_to_running_or_queued( diff --git a/airflow/www/views.py b/airflow/www/views.py index 41740d9e64..f3113a251c 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2119,7 +2119,7 @@ def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed, origin): response = self.render_template( 'airflow/confirm.html', - message="Here's the list of task instances you are about to mark as failed", + message="Here's the list of task instances you are about to mark as failed or skipped", details=details, ) diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py index 7e461dc4e3..3a3bcfc621 100644 --- a/tests/api/common/test_mark_tasks.py +++ b/tests/api/common/test_mark_tasks.py @@ -450,6 +450,20 @@ def compare(x, y): return len([s for s in states if compare(s, state)]) + def _get_num_tasks_with_non_completed_state(self): + """ + Return the non completed tasks. + :return: number of tasks in non completed state (SUCCESS, FAILED, SKIPPED, UPSTREAM_FAILED) + """ + expected = len(self.INITIAL_TASK_STATES.values()) - self._get_num_tasks_with_starting_state( + State.SUCCESS, inclusion=True + ) + expected = expected - self._get_num_tasks_with_starting_state(State.FAILED, inclusion=True) + expected = expected - self._get_num_tasks_with_starting_state(State.SKIPPED, inclusion=True) + expected = expected - self._get_num_tasks_with_starting_state(State.UPSTREAM_FAILED, inclusion=True) + + return expected + def _set_default_task_instance_states(self, dr): for task_id, state in self.INITIAL_TASK_STATES.items(): dr.get_task_instance(task_id).set_state(state) @@ -514,8 +528,8 @@ def test_set_running_dag_run_to_failed(self): self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True) - # Only running task should be altered. - expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True) + # Only non-completed tasks should be altered. + expected = self._get_num_tasks_with_non_completed_state() assert len(altered) == expected self._verify_dag_run_state(self.dag1, date, State.FAILED) assert dr.get_task_instance('run_after_loop').state == State.FAILED @@ -561,8 +575,8 @@ def test_set_success_dag_run_to_failed(self): self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True) - # Only running task should be altered. - expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True) + # Only non-completed tasks should be altered. + expected = self._get_num_tasks_with_non_completed_state() assert len(altered) == expected self._verify_dag_run_state(self.dag1, date, State.FAILED) assert dr.get_task_instance('run_after_loop').state == State.FAILED @@ -609,8 +623,8 @@ def test_set_failed_dag_run_to_failed(self): altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True) - # Only running task should be altered. - expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True) + # Only non-completed tasks should be altered. + expected = self._get_num_tasks_with_non_completed_state() assert len(altered) == expected self._verify_dag_run_state(self.dag1, date, State.FAILED) assert dr.get_task_instance('run_after_loop').state == State.FAILED @@ -655,8 +669,8 @@ def test_set_state_without_commit(self): will_be_altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=False) - # Only the running task should be altered. - expected = self._get_num_tasks_with_starting_state(State.RUNNING, inclusion=True) + # Only the non-completed tasks should be altered. + expected = self._get_num_tasks_with_non_completed_state() assert len(will_be_altered) == expected self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr)