From 5a5151fc0690df2b46d536e8053518a6e716f751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Mon, 28 Aug 2023 21:33:47 +0200 Subject: [PATCH] Refactor unneeded 'continue' jumps in api --- airflow/api/common/delete_dag.py | 4 +--- airflow/api/common/mark_tasks.py | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py index c94b3c39dfeca..1a3346775544d 100644 --- a/airflow/api/common/delete_dag.py +++ b/airflow/api/common/delete_dag.py @@ -82,9 +82,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session = count = 0 for model in get_sqla_model_classes(): - if hasattr(model, "dag_id"): - if keep_records_in_log and model.__name__ == "Log": - continue + if hasattr(model, "dag_id") and (not keep_records_in_log or model.__name__ != "Log"): count += session.execute( delete(model) .where(model.dag_id.in_(dags_to_delete)) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index d71957f86b81d..cfd7471d246fc 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -67,16 +67,15 @@ def _create_dagruns( } for info in infos: - if info.logical_date in dag_runs: - continue - dag_runs[info.logical_date] = dag.create_dagrun( - execution_date=info.logical_date, - data_interval=info.data_interval, - start_date=timezone.utcnow(), - external_trigger=False, - state=state, - run_type=run_type, - ) + if info.logical_date not in dag_runs: + dag_runs[info.logical_date] = dag.create_dagrun( + execution_date=info.logical_date, + data_interval=info.data_interval, + start_date=timezone.utcnow(), + external_trigger=False, + state=state, + run_type=run_type, + ) return dag_runs.values() @@ -493,10 +492,9 @@ def set_dag_run_state_to_failed( tasks = [] for task in dag.tasks: - if task.task_id not in task_ids_of_running_tis: - continue - task.dag = dag - tasks.append(task) + if task.task_id in task_ids_of_running_tis: + task.dag = dag + tasks.append(task) # Mark non-finished tasks as SKIPPED. tis = session.scalars(