Skip to content

Commit

Permalink
Fix catchup by limiting queued dagrun creation using max_active_runs (#…
Browse files Browse the repository at this point in the history
…18897)

Currently, when catchup is True, we create a lot of dagruns limited by
max_queued_runs_per_dag setting. This is not efficient as some dagruns takes
longer to run.

This PR brings back the old behavior of not creating dagruns once max_active_runs
is reached thereby solving the catchup issue.

Now, the dagruns appears as though they were created in running state
  • Loading branch information
ephraimbuddy authored Oct 20, 2021
1 parent 5dc375a commit 05eea00
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 69 deletions.
5 changes: 5 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ Now if you resolve a ``Param`` without a default and don't pass a value, you wil
Param().resolve() # raises TypeError
```

### `max_queued_runs_per_dag` configuration has been removed

The `max_queued_runs_per_dag` configuration option in `[core]` section has been removed. Previously, this controlled the number of queued dagrun
the scheduler can create in a dag. Now, the maximum number is controlled internally by the DAG's `max_active_runs`

## Airflow 2.2.0

### `worker_log_server_port` configuration has been moved to the ``logging`` section.
Expand Down
8 changes: 0 additions & 8 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@
type: string
example: ~
default: "16"
- name: max_queued_runs_per_dag
description: |
The maximum number of queued dagruns for a single DAG. The scheduler will not create more DAG runs
if it reaches the limit. This is not configurable at the DAG level.
version_added: 2.1.4
type: string
example: ~
default: "16"
- name: load_examples
description: |
Whether to load the DAG examples that ship with Airflow. It's good to
Expand Down
4 changes: 0 additions & 4 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ dags_are_paused_at_creation = True
# which is defaulted as ``max_active_runs_per_dag``.
max_active_runs_per_dag = 16

# The maximum number of queued dagruns for a single DAG. The scheduler will not create more DAG runs
# if it reaches the limit. This is not configurable at the DAG level.
max_queued_runs_per_dag = 16

# Whether to load the DAG examples that ship with Airflow. It's good to
# get started, but you probably want to set this to ``False`` in a production
# environment
Expand Down
58 changes: 32 additions & 26 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,30 +847,19 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
existing_dagruns = (
session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all()
)
max_queued_dagruns = conf.getint('core', 'max_queued_runs_per_dag')

queued_runs_of_dags = defaultdict(
active_runs_of_dags = defaultdict(
int,
session.query(DagRun.dag_id, func.count('*'))
.filter( # We use `list` here because SQLA doesn't accept a set
# We use set to avoid duplicate dag_ids
DagRun.dag_id.in_(list({dm.dag_id for dm in dag_models})),
DagRun.state == State.QUEUED,
)
.group_by(DagRun.dag_id)
.all(),
DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in dag_models), session=session),
)

for dag_model in dag_models:
# Lets quickly check if we have exceeded the number of queued dagruns per dags
total_queued = queued_runs_of_dags[dag_model.dag_id]
if total_queued >= max_queued_dagruns:
continue

dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)

data_interval = dag.get_next_data_interval(dag_model)
Expand All @@ -893,12 +882,28 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
dag_hash=dag_hash,
creating_job_id=self.id,
)
queued_runs_of_dags[dag_model.dag_id] += 1
dag_model.calculate_dagrun_date_fields(dag, data_interval)

active_runs_of_dags[dag.dag_id] += 1
self._update_dag_next_dagruns(dag, dag_model, active_runs_of_dags[dag.dag_id])
# TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
# memory for larger dags? or expunge_all()

def _update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs) -> None:
"""
Update the next_dagrun, next_dagrun_data_interval_start/end
and next_dagrun_create_after for this dag.
"""
if total_active_runs >= dag_model.max_active_runs:
self.log.info(
"DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
dag_model.dag_id,
total_active_runs,
dag_model.max_active_runs,
)
dag_model.next_dagrun_create_after = None
else:
data_interval = dag.get_next_data_interval(dag_model)
dag_model.calculate_dagrun_date_fields(dag, data_interval)

def _start_queued_dagruns(
self,
session: Session,
Expand All @@ -907,15 +912,8 @@ def _start_queued_dagruns(
dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)

active_runs_of_dags = defaultdict(
lambda: 0,
session.query(DagRun.dag_id, func.count('*'))
.filter( # We use `list` here because SQLA doesn't accept a set
# We use set to avoid duplicate dag_ids
DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})),
DagRun.state == State.RUNNING,
)
.group_by(DagRun.dag_id)
.all(),
int,
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
)

def _update_state(dag: DAG, dag_run: DagRun):
Expand Down Expand Up @@ -966,6 +964,7 @@ def _schedule_dag_run(
if not dag:
self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id)
return 0
dag_model = DM.get_dagmodel(dag.dag_id, session)

if (
dag_run.start_date
Expand All @@ -984,6 +983,9 @@ def _schedule_dag_run(
session.merge(task_instance)
session.flush()
self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
self._update_dag_next_dagruns(dag, dag_model, active_runs)

callback_to_execute = DagCallbackRequest(
full_filepath=dag.fileloc,
Expand All @@ -1005,6 +1007,10 @@ def _schedule_dag_run(
self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
if dag_run.state in State.finished:
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
self._update_dag_next_dagruns(dag, dag_model, active_runs)

# This will do one query per dag run. We "could" build up a complex
# query to update all the TIs across all the execution dates and dag
Expand Down
21 changes: 14 additions & 7 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ def get_active_runs(self):
return active_dates

@provide_session
def get_num_active_runs(self, external_trigger=None, session=None):
def get_num_active_runs(self, external_trigger=None, only_running=True, session=None):
"""
Returns the number of active "running" dag runs
Expand All @@ -1148,11 +1148,11 @@ def get_num_active_runs(self, external_trigger=None, session=None):
:return: number greater than 0 for active dag runs
"""
# .count() is inefficient
query = (
session.query(func.count())
.filter(DagRun.dag_id == self.dag_id)
.filter(DagRun.state == State.RUNNING)
)
query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id)
if only_running:
query = query.filter(DagRun.state == State.RUNNING)
else:
query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED}))

if external_trigger is not None:
query = query.filter(
Expand Down Expand Up @@ -2423,6 +2423,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
)
most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter}

# Get number of active dagruns for all dags we are processing as a single query.

num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dag_ids, session=session)

filelocs = []

for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
Expand Down Expand Up @@ -2451,7 +2455,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
data_interval = None
else:
data_interval = dag.get_run_data_interval(run)
orm_dag.calculate_dagrun_date_fields(dag, data_interval)
if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs:
orm_dag.next_dagrun_create_after = None
else:
orm_dag.calculate_dagrun_date_fields(dag, data_interval)

for orm_tag in list(orm_dag.tags):
if orm_tag.name not in set(dag.tags):
Expand Down
18 changes: 17 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union

from sqlalchemy import (
Boolean,
Expand Down Expand Up @@ -207,6 +207,22 @@ def refresh_from_db(self, session: Session = None):
self.id = dr.id
self.state = dr.state

@classmethod
@provide_session
def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> Dict[str, int]:
"""Get the number of active dag runs for each dag."""
query = session.query(cls.dag_id, func.count('*'))
if dag_ids is not None:
# 'set' called to avoid duplicate dag_ids, but converted back to 'list'
# because SQLAlchemy doesn't accept a set here.
query = query.filter(cls.dag_id.in_(list(set(dag_ids))))
if only_running:
query = query.filter(cls.state == State.RUNNING)
else:
query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
query = query.group_by(cls.dag_id)
return {dag_id: count for dag_id, count in query.all()}

@classmethod
def next_dagruns_to_examine(
cls,
Expand Down
110 changes: 97 additions & 13 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.utils.callback_requests import DagCallbackRequest
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars, env_vars
Expand Down Expand Up @@ -1105,9 +1105,9 @@ def test_cleanup_methods_all_called(self, mock_processor_agent):
self.scheduler_job.executor.end.assert_called_once()
mock_processor_agent.return_value.end.reset_mock(side_effect=True)

def test_theres_limit_to_queued_dagruns_in_a_dag(self, dag_maker):
"""This tests that there's limit to the number of queued dagrun scheduler can create in a dag"""
with dag_maker() as dag:
def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_maker):
"""This tests that queued dagruns stops creating once max_active_runs is reached"""
with dag_maker(max_active_runs=10) as dag:
DummyOperator(task_id='mytask')

session = settings.Session()
Expand All @@ -1122,13 +1122,20 @@ def test_theres_limit_to_queued_dagruns_in_a_dag(self, dag_maker):
assert orm_dag is not None
for _ in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 16
drs = session.query(DagRun).all()
assert len(drs) == 10

with conf_vars({('core', 'max_queued_runs_per_dag'): '5'}):
clear_db_runs()
for i in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 5
for dr in drs:
dr.state = State.RUNNING
session.merge(dr)
session.commit()
assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10
for _ in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 10
assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10
assert session.query(DagRun.state).filter(DagRun.state == State.QUEUED).count() == 0
assert orm_dag.next_dagrun_create_after is None

def test_dagrun_timeout_verify_max_active_runs(self, dag_maker):
"""
Expand Down Expand Up @@ -1160,9 +1167,7 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker):
assert len(drs) == 1
dr = drs[0]

# This should have a value since we control max_active_runs
# by DagRun State.
assert orm_dag.next_dagrun_create_after
assert orm_dag.next_dagrun_create_after is None
# But we should record the date of _what run_ it would be
assert isinstance(orm_dag.next_dagrun, datetime.datetime)
assert isinstance(orm_dag.next_dagrun_data_interval_start, datetime.datetime)
Expand Down Expand Up @@ -1242,6 +1247,40 @@ def test_dagrun_timeout_fails_run(self, dag_maker):
session.rollback()
session.close()

def test_dagrun_timeout_fails_run_and_update_next_dagrun(self, dag_maker):
"""
Test that dagrun timeout fails run and update the next dagrun
"""
session = settings.Session()
with dag_maker(
max_active_runs=1,
dag_id='test_scheduler_fail_dagrun_timeout',
dagrun_timeout=datetime.timedelta(seconds=60),
):
DummyOperator(task_id='dummy')

dr = dag_maker.create_dagrun(start_date=timezone.utcnow() - datetime.timedelta(days=1))
# check that next_dagrun is dr.execution_date
dag_maker.dag_model.next_dagrun == dr.execution_date
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.dagbag = dag_maker.dagbag

# Mock that processor_agent is started
self.scheduler_job.processor_agent = mock.Mock()
self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()

self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
session.refresh(dr)
assert dr.state == State.FAILED
# check that next_dagrun has been updated by Schedulerjob._update_dag_next_dagruns
assert dag_maker.dag_model.next_dagrun == dr.execution_date + timedelta(days=1)
# check that no running/queued runs yet
assert (
session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count()
== 0
)

@pytest.mark.parametrize(
"state, expected_callback_msg", [(State.SUCCESS, "success"), (State.FAILED, "task_failure")]
)
Expand Down Expand Up @@ -2759,6 +2798,51 @@ def test_do_schedule_max_active_runs_task_removed(self, session, dag_maker):
ti.refresh_from_db(session=session)
assert ti.state == State.QUEUED

def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_maker, caplog):
"""
This tests that when max_active_runs is reached, _create_dag_runs doesn't create
more dagruns
"""
with dag_maker(max_active_runs=1):
DummyOperator(task_id='task')
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.executor = MockExecutor(do_update=False)
self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
session = settings.Session()
assert session.query(DagRun).count() == 0
dag_models = DagModel.dags_needing_dagruns(session).all()
self.scheduler_job._create_dag_runs(dag_models, session)
dr = session.query(DagRun).one()
dr.state == DagRunState.QUEUED
assert session.query(DagRun).count() == 1
assert dag_maker.dag_model.next_dagrun_create_after is None
session.flush()
# dags_needing_dagruns query should not return any value
assert len(DagModel.dags_needing_dagruns(session).all()) == 0
self.scheduler_job._create_dag_runs(dag_models, session)
assert session.query(DagRun).count() == 1
assert dag_maker.dag_model.next_dagrun_create_after is None
assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE
# set dagrun to success
dr = session.query(DagRun).one()
dr.state = DagRunState.SUCCESS
ti = dr.get_task_instance('task', session)
ti.state = TaskInstanceState.SUCCESS
session.merge(ti)
session.merge(dr)
session.flush()
# check that next_dagrun is set properly by Schedulerjob._update_dag_next_dagruns
self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
assert len(DagModel.dags_needing_dagruns(session).all()) == 1
# assert next_dagrun has been updated correctly
assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1)
# assert no dagruns is created yet
assert (
session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count()
== 0
)

def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker):
"""
Make sure that when a DAG is already at max_active_runs, that manually triggered
Expand Down
Loading

0 comments on commit 05eea00

Please sign in to comment.