Skip to content

Commit

Permalink
Ensure that manually creating a DAG run doesn't "block" the scheduler…
Browse files Browse the repository at this point in the history
… (#11732)

It was possible to "block" the scheduler such that it would not
schedule or queue tasks for a dag if you triggered a DAG run when the
DAG was already at the max active runs.

This approach works around the problem for now, but a better longer term
fix for this would be to introduce a "queued" state for DagRuns, and
then when manually creating dag runs (or clearing) set it to queued, and
only have the scheduler set DagRuns to running, nothing else -- this
would mean we wouldn't need to examine active runs in the TI part of the
scheduler loop, only in DagRun creation part.

Fixes #11582

GitOrigin-RevId: f603b36aa4a07bf98ebe3b1c81676748173b8b57
  • Loading branch information
ashb authored and Cloud Composer Team committed Sep 15, 2021
1 parent 2b13a2f commit a4c219c
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 30 deletions.
32 changes: 23 additions & 9 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,17 +1467,24 @@ def _do_scheduling(self, session) -> int:
# The longer term fix would be to have `clear` do this, and put DagRuns
# in to the queued state, then take DRs out of queued before creating
# any new ones
# TODO[HA]: Why is this on TI, not on DagRun??
currently_active_runs = dict(session.query(

# Build up a set of execution_dates that are "active" for a given
# dag_id -- only tasks from those runs will be scheduled.
active_runs_by_dag_id = defaultdict(set)

query = session.query(
TI.dag_id,
func.count(TI.execution_date.distinct()),
TI.execution_date,
).filter(
TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
TI.state.notin_(list(State.finished))
).group_by(TI.dag_id).all())
).group_by(TI.dag_id, TI.execution_date)

for dag_id, execution_date in query:
active_runs_by_dag_id[dag_id].add(execution_date)

for dag_run in dag_runs:
self._schedule_dag_run(dag_run, currently_active_runs.get(dag_run.dag_id, 0), session)
self._schedule_dag_run(dag_run, active_runs_by_dag_id.get(dag_run.dag_id, set()), session)

guard.commit()

Expand Down Expand Up @@ -1588,7 +1595,12 @@ def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Sess
dag_model.next_dagrun, dag_model.next_dagrun_create_after = \
dag.next_dagrun_info(dag_model.next_dagrun)

def _schedule_dag_run(self, dag_run: DagRun, currently_active_runs: int, session: Session) -> int:
def _schedule_dag_run(
self,
dag_run: DagRun,
currently_active_runs: Set[datetime.datetime],
session: Session,
) -> int:
"""
Make scheduling decisions about an individual dag run
Expand Down Expand Up @@ -1640,11 +1652,13 @@ def _schedule_dag_run(self, dag_run: DagRun, currently_active_runs: int, session
return 0

if dag.max_active_runs:
if currently_active_runs >= dag.max_active_runs:
if len(currently_active_runs) >= dag.max_active_runs and \
dag_run.execution_date not in currently_active_runs:
self.log.info(
"DAG %s already has %d active runs, not queuing any more tasks",
"DAG %s already has %d active runs, not queuing any tasks for run %s",
dag.dag_id,
currently_active_runs,
len(currently_active_runs),
dag_run.execution_date,
)
return 0

Expand Down
132 changes: 111 additions & 21 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_
ti.start_date = start_date
ti.end_date = end_date

count = scheduler._schedule_dag_run(dr, 0, session)
count = scheduler._schedule_dag_run(dr, set(), session)
assert count == 1

session.refresh(ti)
Expand Down Expand Up @@ -469,7 +469,7 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency(
ti.start_date = start_date
ti.end_date = end_date

count = scheduler._schedule_dag_run(dr, 0, session)
count = scheduler._schedule_dag_run(dr, set(), session)
assert count == 1

session.refresh(ti)
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state,
ti.start_date = start_date
ti.end_date = end_date

count = scheduler._schedule_dag_run(dr, 0, session)
count = scheduler._schedule_dag_run(dr, set(), session)
assert count == 2

session.refresh(tis[0])
Expand Down Expand Up @@ -569,7 +569,7 @@ def test_scheduler_job_add_new_task(self):
BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
SerializedDagModel.write_dag(dag=dag)

scheduled_tis = scheduler._schedule_dag_run(dr, 0, session)
scheduled_tis = scheduler._schedule_dag_run(dr, set(), session)
session.flush()
assert scheduled_tis == 2

Expand Down Expand Up @@ -640,11 +640,11 @@ def test_runs_respected_after_clear(self):
# and schedule them in, so we can check how many
# tasks are put on the task_instances_list (should be one, not 3)
with create_session() as session:
num_scheduled = scheduler._schedule_dag_run(dr1, 0, session)
num_scheduled = scheduler._schedule_dag_run(dr1, set(), session)
assert num_scheduled == 1
num_scheduled = scheduler._schedule_dag_run(dr2, 1, session)
num_scheduled = scheduler._schedule_dag_run(dr2, {dr1.execution_date}, session)
assert num_scheduled == 0
num_scheduled = scheduler._schedule_dag_run(dr3, 1, session)
num_scheduled = scheduler._schedule_dag_run(dr3, {dr1.execution_date}, session)
assert num_scheduled == 0

@patch.object(TaskInstance, 'handle_failure')
Expand Down Expand Up @@ -748,7 +748,7 @@ def test_should_mark_dummy_task_as_success(self):
dr = drs[0]

# Schedule TaskInstances
scheduler_job._schedule_dag_run(dr, 0, session)
scheduler_job._schedule_dag_run(dr, {}, session)
with create_session() as session:
tis = session.query(TaskInstance).all()

Expand All @@ -773,7 +773,7 @@ def test_should_mark_dummy_task_as_success(self):
self.assertIsNone(end_date)
self.assertIsNone(duration)

scheduler_job._schedule_dag_run(dr, 0, session)
scheduler_job._schedule_dag_run(dr, {}, session)
with create_session() as session:
tis = session.query(TaskInstance).all()

Expand Down Expand Up @@ -2006,7 +2006,7 @@ def test_dagrun_timeout_verify_max_active_runs(self):
scheduler.processor_agent = mock.Mock()
scheduler.processor_agent.send_callback_to_execute = mock.Mock()

scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)
session.flush()

session.refresh(dr)
Expand Down Expand Up @@ -2068,7 +2068,7 @@ def test_dagrun_timeout_fails_run(self):
scheduler.processor_agent = mock.Mock()
scheduler.processor_agent.send_callback_to_execute = mock.Mock()

scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)
session.flush()

session.refresh(dr)
Expand Down Expand Up @@ -2129,7 +2129,7 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
ti = dr.get_task_instance('dummy')
ti.set_state(state, session)

scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)

expected_callback = DagCallbackRequest(
full_filepath=dr.dag.fileloc,
Expand Down Expand Up @@ -2538,13 +2538,13 @@ def test_scheduler_verify_pool_full(self):
execution_date=DEFAULT_DATE,
state=State.RUNNING,
)
scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)
dr = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag.following_schedule(dr.execution_date),
state=State.RUNNING,
)
scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)

task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)

Expand Down Expand Up @@ -2593,7 +2593,7 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self):
execution_date=date,
state=State.RUNNING,
)
scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)
date = dag.following_schedule(date)

task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)
Expand Down Expand Up @@ -2664,7 +2664,7 @@ def test_scheduler_verify_priority_and_slots(self):
execution_date=DEFAULT_DATE,
state=State.RUNNING,
)
scheduler._schedule_dag_run(dr, 0, session)
scheduler._schedule_dag_run(dr, {}, session)

task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session)

Expand Down Expand Up @@ -2712,7 +2712,7 @@ def test_verify_integrity_if_dag_not_changed(self):

# Verify that DagRun.verify_integrity is not called
with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
scheduled_tis = scheduler._schedule_dag_run(dr, 0, session)
scheduled_tis = scheduler._schedule_dag_run(dr, {}, session)
mock_verify_integrity.assert_not_called()
session.flush()

Expand Down Expand Up @@ -2771,7 +2771,7 @@ def test_verify_integrity_if_dag_changed(self):
dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dag_version_2 != dag_version_1

scheduled_tis = scheduler._schedule_dag_run(dr, 0, session)
scheduled_tis = scheduler._schedule_dag_run(dr, {}, session)
session.flush()

assert scheduled_tis == 2
Expand Down Expand Up @@ -3523,7 +3523,7 @@ def test_do_schedule_max_active_runs_upstream_failed(self):
"""

with DAG(
dag_id='test_max_active_run_plus_manual_trigger',
dag_id='test_max_active_run_with_upstream_failed',
start_date=DEFAULT_DATE,
schedule_interval='@once',
max_active_runs=1,
Expand Down Expand Up @@ -3565,13 +3565,103 @@ def test_do_schedule_max_active_runs_upstream_failed(self):
job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)

num_queued = job._do_scheduling(session)
session.flush()

assert num_queued == 1
ti = run2.get_task_instance(task1.task_id, session)
assert ti.state == State.QUEUED

session.rollback()
def test_do_schedule_max_active_runs_and_manual_trigger(self):
"""
Make sure that when a DAG is already at max_active_runs, that manually triggering a run doesn't cause
the dag to "stall".
"""

with DAG(
dag_id='test_max_active_run_plus_manual_trigger',
start_date=DEFAULT_DATE,
schedule_interval='@once',
max_active_runs=1,
) as dag:
# Cant use DummyOperator as that goes straight to success
task1 = BashOperator(task_id='dummy1', bash_command='true')
task2 = BashOperator(task_id='dummy2', bash_command='true')

task1 >> task2

task3 = BashOperator(task_id='dummy3', bash_command='true')

session = settings.Session()
dagbag = DagBag(
dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"),
include_examples=False,
read_dags_from_db=True
)

dagbag.bag_dag(dag=dag, root_dag=dag)
dagbag.sync_to_db(session=session)

dag_run = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=DEFAULT_DATE,
state=State.RUNNING,
session=session,
)

