Skip to content

Commit

Permalink
Fix race condition with dagrun callbacks (#16741)
Browse files Browse the repository at this point in the history
Instead of immediately sending callbacks to be processed, wait until
after we commit so the dagrun.end_date is guaranteed to be there when
the callback runs.

(cherry picked from commit fb3031a)
  • Loading branch information
jedcunningham authored and kaxil committed Aug 17, 2021
1 parent 7f3bdca commit 62191bd
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 24 deletions.
18 changes: 12 additions & 6 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def _do_scheduling(self, session) -> int:
# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun

callback_tuples = []
for dag_run in dag_runs:
# Use try_except to not stop the Scheduler when a Serialized DAG is not found
# This takes care of Dynamic DAGs especially
Expand All @@ -896,13 +897,18 @@ def _do_scheduling(self, session) -> int:
# But this would take care of the scenario when the Scheduler is restarted after DagRun is
# created and the DAG is deleted / renamed
try:
self._schedule_dag_run(dag_run, session)
callback_to_run = self._schedule_dag_run(dag_run, session)
callback_tuples.append((dag_run, callback_to_run))
except SerializedDagNotFound:
self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue

guard.commit()

# Send the callbacks after we commit to ensure the context is up to date when it gets run
for dag_run, callback_to_run in callback_tuples:
self._send_dag_callbacks_to_processor(dag_run, callback_to_run)

# Without this, the session has an invalid view of the DB
session.expunge_all()
# END: schedule TIs
Expand Down Expand Up @@ -1064,12 +1070,12 @@ def _schedule_dag_run(
self,
dag_run: DagRun,
session: Session,
) -> int:
) -> Optional[DagCallbackRequest]:
"""
Make scheduling decisions about an individual dag run
:param dag_run: The DagRun to schedule
:return: Number of tasks scheduled
:return: Callback that needs to be executed
"""
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)

Expand Down Expand Up @@ -1116,13 +1122,13 @@ def _schedule_dag_run(
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)

self._send_dag_callbacks_to_processor(dag_run, callback_to_run)

# This will do one query per dag run. We "could" build up a complex
# query to update all the TIs across all the execution dates and dag
# IDs in a single query, but it turns out that can be _very very slow_
# see #11147/commit ee90807ac for more details
return dag_run.schedule_tis(schedulable_tis, session)
dag_run.schedule_tis(schedulable_tis, session)

return callback_to_run

@provide_session
def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
Expand Down
20 changes: 12 additions & 8 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def setUpClass(cls):
non_serialized_dagbag.sync_to_db()
cls.dagbag = DagBag(read_dags_from_db=True)

@staticmethod
def assert_scheduled_ti_count(session, count):
assert count == session.query(TaskInstance).filter_by(state=State.SCHEDULED).count()

