Skip to content

Commit

Permalink
Merge branch 'main' into feature/http-hook-custom-adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
jieyao-MilestoneHub authored Nov 23, 2024
2 parents ff2adff + 00a3099 commit 3cd76d0
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 92 deletions.
76 changes: 2 additions & 74 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
from typing import TYPE_CHECKING, Collection, Iterable

from sqlalchemy import and_, or_, select
from sqlalchemy.orm import lazyload
Expand All @@ -29,7 +29,6 @@
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
from datetime import datetime
Expand All @@ -38,47 +37,6 @@

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.utils.types import DagRunType


class _DagRunInfo(NamedTuple):
logical_date: datetime
data_interval: tuple[datetime, datetime]


def _create_dagruns(
dag: DAG,
infos: Iterable[_DagRunInfo],
state: DagRunState,
run_type: DagRunType,
) -> Iterable[DagRun]:
"""
Infers from data intervals which DAG runs need to be created and does so.
:param dag: The DAG to create runs for.
:param infos: List of logical dates and data intervals to evaluate.
:param state: The state to set the dag run to
:param run_type: The prefix will be used to construct dag run id: ``{run_id_prefix}__{logical_date}``.
:return: Newly created and existing dag runs for the logical dates supplied.
"""
# Find out existing DAG runs that we don't need to create.
dag_runs = {
run.logical_date: run
for run in DagRun.find(dag_id=dag.dag_id, logical_date=[info.logical_date for info in infos])
}

for info in infos:
if info.logical_date not in dag_runs:
dag_runs[info.logical_date] = dag.create_dagrun(
logical_date=info.logical_date,
data_interval=info.data_interval,
start_date=timezone.utcnow(),
external_trigger=False,
state=state,
run_type=run_type,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)
return dag_runs.values()


@provide_session
Expand Down Expand Up @@ -131,7 +89,7 @@ def set_state(
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
# now look for the task instances that are affected

qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)
qry_dag = get_all_dag_task_query(dag, state, task_id_map_index_list, dag_run_ids)

if commit:
tis_altered = session.scalars(qry_dag.with_for_update()).all()
Expand All @@ -145,7 +103,6 @@ def set_state(

def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
task_ids: list[str | tuple[str, int]],
run_ids: Iterable[str],
Expand All @@ -163,13 +120,6 @@ def get_all_dag_task_query(
return qry_dag


def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]:
for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session):
dag_run.dag = dag
dag_run.verify_integrity(session=session)
yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))