dag.sync_to_db(session=session) # Update the date fields

job = SchedulerJob()
job.executor = MockExecutor(do_update=False)
job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)

num_queued = job._do_scheduling(session)
# Add it back in to the session so we can refresh it. (_do_scheduling does an expunge_all to reduce
# memory)
session.add(dag_run)
session.refresh(dag_run)

assert num_queued == 2
assert dag_run.state == State.RUNNING
ti1 = dag_run.get_task_instance(task1.task_id, session)
assert ti1.state == State.QUEUED

# Set task1 to success (so task2 can run) but keep task3 as "running"
ti1.state = State.SUCCESS

ti3 = dag_run.get_task_instance(task3.task_id, session)
ti3.state = State.RUNNING

session.flush()

# At this point, ti2 and ti3 of the scheduled dag run should be running
num_queued = job._do_scheduling(session)

assert num_queued == 1
# Should have queued task2
ti2 = dag_run.get_task_instance(task2.task_id, session)
assert ti2.state == State.QUEUED

ti2.state = None
session.flush()

# Now that this one is running, manually trigger a dag.

manual_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE + timedelta(hours=1),
state=State.RUNNING,
session=session,
)
session.flush()

num_queued = job._do_scheduling(session)

assert num_queued == 1
# Should have queued task2 again.
ti2 = dag_run.get_task_instance(task2.task_id, session)
assert ti2.state == State.QUEUED
# Manual run shouldn't have been started, because we're at max_active_runs with DR1
ti1 = manual_run.get_task_instance(task1.task_id, session)
assert ti1.state is None


@pytest.mark.xfail(reason="Work out where this goes")
Expand Down

0 comments on commit a4c219c

Please sign in to comment.