def test_dag_file_processor_sla_miss_callback(self):
"""
Test that the dag file processor calls the sla miss callback
Expand Down Expand Up @@ -387,8 +391,8 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_
ti.start_date = start_date
ti.end_date = end_date

count = self.scheduler_job._schedule_dag_run(dr, session)
assert count == 1
self.scheduler_job._schedule_dag_run(dr, session)
self.assert_scheduled_ti_count(session, 1)

session.refresh(ti)
assert ti.state == State.SCHEDULED
Expand Down Expand Up @@ -444,8 +448,8 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency(
ti.start_date = start_date
ti.end_date = end_date

count = self.scheduler_job._schedule_dag_run(dr, session)
assert count == 1
self.scheduler_job._schedule_dag_run(dr, session)
self.assert_scheduled_ti_count(session, 1)

session.refresh(ti)
assert ti.state == State.SCHEDULED
Expand Down Expand Up @@ -504,8 +508,8 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state,
ti.start_date = start_date
ti.end_date = end_date

count = self.scheduler_job._schedule_dag_run(dr, session)
assert count == 2
self.scheduler_job._schedule_dag_run(dr, session)
self.assert_scheduled_ti_count(session, 2)

session.refresh(tis[0])
session.refresh(tis[1])
Expand Down Expand Up @@ -547,9 +551,9 @@ def test_scheduler_job_add_new_task(self):
BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
SerializedDagModel.write_dag(dag=dag)

scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
self.scheduler_job._schedule_dag_run(dr, session)
self.assert_scheduled_ti_count(session, 2)
session.flush()
assert scheduled_tis == 2

drs = DagRun.find(dag_id=dag.dag_id, session=session)
assert len(drs) == 1
Expand Down
80 changes: 70 additions & 10 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,10 +1710,11 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
ti = dr.get_task_instance('dummy')
ti.set_state(state, session)

self.scheduler_job._schedule_dag_run(dr, session)
with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
self.scheduler_job._do_scheduling(session)

expected_callback = DagCallbackRequest(
full_filepath=dr.dag.fileloc,
full_filepath=dag.fileloc,
dag_id=dr.dag_id,
is_failure_callback=bool(state == State.FAILED),
execution_date=dr.execution_date,
Expand All @@ -1729,6 +1730,64 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg):
session.rollback()
session.close()

def test_dagrun_callbacks_commited_before_sent(self):
"""
Tests that before any callbacks are sent to the processor, the session is committed. This ensures
that the dagrun details are up to date when the callbacks are run.
"""
dag = DAG(dag_id='test_dagrun_callbacks_commited_before_sent', start_date=DEFAULT_DATE)
DummyOperator(task_id='dummy', dag=dag, owner='airflow')

self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.processor_agent = mock.Mock()
self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock()
self.scheduler_job._schedule_dag_run = mock.Mock()

# Sync DAG into DB
with mock.patch.object(settings, "STORE_DAG_CODE", False):
self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
self.scheduler_job.dagbag.sync_to_db()

session = settings.Session()
orm_dag = session.query(DagModel).get(dag.dag_id)
assert orm_dag is not None

# Create DagRun
self.scheduler_job._create_dag_runs([orm_dag], session)

drs = DagRun.find(dag_id=dag.dag_id, session=session)
assert len(drs) == 1
dr = drs[0]

ti = dr.get_task_instance('dummy')
ti.set_state(State.SUCCESS, session)

with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), mock.patch(
"airflow.jobs.scheduler_job.prohibit_commit"
) as mock_gaurd:
mock_gaurd.return_value.__enter__.return_value.commit.side_effect = session.commit

def mock_schedule_dag_run(*args, **kwargs):
mock_gaurd.reset_mock()
return None

def mock_send_dag_callbacks_to_processor(*args, **kwargs):
mock_gaurd.return_value.__enter__.return_value.commit.assert_called_once()

self.scheduler_job._send_dag_callbacks_to_processor.side_effect = (
mock_send_dag_callbacks_to_processor
)
self.scheduler_job._schedule_dag_run.side_effect = mock_schedule_dag_run

self.scheduler_job._do_scheduling(session)

# Verify dag failure callback request is sent to file processor
self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
# and mock_send_dag_callbacks_to_processor has asserted the callback was sent after a commit

session.rollback()
session.close()

@parameterized.expand([(State.SUCCESS,), (State.FAILED,)])
def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state):
"""
Expand Down Expand Up @@ -1765,10 +1824,15 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta
ti = dr.get_task_instance('test_task')
ti.set_state(state, session)

self.scheduler_job._schedule_dag_run(dr, session)
with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
self.scheduler_job._do_scheduling(session)

# Verify Callback is not set (i.e is None) when no callbacks are set on DAG
self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None)
self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
call_args = self.scheduler_job._send_dag_callbacks_to_processor.call_args[0]
assert call_args[0].dag_id == dr.dag_id
assert call_args[0].execution_date == dr.execution_date
assert call_args[1] is None

session.rollback()
session.close()
Expand Down Expand Up @@ -2411,12 +2475,10 @@ def test_verify_integrity_if_dag_not_changed(self):

# Verify that DagRun.verify_integrity is not called
with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
self.scheduler_job._schedule_dag_run(dr, session)
mock_verify_integrity.assert_not_called()
session.flush()

assert scheduled_tis == 1

tis_count = (
session.query(func.count(TaskInstance.task_id))
.filter(
Expand Down Expand Up @@ -2474,11 +2536,9 @@ def test_verify_integrity_if_dag_changed(self):
dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dag_version_2 != dag_version_1

scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
self.scheduler_job._schedule_dag_run(dr, session)
session.flush()

assert scheduled_tis == 2

drs = DagRun.find(dag_id=dag.dag_id, session=session)
assert len(drs) == 1
dr = drs[0]
Expand Down

0 comments on commit 62191bd

Please sign in to comment.