Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-65: Track the serialized DAG across DagRun & TaskInstance #42690

Closed
wants to merge 12 commits into from
4 changes: 3 additions & 1 deletion airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _trigger_dag(
:param replace_microseconds: whether microseconds should be zeroed
:return: list of triggered dags
"""
from airflow.models.serialized_dag import SerializedDagModel

dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized

if dag is None or dag_id not in dag_bag.dags:
Expand Down Expand Up @@ -99,7 +101,7 @@ def _trigger_dag(
state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id),
serialized_dag=SerializedDagModel.get(dag_id),
data_interval=data_interval,
triggered_by=triggered_by,
)
Expand Down
3 changes: 2 additions & 1 deletion airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.exceptions import ParamValidationError
from airflow.models import DagModel, DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.timetables.base import DataInterval
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -347,7 +348,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
state=DagRunState.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
serialized_dag=SerializedDagModel.get(dag_id),
session=session,
triggered_by=DagRunTriggeredByType.REST_API,
)
Expand Down
15 changes: 8 additions & 7 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
serialized_dag = SerializedDagModel.get(dag.dag_id, session=session)

data_interval = dag.get_next_data_interval(dag_model)
# Explicitly check if the DagRun already exists. This is an edge case
Expand All @@ -1338,7 +1338,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
data_interval=data_interval,
external_trigger=False,
session=session,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=self.job.id,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)
Expand Down Expand Up @@ -1397,7 +1397,7 @@ def _create_dag_runs_asset_triggered(
)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
serialized_dag = SerializedDagModel.get(dag.dag_id, session=session)

# Explicitly check if the DagRun already exists. This is an edge case
# where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after`
Expand Down Expand Up @@ -1452,7 +1452,7 @@ def _create_dag_runs_asset_triggered(
state=DagRunState.QUEUED,
external_trigger=False,
session=session,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=self.job.id,
triggered_by=DagRunTriggeredByType.DATASET,
)
Expand Down Expand Up @@ -1701,12 +1701,13 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->

Return True if we determine that DAG still exists.
"""
latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session)
if dag_run.dag_hash == latest_version:
latest_version = SerializedDagModel.get(dag_run.dag_id, session=session)

if latest_version and dag_run.serialized_dag_id == latest_version.id:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True

dag_run.dag_hash = latest_version
dag_run.serialized_dag = latest_version

# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
Expand Down
66 changes: 66 additions & 0 deletions airflow/migrations/versions/0036_3_0_0_add_serial_id_to_sdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add serial ID to SDM.

Revision ID: e1ff90d3efe9
Revises: 0d9e73a75ee4
Create Date: 2024-09-27 09:32:46.514067

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

from airflow.models.base import naming_convention

# revision identifiers, used by Alembic.
revision = "e1ff90d3efe9"
down_revision = "0d9e73a75ee4"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply add serial pkey to SerializedDag."""
with op.batch_alter_table(
"serialized_dag", recreate="always", naming_convention=naming_convention
) as batch_op:
batch_op.drop_constraint("serialized_dag_pkey", type_="primary")
# hack. The primary_key here sets autoincrement
batch_op.add_column(sa.Column("id", sa.Integer(), primary_key=True), insert_before="dag_id")
batch_op.create_primary_key("serialized_dag_pkey", ["id"])
batch_op.add_column(
sa.Column("version_number", sa.Integer(), nullable=False, default=1), insert_before="dag_id"
)
batch_op.create_unique_constraint(
batch_op.f("dag_hash_version_number_unique"), ["dag_hash", "version_number"]
)


def downgrade():
"""Unapply add serial pkey to SerializedDag."""
with op.batch_alter_table("serialized_dag", naming_convention=naming_convention) as batch_op:
batch_op.drop_constraint(batch_op.f("dag_hash_version_number_unique"), type_="unique")
batch_op.drop_column("id")
batch_op.create_primary_key("serialized_dag_pkey", ["dag_id"])
batch_op.drop_column("version_number")
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add SDM foreign key to DagRun, TI & TIH.
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved

Revision ID: 4235395d5ec5
Revises: e1ff90d3efe9
Create Date: 2024-10-03 13:37:55.678831

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "4235395d5ec5"
down_revision = "e1ff90d3efe9"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Add SDM foreignkey to DagRun."""
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))
batch_op.create_foreign_key(
"dag_run_serialized_dag_fkey",
"serialized_dag",
["serialized_dag_id"],
["id"],
ondelete="SET NULL",
)
batch_op.drop_column("dag_hash")

with op.batch_alter_table("task_instance") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))
batch_op.create_foreign_key(
"task_instance_serialized_dag_fkey",
"serialized_dag",
["serialized_dag_id"],
["id"],
ondelete="SET NULL",
)

with op.batch_alter_table("task_instance_history") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))


def downgrade():
"""Unapply Add SDM foreignkey to DagRun."""
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(sa.Column("dag_hash", sa.String(32)))
batch_op.drop_constraint("dag_run_serialized_dag_fkey", type_="foreignkey")
batch_op.drop_column("serialized_dag_id")

