Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix catchup by limiting queued dagrun creation using max_active_runs #18897

Merged
merged 3 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ Now if you resolve a ``Param`` without a default and don't pass a value, you wil
Param().resolve() # raises TypeError
```

### `max_queued_runs_per_dag` configuration has been removed
jedcunningham marked this conversation as resolved.
Show resolved Hide resolved

The `max_queued_runs_per_dag` configuration option in `[core]` section has been removed. Previously, this controlled the number of queued dagrun
the scheduler can create in a dag. Now, the maximum number is controlled internally by the DAG's `max_active_runs`

## Airflow 2.2.0

### `worker_log_server_port` configuration has been moved to the ``logging`` section.
Expand Down
8 changes: 0 additions & 8 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@
type: string
example: ~
default: "16"
- name: max_queued_runs_per_dag
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved
description: |
The maximum number of queued dagruns for a single DAG. The scheduler will not create more DAG runs
if it reaches the limit. This is not configurable at the DAG level.
version_added: 2.1.4
type: string
example: ~
default: "16"
- name: load_examples
description: |
Whether to load the DAG examples that ship with Airflow. It's good to
Expand Down
4 changes: 0 additions & 4 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ dags_are_paused_at_creation = True
# which is defaulted as ``max_active_runs_per_dag``.
max_active_runs_per_dag = 16

# The maximum number of queued dagruns for a single DAG. The scheduler will not create more DAG runs
# if it reaches the limit. This is not configurable at the DAG level.
max_queued_runs_per_dag = 16

# Whether to load the DAG examples that ship with Airflow. It's good to
# get started, but you probably want to set this to ``False`` in a production
# environment
Expand Down
58 changes: 32 additions & 26 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,30 +839,19 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
existing_dagruns = (
session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all()
)
max_queued_dagruns = conf.getint('core', 'max_queued_runs_per_dag')

queued_runs_of_dags = defaultdict(
active_runs_of_dags = defaultdict(
int,
session.query(DagRun.dag_id, func.count('*'))
.filter( # We use `list` here because SQLA doesn't accept a set
# We use set to avoid duplicate dag_ids
DagRun.dag_id.in_(list({dm.dag_id for dm in dag_models})),
DagRun.state == State.QUEUED,
)
.group_by(DagRun.dag_id)
.all(),
DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in dag_models), session=session),
)

for dag_model in dag_models:
# Lets quickly check if we have exceeded the number of queued dagruns per dags
total_queued = queued_runs_of_dags[dag_model.dag_id]
if total_queued >= max_queued_dagruns:
continue

dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)

data_interval = dag.get_next_data_interval(dag_model)
Expand All @@ -885,12 +874,28 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
dag_hash=dag_hash,
creating_job_id=self.id,
)
queued_runs_of_dags[dag_model.dag_id] += 1
dag_model.calculate_dagrun_date_fields(dag, data_interval)

active_runs_of_dags[dag.dag_id] += 1
self._update_dag_next_dagruns(dag, dag_model, active_runs_of_dags[dag.dag_id])
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
# TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
# memory for larger dags? or expunge_all()

def _update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs) -> None:
"""
Update the next_dagrun, next_dagrun_data_interval_start/end
and next_dagrun_create_after for this dag.
"""
if total_active_runs >= dag_model.max_active_runs:
self.log.info(
"DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
dag_model.dag_id,
total_active_runs,
dag_model.max_active_runs,
)
dag_model.next_dagrun_create_after = None
else:
data_interval = dag.get_next_data_interval(dag_model)
dag_model.calculate_dagrun_date_fields(dag, data_interval)

def _start_queued_dagruns(
self,
session: Session,
Expand All @@ -899,15 +904,8 @@ def _start_queued_dagruns(
dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)

active_runs_of_dags = defaultdict(
lambda: 0,
session.query(DagRun.dag_id, func.count('*'))
.filter( # We use `list` here because SQLA doesn't accept a set
# We use set to avoid duplicate dag_ids
DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})),
DagRun.state == State.RUNNING,
)
.group_by(DagRun.dag_id)
.all(),
int,
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
)

def _update_state(dag: DAG, dag_run: DagRun):
Expand Down Expand Up @@ -958,6 +956,7 @@ def _schedule_dag_run(
if not dag:
self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id)
return 0
dag_model = DM.get_dagmodel(dag.dag_id, session)

if (
dag_run.start_date
Expand All @@ -976,6 +975,9 @@ def _schedule_dag_run(
session.merge(task_instance)
session.flush()
self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
self._update_dag_next_dagruns(dag, dag_model, active_runs)

callback_to_execute = DagCallbackRequest(
full_filepath=dag.fileloc,
Expand All @@ -997,6 +999,10 @@ def _schedule_dag_run(
self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
if dag_run.state in State.finished:
active_runs = dag.get_num_active_runs(only_running=False, session=session)
# Work out if we should allow creating a new DagRun now?
self._update_dag_next_dagruns(dag, dag_model, active_runs)

# This will do one query per dag run. We "could" build up a complex
# query to update all the TIs across all the execution dates and dag
Expand Down
21 changes: 14 additions & 7 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ def get_active_runs(self):
return active_dates

@provide_session
def get_num_active_runs(self, external_trigger=None, session=None):
def get_num_active_runs(self, external_trigger=None, only_running=True, session=None):
"""
Returns the number of active "running" dag runs

Expand All @@ -1148,11 +1148,11 @@ def get_num_active_runs(self, external_trigger=None, session=None):
:return: number greater than 0 for active dag runs
"""
# .count() is inefficient
query = (
session.query(func.count())
.filter(DagRun.dag_id == self.dag_id)
.filter(DagRun.state == State.RUNNING)
)
query = session.query(func.count()).filter(DagRun.dag_id == self.dag_id)
if only_running:
query = query.filter(DagRun.state == State.RUNNING)
else:
query = query.filter(DagRun.state.in_({State.RUNNING, State.QUEUED}))

if external_trigger is not None:
query = query.filter(
Expand Down Expand Up @@ -2423,6 +2423,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
)
most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter}

# Get number of active dagruns for all dags we are processing as a single query.

num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dag_ids, session=session)

filelocs = []

for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
Expand Down Expand Up @@ -2451,7 +2455,10 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=None):
data_interval = None
else:
data_interval = dag.get_run_data_interval(run)
orm_dag.calculate_dagrun_date_fields(dag, data_interval)
if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs:
orm_dag.next_dagrun_create_after = None
else:
orm_dag.calculate_dagrun_date_fields(dag, data_interval)

for orm_tag in list(orm_dag.tags):
if orm_tag.name not in set(dag.tags):
Expand Down
18 changes: 17 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union

from sqlalchemy import (
Boolean,
Expand Down Expand Up @@ -207,6 +207,22 @@ def refresh_from_db(self, session: Session = None):
self.id = dr.id
self.state = dr.state

@classmethod
@provide_session
def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> Dict[str, int]:
"""Get the number of active dag runs for each dag."""
query = session.query(cls.dag_id, func.count('*'))
if dag_ids is not None:
# 'set' called to avoid duplicate dag_ids, but converted back to 'list'
# because SQLAlchemy doesn't accept a set here.
query = query.filter(cls.dag_id.in_(list(set(dag_ids))))
if only_running:
query = query.filter(cls.state == State.RUNNING)
else:
query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
query = query.group_by(cls.dag_id)
return {dag_id: count for dag_id, count in query.all()}

@classmethod
def next_dagruns_to_examine(
cls,
Expand Down
110 changes: 97 additions & 13 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from airflow.utils.callback_requests import DagCallbackRequest
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars, env_vars
Expand Down Expand Up @@ -1082,9 +1082,9 @@ def test_cleanup_methods_all_called(self, mock_processor_agent):
self.scheduler_job.executor.end.assert_called_once()
mock_processor_agent.return_value.end.reset_mock(side_effect=True)

def test_theres_limit_to_queued_dagruns_in_a_dag(self, dag_maker):
"""This tests that there's limit to the number of queued dagrun scheduler can create in a dag"""
with dag_maker() as dag:
def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_maker):
"""This tests that queued dagruns stops creating once max_active_runs is reached"""
with dag_maker(max_active_runs=10) as dag:
DummyOperator(task_id='mytask')

session = settings.Session()
Expand All @@ -1099,13 +1099,20 @@ def test_theres_limit_to_queued_dagruns_in_a_dag(self, dag_maker):
assert orm_dag is not None
for _ in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 16
drs = session.query(DagRun).all()
assert len(drs) == 10

with conf_vars({('core', 'max_queued_runs_per_dag'): '5'}):
clear_db_runs()
for i in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 5
for dr in drs:
dr.state = State.RUNNING
session.merge(dr)
session.commit()
assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10
for _ in range(20):
self.scheduler_job._create_dag_runs([orm_dag], session)
assert session.query(DagRun).count() == 10
assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10
assert session.query(DagRun.state).filter(DagRun.state == State.QUEUED).count() == 0
assert orm_dag.next_dagrun_create_after is None

def test_dagrun_timeout_verify_max_active_runs(self, dag_maker):
"""
Expand Down Expand Up @@ -1137,9 +1144,7 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker):
assert len(drs) == 1
dr = drs[0]

# This should have a value since we control max_active_runs
# by DagRun State.
assert orm_dag.next_dagrun_create_after
assert orm_dag.next_dagrun_create_after is None
# But we should record the date of _what run_ it would be
assert isinstance(orm_dag.next_dagrun, datetime.datetime)
assert isinstance(orm_dag.next_dagrun_data_interval_start, datetime.datetime)
Expand Down Expand Up @@ -1219,6 +1224,40 @@ def test_dagrun_timeout_fails_run(self, dag_maker):
session.rollback()
session.close()

def test_dagrun_timeout_fails_run_and_update_next_dagrun(self, dag_maker):
"""
Test that dagrun timeout fails run and update the next dagrun
"""
session = settings.Session()
with dag_maker(
max_active_runs=1,
dag_id='test_scheduler_fail_dagrun_timeout',
dagrun_timeout=datetime.timedelta(seconds=60),
):
DummyOperator(task_id='dummy')

dr = dag_maker.create_dagrun(start_date=timezone.utcnow() - datetime.timedelta(days=1))
# check that next_dagrun is dr.execution_date
dag_maker.dag_model.next_dagrun == dr.execution_date
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.dagbag = dag_maker.dagbag

# Mock that processor_agent is started
self.scheduler_job.processor_agent = mock.Mock()
self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()

self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
session.refresh(dr)
assert dr.state == State.FAILED
# check that next_dagrun has been updated by Schedulerjob._update_dag_next_dagruns
assert dag_maker.dag_model.next_dagrun == dr.execution_date + timedelta(days=1)
# check that no running/queued runs yet
assert (
session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count()
== 0
)

@pytest.mark.parametrize(
"state, expected_callback_msg", [(State.SUCCESS, "success"), (State.FAILED, "task_failure")]
)
Expand Down Expand Up @@ -2738,6 +2777,51 @@ def test_do_schedule_max_active_runs_task_removed(self, session, dag_maker):
ti.refresh_from_db(session=session)
assert ti.state == State.QUEUED

def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_maker, caplog):
"""
This tests that when max_active_runs is reached, _create_dag_runs doesn't create
more dagruns
"""
with dag_maker(max_active_runs=1):
DummyOperator(task_id='task')
self.scheduler_job = SchedulerJob(subdir=os.devnull)
self.scheduler_job.executor = MockExecutor(do_update=False)
self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
session = settings.Session()
assert session.query(DagRun).count() == 0
dag_models = DagModel.dags_needing_dagruns(session).all()
self.scheduler_job._create_dag_runs(dag_models, session)
dr = session.query(DagRun).one()
dr.state == DagRunState.QUEUED
assert session.query(DagRun).count() == 1
assert dag_maker.dag_model.next_dagrun_create_after is None
session.flush()
# dags_needing_dagruns query should not return any value
assert len(DagModel.dags_needing_dagruns(session).all()) == 0
self.scheduler_job._create_dag_runs(dag_models, session)
assert session.query(DagRun).count() == 1
assert dag_maker.dag_model.next_dagrun_create_after is None
assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE
# set dagrun to success
dr = session.query(DagRun).one()
dr.state = DagRunState.SUCCESS
ti = dr.get_task_instance('task', session)
ti.state = TaskInstanceState.SUCCESS
session.merge(ti)
session.merge(dr)
session.flush()
# check that next_dagrun is set properly by Schedulerjob._update_dag_next_dagruns
self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
assert len(DagModel.dags_needing_dagruns(session).all()) == 1
# assert next_dagrun has been updated correctly
assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1)
# assert no dagruns is created yet
assert (
session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count()
== 0
)

def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker):
"""
Make sure that when a DAG is already at max_active_runs, that manually triggered
Expand Down
Loading