def find_task_relatives(tasks, downstream, upstream):
"""Yield task ids and optionally ancestor and descendant ids."""
for item in tasks:
Expand Down Expand Up @@ -417,28 +367,6 @@ def __set_dag_run_state_to_running_or_queued(
return res


@provide_session
def set_dag_run_state_to_running(
*,
dag: DAG,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> list[TaskInstance]:
"""
Set the dag run's state to running.
Set for a specific logical date and its task instances to running.
"""
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.RUNNING,
dag=dag,
run_id=run_id,
commit=commit,
session=session,
)


@provide_session
def set_dag_run_state_to_queued(
*,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DAGRunClearBody(BaseModel):
class DAGRunResponse(BaseModel):
"""DAG Run serializer for responses."""

dag_run_id: str | None = Field(alias="run_id")
dag_run_id: str | None = Field(validation_alias="run_id")
dag_id: str
logical_date: datetime | None
queued_at: datetime | None
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5959,11 +5959,11 @@ components:
description: Enum for DAG Run states when updating a DAG Run.
DAGRunResponse:
properties:
run_id:
dag_run_id:
anyOf:
- type: string
- type: 'null'
title: Run Id
title: Dag Run Id
dag_id:
type: string
title: Dag Id
Expand Down Expand Up @@ -6028,7 +6028,7 @@ components:
title: Note
type: object
required:
- run_id
- dag_run_id
- dag_id
- logical_date
- queued_at
Expand Down
6 changes: 3 additions & 3 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,7 @@ export const $DAGRunPatchStates = {

export const $DAGRunResponse = {
properties: {
run_id: {
dag_run_id: {
anyOf: [
{
type: "string",
Expand All @@ -1679,7 +1679,7 @@ export const $DAGRunResponse = {
type: "null",
},
],
title: "Run Id",
title: "Dag Run Id",
},
dag_id: {
type: "string",
Expand Down Expand Up @@ -1800,7 +1800,7 @@ export const $DAGRunResponse = {
},
type: "object",
required: [
"run_id",
"dag_run_id",
"dag_id",
"logical_date",
"queued_at",
Expand Down
2 changes: 1 addition & 1 deletion airflow/ui/openapi-gen/requests/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ export type DAGRunPatchStates = "queued" | "success" | "failed";
* DAG Run serializer for responses.
*/
export type DAGRunResponse = {
run_id: string | null;
dag_run_id: string | null;
dag_id: string;
logical_date: string | null;
queued_at: string | null;
Expand Down
2 changes: 1 addition & 1 deletion airflow/ui/src/pages/DagsList/RecentRuns.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export const RecentRuns = ({
<Text>Duration: {run.duration.toFixed(2)}s</Text>
</Box>
}
key={run.run_id}
key={run.dag_run_id}
positioning={{
offset: {
crossAxis: 5,
Expand Down
16 changes: 8 additions & 8 deletions tests/api_fastapi/core_api/routes/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_get_dag_run(self, test_client, dag_id, run_id, state, run_type, trigger
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == dag_id
assert body["run_id"] == run_id
assert body["dag_run_id"] == run_id
assert body["state"] == state
assert body["run_type"] == run_type
assert body["triggered_by"] == triggered_by.value
Expand All @@ -168,7 +168,7 @@ def parse_datetime(datetime_str):
@staticmethod
def get_dag_run_dict(run: DagRun):
return {
"run_id": run.run_id,
"dag_run_id": run.run_id,
"dag_id": run.dag_id,
"logical_date": TestGetDagRuns.parse_datetime(run.logical_date),
"queued_at": TestGetDagRuns.parse_datetime(run.queued_at),
Expand All @@ -194,7 +194,7 @@ def test_get_dag_runs(self, test_client, session, dag_id, total_entries):
for each in body["dag_runs"]:
run = (
session.query(DagRun)
.where(DagRun.dag_id == each["dag_id"], DagRun.run_id == each["run_id"])
.where(DagRun.dag_id == each["dag_id"], DagRun.run_id == each["dag_run_id"])
.one()
)
expected = self.get_dag_run_dict(run)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_return_correct_results_with_order_by(self, test_client, order_by, expec
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 2
assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_order
assert [each["dag_run_id"] for each in body["dag_runs"]] == expected_dag_id_order

@pytest.mark.parametrize(
"query_params, expected_dag_id_order",
Expand All @@ -254,7 +254,7 @@ def test_limit_and_offset(self, test_client, query_params, expected_dag_id_order
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 2
assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_order
assert [each["dag_run_id"] for each in body["dag_runs"]] == expected_dag_id_order

@pytest.mark.parametrize(
"query_params, expected_detail",
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_filters(self, test_client, dag_id, query_params, expected_dag_id_list):
response = test_client.get(f"/public/dags/{dag_id}/dagRuns", params=query_params)
assert response.status_code == 200
body = response.json()
assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_list
assert [each["dag_run_id"] for each in body["dag_runs"]] == expected_dag_id_list

def test_bad_filters(self, test_client):
query_params = {
Expand Down Expand Up @@ -474,7 +474,7 @@ def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_b
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == dag_id
assert body["run_id"] == run_id
assert body["dag_run_id"] == run_id
assert body.get("state") == response_body.get("state")
assert body.get("note") == response_body.get("note")

Expand Down Expand Up @@ -623,7 +623,7 @@ def test_clear_dag_run(self, test_client):
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == DAG1_ID
assert body["run_id"] == DAG1_RUN1_ID
assert body["dag_run_id"] == DAG1_RUN1_ID
assert body["state"] == "queued"

@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/api_fastapi/core_api/routes/ui/test_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_recent_dag_runs(self, test_client, query_params, expected_ids, expected
assert response.status_code == 200
body = response.json()
required_dag_run_key = [
"run_id",
"dag_run_id",
"dag_id",
"state",
"logical_date",
Expand Down

0 comments on commit 3cd76d0

Please sign in to comment.