with op.batch_alter_table("task_instance") as batch_op:
batch_op.drop_constraint("task_instance_serialized_dag_fkey", type_="foreignkey")
batch_op.drop_column("serialized_dag_id")

with op.batch_alter_table("task_instance_history") as batch_op:
batch_op.drop_column("serialized_dag_id")
2 changes: 1 addition & 1 deletion airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _create_backfill(
dag_run_conf: dict | None,
) -> Backfill | None:
with create_session() as session:
serdag = session.get(SerializedDagModel, dag_id)
serdag = session.scalar(SerializedDagModel.latest_item_select_object(dag_id))
if not serdag:
raise NotFound(f"Could not find dag {dag_id}")

Expand Down
10 changes: 5 additions & 5 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _create_orm_dagrun(
conf,
state,
run_type,
dag_hash,
serialized_dag,
creating_job_id,
data_interval,
session,
Expand All @@ -317,7 +317,7 @@ def _create_orm_dagrun(
conf=conf,
state=state,
run_type=run_type,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
Expand Down Expand Up @@ -2542,7 +2542,7 @@ def create_dagrun(
conf: dict | None = None,
run_type: DagRunType | None = None,
session: Session = NEW_SESSION,
dag_hash: str | None = None,
serialized_dag: SerializedDagModel | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
):
Expand All @@ -2561,7 +2561,7 @@ def create_dagrun(
:param conf: Dict containing configuration/parameters to pass to the DAG
:param creating_job_id: id of the job creating this DagRun
:param session: database session
:param dag_hash: Hash of Serialized DAG
:param serialized_dag: The serialized Dag Model object
:param data_interval: Data interval of the DagRun
"""
logical_date = timezone.coerce_datetime(execution_date)
Expand Down Expand Up @@ -2627,7 +2627,7 @@ def create_dagrun(
conf=conf,
state=state,
run_type=run_type,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=creating_job_id,
data_interval=data_interval,
session=session,
Expand Down
22 changes: 17 additions & 5 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
Expand Down Expand Up @@ -141,7 +142,6 @@ class DagRun(Base, LoggingMixin):
data_interval_end = Column(UtcDateTime)
# When a scheduler last attempted to schedule TIs for this DagRun
last_scheduling_decision = Column(UtcDateTime)
dag_hash = Column(String(32))
# Foreign key to LogTemplate. DagRun rows created prior to this column's
# existence have this set to NULL. Later rows automatically populate this on
# insert to point to the latest LogTemplate entry.
Expand All @@ -155,6 +155,11 @@ class DagRun(Base, LoggingMixin):
# This number is incremented only when the DagRun is re-Queued,
# when the DagRun is cleared.
clear_number = Column(Integer, default=0, nullable=False, server_default="0")
serialized_dag_id = Column(
Integer,
ForeignKey("serialized_dag.id", name="dag_run_serialized_dag_fkey", ondelete="SET NULL"),
)
serialized_dag = relationship("SerializedDagModel", back_populates="dag_run")

# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
Expand Down Expand Up @@ -218,7 +223,7 @@ def __init__(
conf: Any | None = None,
state: DagRunState | None = None,
run_type: str | None = None,
dag_hash: str | None = None,
serialized_dag: SerializedDagModel | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
triggered_by: DagRunTriggeredByType | None = None,
Expand All @@ -242,7 +247,7 @@ def __init__(
else:
self.queued_at = queued_at
self.run_type = run_type
self.dag_hash = dag_hash
self.serialized_dag = serialized_dag
self.creating_job_id = creating_job_id
self.clear_number = 0
self.triggered_by = triggered_by
Expand Down Expand Up @@ -354,6 +359,12 @@ def set_state(self, state: DagRunState) -> None:
def state(self):
return synonym("_state", descriptor=property(self.get_state, self.set_state))

@provide_session
def dag_hash(self, session: Session = NEW_SESSION):
from airflow.models.serialized_dag import SerializedDagModel as SDM

return str(session.scalar(select(SDM.dag_hash).where(SDM.id == self.serialized_dag_id)))

@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -939,6 +950,7 @@ def recalculate(self) -> _UnfinishedStates:
"state=%s, external_trigger=%s, run_type=%s, "
"data_interval_start=%s, data_interval_end=%s, dag_hash=%s"
)

self.log.info(
msg,
self.dag_id,
Expand All @@ -956,7 +968,7 @@ def recalculate(self) -> _UnfinishedStates:
self.run_type,
self.data_interval_start,
self.data_interval_end,
self.dag_hash,
self.dag_hash(session),
)

with Trace.start_span_from_dagrun(dagrun=self) as span:
Expand All @@ -980,7 +992,7 @@ def recalculate(self) -> _UnfinishedStates:
"run_type": str(self.run_type),
"data_interval_start": str(self.data_interval_start),
"data_interval_end": str(self.data_interval_end),
"dag_hash": str(self.dag_hash),
"dag_hash": str(self.dag_hash(session)),
"conf": str(self.conf),
}
if span.is_recording():
Expand Down
Loading
Loading