From c54811d61f46decbab30bca90ac16a75675bb9c3 Mon Sep 17 00:00:00 2001 From: Omkar P <45419097+omkar-foss@users.noreply.github.com> Date: Wed, 20 Nov 2024 22:55:18 +0530 Subject: [PATCH] Migrate public endpoint Clear Task Instances to FastAPI --- .../endpoints/task_instance_endpoint.py | 1 + airflow/api_fastapi/common/types.py | 27 +- .../core_api/datamodels/task_instances.py | 57 +- .../core_api/openapi/v1-generated.yaml | 149 +++- .../core_api/routes/public/task_instances.py | 116 ++- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 44 ++ .../ui/openapi-gen/requests/schemas.gen.ts | 139 ++++ .../ui/openapi-gen/requests/services.gen.ts | 33 +- airflow/ui/openapi-gen/requests/types.gen.ts | 72 +- tests/api_fastapi/conftest.py | 9 + .../routes/public/test_task_instances.py | 669 +++++++++++++++++- 12 files changed, 1289 insertions(+), 30 deletions(-) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index f4f7ac23cb859..95c824656ff27 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -439,6 +439,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) +@mark_fastapi_migration_done @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @action_logging @provide_session diff --git a/airflow/api_fastapi/common/types.py b/airflow/api_fastapi/common/types.py index 4a41f4f0ceed8..7c619370e3ef8 100644 --- a/airflow/api_fastapi/common/types.py +++ b/airflow/api_fastapi/common/types.py @@ -16,11 +16,18 @@ # under the License. from __future__ import annotations -from datetime import timedelta +from datetime import datetime, timedelta from enum import Enum from typing import Annotated -from pydantic import AfterValidator, AliasGenerator, AwareDatetime, BaseModel, BeforeValidator, ConfigDict +from pydantic import ( + AfterValidator, + AliasGenerator, + AwareDatetime, + BaseModel, + BeforeValidator, + ConfigDict, +) from airflow.utils import timezone @@ -29,7 +36,7 @@ def _validate_timedelta_field(td: timedelta | None) -> TimeDelta | None: - """Validate the execution_timeout property.""" + """Validate the timedelta field and return it.""" if td is None: return None return TimeDelta( @@ -59,6 +66,20 @@ class TimeDelta(BaseModel): TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)] +def _validate_nonnaive_datetime_field(dt: datetime | None) -> datetime | None: + """Validate and return the datetime field.""" + if dt is None: + return None + if isinstance(dt, str): + dt = datetime.fromisoformat(dt) + if not dt.tzinfo: + raise ValueError("Invalid datetime format, Naive datetime is disallowed") + return dt + + +DatetimeWithNonNaiveValidation = Annotated[datetime, BeforeValidator(_validate_nonnaive_datetime_field)] + + class Mimetype(str, Enum): """Mimetype for the `Content-Type` header.""" diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 4712df3273a4e..2dc424dba227c 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -17,9 +17,10 @@ from __future__ import annotations from datetime import datetime -from typing import Annotated +from typing import Annotated, Any from pydantic import ( + AliasChoices, AliasPath, AwareDatetime, BaseModel, @@ -27,8 +28,11 @@ ConfigDict, Field, NonNegativeInt, + ValidationError, + model_validator, ) +from airflow.api_fastapi.common.types import DatetimeWithNonNaiveValidation from airflow.api_fastapi.core_api.datamodels.job import JobResponse from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse from airflow.utils.state import TaskInstanceState @@ -150,3 +154,54 @@ class TaskInstanceHistoryCollectionResponse(BaseModel): task_instances: list[TaskInstanceHistoryResponse] total_entries: int + + +class ClearTaskInstancesBody(BaseModel): + """Request body for Clear Task Instances endpoint.""" + + dry_run: bool = True + start_date: DatetimeWithNonNaiveValidation | None = None + end_date: DatetimeWithNonNaiveValidation | None = None + only_failed: bool = True + only_running: bool = False + reset_dag_runs: bool = False + task_ids: list[str] | None = None + dag_run_id: str | None = None + include_upstream: bool = False + include_downstream: bool = False + include_future: bool = False + include_past: bool = False + + @model_validator(mode="before") + @classmethod + def validate_model(cls, data: Any) -> Any: + """Validate clear task instance form.""" + if data.get("only_failed") and data.get("only_running"): + raise ValidationError("only_failed and only_running both are set to True") + if data.get("start_date") and data.get("end_date"): + if data.get("start_date") > data.get("end_date"): + raise ValidationError("end_date is sooner than start_date") + if data.get("start_date") and data.get("end_date") and data.get("dag_run_id"): + raise ValidationError("Exactly one of dag_run_id or (start_date and end_date) must be provided") + if data.get("start_date") and data.get("dag_run_id"): + raise ValidationError("Exactly one of dag_run_id or start_date must be provided") + if data.get("end_date") and data.get("dag_run_id"): + raise ValidationError("Exactly one of dag_run_id or end_date must be provided") + if isinstance(data.get("task_ids"), list) and len(data.get("task_ids")) < 1: + raise ValidationError("task_ids list should have at least 1 element.") + return data + + +class TaskInstanceReferenceResponse(BaseModel): + """Task Instance Reference serializer for responses.""" + + task_id: str + dag_run_id: str = Field(validation_alias=AliasChoices("run_id", "dagrun_id")) + dag_id: str + logical_date: datetime + + +class TaskInstanceReferenceCollectionResponse(BaseModel): + """Task Instance Reference collection serializer for responses.""" + + task_instances: list[TaskInstanceReferenceResponse] diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 3542430176bb9..bb9fd1ce2e7c8 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3629,7 +3629,7 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' - /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances: + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/: get: tags: - Task Instance @@ -4030,6 +4030,57 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/clearTaskInstances: + post: + tags: + - Task Instance + summary: Post Clear Task Instances + description: Clear task instances. + operationId: post_clear_task_instances + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ClearTaskInstancesBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceReferenceCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/tasks/: get: tags: @@ -4823,6 +4874,67 @@ components: - status title: BaseInfoSchema description: Base status field for metadatabase and scheduler. + ClearTaskInstancesBody: + properties: + dry_run: + type: boolean + title: Dry Run + default: true + start_date: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date + end_date: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date + only_failed: + type: boolean + title: Only Failed + default: true + only_running: + type: boolean + title: Only Running + default: false + reset_dag_runs: + type: boolean + title: Reset Dag Runs + default: false + task_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Task Ids + dag_run_id: + anyOf: + - type: string + - type: 'null' + title: Dag Run Id + include_upstream: + type: boolean + title: Include Upstream + default: false + include_downstream: + type: boolean + title: Include Downstream + default: false + include_future: + type: boolean + title: Include Future + default: false + include_past: + type: boolean + title: Include Past + default: false + type: object + title: ClearTaskInstancesBody + description: Request body for Clear Task Instances endpoint. Config: properties: sections: @@ -6720,6 +6832,41 @@ components: - executor_config title: TaskInstanceHistoryResponse description: TaskInstanceHistory serializer for responses. + TaskInstanceReferenceCollectionResponse: + properties: + task_instances: + items: + $ref: '#/components/schemas/TaskInstanceReferenceResponse' + type: array + title: Task Instances + type: object + required: + - task_instances + title: TaskInstanceReferenceCollectionResponse + description: Task Instance Reference collection serializer for responses. + TaskInstanceReferenceResponse: + properties: + task_id: + type: string + title: Task Id + dag_run_id: + type: string + title: Dag Run Id + dag_id: + type: string + title: Dag Id + logical_date: + type: string + format: date-time + title: Logical Date + type: object + required: + - task_id + - dag_run_id + - dag_id + - logical_date + title: TaskInstanceReferenceResponse + description: Task Instance Reference serializer for responses. TaskInstanceResponse: properties: id: diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index f4769a981b882..ddcae95dccaae 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -48,29 +48,32 @@ ) from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.task_instances import ( + ClearTaskInstancesBody, TaskDependencyCollectionResponse, TaskInstanceCollectionResponse, TaskInstanceHistoryResponse, + TaskInstanceReferenceCollectionResponse, + TaskInstanceReferenceResponse, TaskInstanceResponse, TaskInstancesBatchBody, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.exceptions import TaskNotFound +from airflow.jobs.scheduler_job_runner import DR from airflow.models import Base, DagRun -from airflow.models.taskinstance import TaskInstance as TI +from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.utils.db import get_query_count -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import DagRunState, TaskInstanceState -task_instances_router = AirflowRouter( - tags=["Task Instance"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances" -) +task_instances_router = AirflowRouter(tags=["Task Instance"], prefix="/dags/{dag_id}") +task_instances_prefix = "/dagRuns/{dag_run_id}/taskInstances" @task_instances_router.get( - "/{task_id}", + task_instances_prefix + "/{task_id}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_task_instance( @@ -99,7 +102,7 @@ def get_task_instance( @task_instances_router.get( - "/{task_id}/listMapped", + task_instances_prefix + "/{task_id}/listMapped", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_mapped_task_instances( @@ -182,11 +185,11 @@ def get_mapped_task_instances( @task_instances_router.get( - "/{task_id}/dependencies", + task_instances_prefix + "/{task_id}/dependencies", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) @task_instances_router.get( - "/{task_id}/{map_index}/dependencies", + task_instances_prefix + "/{task_id}/{map_index}/dependencies", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_task_instance_dependencies( @@ -236,7 +239,7 @@ def get_task_instance_dependencies( @task_instances_router.get( - "/{task_id}/{map_index}", + task_instances_prefix + "/{task_id}/{map_index}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_mapped_task_instance( @@ -265,7 +268,7 @@ def get_mapped_task_instance( @task_instances_router.get( - "", + task_instances_prefix + "/", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_task_instances( @@ -348,7 +351,7 @@ def get_task_instances( @task_instances_router.post( - "/list", + task_instances_prefix + "/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_task_instances_batch( @@ -428,7 +431,7 @@ def get_task_instances_batch( @task_instances_router.get( - "/{task_id}/tries/{task_try_number}", + task_instances_prefix + "/{task_id}/tries/{task_try_number}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_task_instance_try_details( @@ -463,7 +466,7 @@ def _query(orm_object: Base) -> TI | TIH | None: @task_instances_router.get( - "/{task_id}/{map_index}/tries/{task_try_number}", + task_instances_prefix + "/{task_id}/{map_index}/tries/{task_try_number}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), ) def get_mapped_task_instance_try_details( @@ -482,3 +485,88 @@ def get_mapped_task_instance_try_details( map_index=map_index, session=session, ) + + +@task_instances_router.post( + "/clearTaskInstances", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), +) +def post_clear_task_instances( + dag_id: str, + request: Request, + body: ClearTaskInstancesBody, + session: Annotated[Session, Depends(get_session)], +) -> TaskInstanceReferenceCollectionResponse: + """Clear task instances.""" + dag = request.app.state.dag_bag.get_dag(dag_id) + if not dag: + error_message = f"DAG {dag_id} not found" + raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) + + reset_dag_runs = body.reset_dag_runs + dry_run = body.dry_run + # We always pass dry_run here, otherwise this would try to confirm on the terminal! + dag_run_id = body.dag_run_id + future = body.include_future + past = body.include_past + downstream = body.include_downstream + upstream = body.include_upstream + + if dag_run_id is not None: + dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id == dag_id, DR.run_id == dag_run_id)) + if dag_run is None: + error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}" + raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) + body.start_date = dag_run.logical_date + body.end_date = dag_run.logical_date + + if past: + body.start_date = None + + if future: + body.end_date = None + + task_ids = body.task_ids + if task_ids is not None: + task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] + dag = dag.partial_subset( + task_ids_or_regex=task_id, + include_downstream=downstream, + include_upstream=upstream, + ) + + if len(dag.task_dict) > 1: + # If we had upstream/downstream etc then also include those! + task_ids.extend(tid for tid in dag.task_dict if tid != task_id) + + task_instances = dag.clear( + dry_run=True, + task_ids=body.task_ids, + dag_bag=request.app.state.dag_bag, + **body.model_dump( + include=[ # type: ignore[arg-type] + "start_date", + "end_date", + "only_failed", + "only_running", + ] + ), + ) + + if not dry_run: + clear_task_instances( + task_instances, + session, + dag, + DagRunState.QUEUED if reset_dag_runs else False, + ) + + return TaskInstanceReferenceCollectionResponse( + task_instances=[ + TaskInstanceReferenceResponse.model_validate( + ti, + from_attributes=True, + ) + for ti in task_instances + ] + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index fe281c5640f9e..4fc26f693b54d 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1274,6 +1274,9 @@ export type PoolServicePostPoolMutationResult = Awaited< export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited< ReturnType >; +export type TaskInstanceServicePostClearTaskInstancesMutationResult = Awaited< + ReturnType +>; export type VariableServicePostVariableMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 74e25c0258a25..176cc256ccfa0 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -32,6 +32,7 @@ import { } from "../requests/services.gen"; import { BackfillPostBody, + ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, @@ -2314,6 +2315,49 @@ export const useTaskInstanceServiceGetTaskInstancesBatch = < }) as unknown as Promise, ...options, }); +/** + * Post Clear Task Instances + * Clear task instances. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns TaskInstanceReferenceCollectionResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServicePostClearTaskInstances = < + TData = Common.TaskInstanceServicePostClearTaskInstancesMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + requestBody: ClearTaskInstancesBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + requestBody: ClearTaskInstancesBody; + }, + TContext + >({ + mutationFn: ({ dagId, requestBody }) => + TaskInstanceService.postClearTaskInstances({ + dagId, + requestBody, + }) as unknown as Promise, + ...options, + }); /** * Post Variable * Create a variable. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 0f08034c533dd..f6f75da2a6413 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -459,6 +459,103 @@ export const $BaseInfoSchema = { description: "Base status field for metadatabase and scheduler.", } as const; +export const $ClearTaskInstancesBody = { + properties: { + dry_run: { + type: "boolean", + title: "Dry Run", + default: true, + }, + start_date: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date", + }, + end_date: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date", + }, + only_failed: { + type: "boolean", + title: "Only Failed", + default: true, + }, + only_running: { + type: "boolean", + title: "Only Running", + default: false, + }, + reset_dag_runs: { + type: "boolean", + title: "Reset Dag Runs", + default: false, + }, + task_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Task Ids", + }, + dag_run_id: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Dag Run Id", + }, + include_upstream: { + type: "boolean", + title: "Include Upstream", + default: false, + }, + include_downstream: { + type: "boolean", + title: "Include Downstream", + default: false, + }, + include_future: { + type: "boolean", + title: "Include Future", + default: false, + }, + include_past: { + type: "boolean", + title: "Include Past", + default: false, + }, + }, + type: "object", + title: "ClearTaskInstancesBody", + description: "Request body for Clear Task Instances endpoint.", +} as const; + export const $Config = { properties: { sections: { @@ -3327,6 +3424,48 @@ export const $TaskInstanceHistoryResponse = { description: "TaskInstanceHistory serializer for responses.", } as const; +export const $TaskInstanceReferenceCollectionResponse = { + properties: { + task_instances: { + items: { + $ref: "#/components/schemas/TaskInstanceReferenceResponse", + }, + type: "array", + title: "Task Instances", + }, + }, + type: "object", + required: ["task_instances"], + title: "TaskInstanceReferenceCollectionResponse", + description: "Task Instance Reference collection serializer for responses.", +} as const; + +export const $TaskInstanceReferenceResponse = { + properties: { + task_id: { + type: "string", + title: "Task Id", + }, + dag_run_id: { + type: "string", + title: "Dag Run Id", + }, + dag_id: { + type: "string", + title: "Dag Id", + }, + logical_date: { + type: "string", + format: "date-time", + title: "Logical Date", + }, + }, + type: "object", + required: ["task_id", "dag_run_id", "dag_id", "logical_date"], + title: "TaskInstanceReferenceResponse", + description: "Task Instance Reference serializer for responses.", +} as const; + export const $TaskInstanceResponse = { properties: { id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 09eb2432d77ad..68a2c40ddd123 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -127,6 +127,8 @@ import type { GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, + PostClearTaskInstancesData, + PostClearTaskInstancesResponse, GetTasksData, GetTasksResponse, GetTaskData, @@ -1998,7 +2000,7 @@ export class TaskInstanceService { ): CancelablePromise { return __request(OpenAPI, { method: "GET", - url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/", path: { dag_id: data.dagId, dag_run_id: data.dagRunId, @@ -2130,6 +2132,35 @@ export class TaskInstanceService { }, }); } + + /** + * Post Clear Task Instances + * Clear task instances. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns TaskInstanceReferenceCollectionResponse Successful Response + * @throws ApiError + */ + public static postClearTaskInstances( + data: PostClearTaskInstancesData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/dags/{dag_id}/clearTaskInstances", + path: { + dag_id: data.dagId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class TaskService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 71a94d534d326..27446e73b4834 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -128,6 +128,24 @@ export type BaseInfoSchema = { status: string | null; }; +/** + * Request body for Clear Task Instances endpoint. + */ +export type ClearTaskInstancesBody = { + dry_run?: boolean; + start_date?: string | null; + end_date?: string | null; + only_failed?: boolean; + only_running?: boolean; + reset_dag_runs?: boolean; + task_ids?: Array | null; + dag_run_id?: string | null; + include_upstream?: boolean; + include_downstream?: boolean; + include_future?: boolean; + include_past?: boolean; +}; + /** * List of config sections with their options. */ @@ -828,6 +846,23 @@ export type TaskInstanceHistoryResponse = { executor_config: string; }; +/** + * Task Instance Reference collection serializer for responses. + */ +export type TaskInstanceReferenceCollectionResponse = { + task_instances: Array; +}; + +/** + * Task Instance Reference serializer for responses. + */ +export type TaskInstanceReferenceResponse = { + task_id: string; + dag_run_id: string; + dag_id: string; + logical_date: string; +}; + /** * TaskInstance serializer for responses. */ @@ -1606,6 +1641,14 @@ export type GetMappedTaskInstanceTryDetailsData = { export type GetMappedTaskInstanceTryDetailsResponse = TaskInstanceHistoryResponse; +export type PostClearTaskInstancesData = { + dagId: string; + requestBody: ClearTaskInstancesBody; +}; + +export type PostClearTaskInstancesResponse = + TaskInstanceReferenceCollectionResponse; + export type GetTasksData = { dagId: string; orderBy?: string; @@ -3221,7 +3264,7 @@ export type $OpenApiTs = { }; }; }; - "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances": { + "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/": { get: { req: GetTaskInstancesData; res: { @@ -3329,6 +3372,33 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/clearTaskInstances": { + post: { + req: PostClearTaskInstancesData; + res: { + /** + * Successful Response + */ + 200: TaskInstanceReferenceCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dags/{dag_id}/tasks/": { get: { req: GetTasksData; diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 2ef3b14369a31..2928a4d829c70 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -36,3 +36,12 @@ def create_test_client(apps="all"): return TestClient(app) return create_test_client + + +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models import DagBag + + dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False) + dagbag_instance.sync_to_db() + return dagbag_instance diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index f8e75600171b3..f97a858a1efe7 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -35,11 +35,14 @@ from airflow.models.taskmap import TaskMap from airflow.models.trigger import Trigger from airflow.utils.platform import getuser -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields +from tests_common.test_utils.db import ( + clear_db_runs, + clear_rendered_ti_fields, +) from tests_common.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test @@ -54,14 +57,18 @@ class TestTaskInstanceEndpoint: - def setup_method(self): + @staticmethod + def clear_db(): clear_db_runs() + def setup_method(self): + self.clear_db() + def teardown_method(self): - clear_db_runs() + self.clear_db() @pytest.fixture(autouse=True) - def setup_attrs(self, session) -> None: + def setup_attrs(self, dagbag) -> None: self.default_time = DEFAULT self.ti_init = { "logical_date": self.default_time, @@ -77,8 +84,6 @@ def setup_attrs(self, session) -> None: } clear_db_runs() clear_rendered_ti_fields() - dagbag = DagBag(include_examples=True, read_dags_from_db=False) - dagbag.sync_to_db() self.dagbag = dagbag def create_task_instances( @@ -973,8 +978,8 @@ def test_return_TI_only_from_readable_dags(self, test_client, session): ) response = test_client.get("/public/dags/~/dagRuns/~/taskInstances") assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 def test_should_respond_200_for_dag_id_filter(self, test_client, session): self.create_task_instances(session) @@ -1663,3 +1668,649 @@ def test_raises_404_for_nonexistent_task_instance(self, test_client, session): assert response.json() == { "detail": "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `nonexistent_task`, try_number: `0` and map_index: `-1` was not found" } + + +class TestPostClearTaskInstances(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "main_dag, task_instances, request_dag, payload, expected_ti", + [ + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.FAILED}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.FAILED, + }, + ], + "example_python_operator", + { + "dry_run": True, + "start_date": DEFAULT_DATETIME_STR_2, + "only_failed": True, + }, + 2, + id="clear start date filter", + ), + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.FAILED}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.FAILED, + }, + ], + "example_python_operator", + { + "dry_run": True, + "end_date": DEFAULT_DATETIME_STR_2, + "only_failed": True, + }, + 2, + id="clear end date filter", + ), + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.FAILED, + }, + ], + "example_python_operator", + {"dry_run": True, "only_running": True, "only_failed": False}, + 2, + id="clear only running", + ), + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.FAILED}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + ], + "example_python_operator", + { + "dry_run": True, + "only_failed": True, + }, + 2, + id="clear only failed", + ), + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.FAILED}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.FAILED, + }, + ], + "example_python_operator", + { + "dry_run": True, + "task_ids": ["print_the_context", "sleep_for_1"], + }, + 2, + id="clear by task ids", + ), + pytest.param( + "example_python_operator", + [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.FAILED}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.FAILED, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + ], + "example_python_operator", + { + "only_failed": True, + }, + 2, + id="dry_run default", + ), + ], + ) + def test_should_respond_200( + self, + test_client, + session, + main_dag, + task_instances, + request_dag, + payload, + expected_ti, + ): + self.create_task_instances( + session, + dag_id=main_dag, + task_instances=task_instances, + update_extras=False, + ) + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{request_dag}/clearTaskInstances", + json=payload, + ) + assert response.status_code == 200 + assert len(response.json()["task_instances"]) == expected_ti + + def test_clear_taskinstance_is_called_with_queued_dr_state(self, test_client, session): + """Test that if reset_dag_runs is True, then clear_task_instances is called with State.QUEUED""" + self.create_task_instances(session) + dag_id = "example_python_operator" + payload = {"reset_dag_runs": True, "dry_run": False} + self.dagbag.sync_to_db() + with mock.patch( + "airflow.api_fastapi.core_api.routes.public.task_instances.clear_task_instances", + ) as mp: + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + assert response.status_code == 200 + mp.assert_called_once() + + def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, session): + """Test that dagrun is running when invalid task_ids are passed to clearTaskInstances API.""" + dag_id = "example_python_operator" + tis = self.create_task_instances(session) + dagrun = tis[0].get_dagrun() + assert dagrun.state == "running" + + payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + assert response.status_code == 200 + + dagrun.refresh_from_db() + assert dagrun.state == "running" + assert all(ti.state == "running" for ti in tis) + + def test_should_respond_200_with_reset_dag_run(self, test_client, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": True, + "only_failed": False, + "only_running": True, + } + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=DagRunState.FAILED, + ) + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + + failed_dag_runs = session.query(DagRun).filter(DagRun.state == "failed").count() + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "logical_date": "2020-01-02T00:00:00Z", + "task_id": "log_sql_query", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "logical_date": "2020-01-03T00:00:00Z", + "task_id": "sleep_for_0", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_3", + "logical_date": "2020-01-04T00:00:00Z", + "task_id": "sleep_for_1", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_4", + "logical_date": "2020-01-05T00:00:00Z", + "task_id": "sleep_for_2", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_5", + "logical_date": "2020-01-06T00:00:00Z", + "task_id": "sleep_for_3", + }, + ] + for task_instance in expected_response: + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) + assert 0 == failed_dag_runs, 0 + + def test_should_respond_200_with_dag_run_id(self, test_client, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "only_running": True, + "dag_run_id": "TEST_DAG_RUN_ID_0", + } + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + }, + ] + assert response.json()["task_instances"] == expected_response + assert 1 == len(response.json()["task_instances"]) + + def test_should_respond_200_with_include_past(self, test_client, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "include_past": True, + "only_running": True, + } + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.RUNNING, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "logical_date": "2020-01-02T00:00:00Z", + "task_id": "log_sql_query", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "logical_date": "2020-01-03T00:00:00Z", + "task_id": "sleep_for_0", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_3", + "logical_date": "2020-01-04T00:00:00Z", + "task_id": "sleep_for_1", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_4", + "logical_date": "2020-01-05T00:00:00Z", + "task_id": "sleep_for_2", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_5", + "logical_date": "2020-01-06T00:00:00Z", + "task_id": "sleep_for_3", + }, + ] + for task_instance in expected_response: + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) + + def test_should_respond_200_with_include_future(self, test_client, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "include_future": True, + "only_running": False, + } + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.SUCCESS}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.SUCCESS, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.SUCCESS, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.SUCCESS, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.SUCCESS, + }, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.SUCCESS, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + self.dagbag.sync_to_db() + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "logical_date": "2020-01-02T00:00:00Z", + "task_id": "log_sql_query", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "logical_date": "2020-01-03T00:00:00Z", + "task_id": "sleep_for_0", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_3", + "logical_date": "2020-01-04T00:00:00Z", + "task_id": "sleep_for_1", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_4", + "logical_date": "2020-01-05T00:00:00Z", + "task_id": "sleep_for_2", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_5", + "logical_date": "2020-01-06T00:00:00Z", + "task_id": "sleep_for_3", + }, + ] + for task_instance in expected_response: + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) + + def test_should_respond_404_for_nonexistent_dagrun_id(self, test_client, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "only_running": True, + "dag_run_id": "TEST_DAG_RUN_ID_100", + } + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + response = test_client.post( + f"/public/dags/{dag_id}/clearTaskInstances", + json=payload, + ) + + assert 404 == response.status_code + assert f"Dag Run id TEST_DAG_RUN_ID_100 not found in dag {dag_id}" in response.text + + @pytest.mark.parametrize( + "payload, expected", + [ + ( + {"end_date": "2020-11-10T12:42:39.442973"}, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "end_date"], + "msg": "Value error, Invalid datetime format, Naive datetime is disallowed", + "input": "2020-11-10T12:42:39.442973", + "ctx": {"error": {}}, + } + ] + }, + ), + ( + {"end_date": "2020-11-10T12:4po"}, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "end_date"], + "msg": "Value error, Invalid isoformat string: '2020-11-10T12:4po'", + "input": "2020-11-10T12:4po", + "ctx": {"error": {}}, + } + ] + }, + ), + ( + {"start_date": "2020-11-10T12:42:39.442973"}, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "start_date"], + "msg": "Value error, Invalid datetime format, Naive datetime is disallowed", + "input": "2020-11-10T12:42:39.442973", + "ctx": {"error": {}}, + } + ] + }, + ), + ( + {"start_date": "2020-11-10T12:4po"}, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "start_date"], + "msg": "Value error, Invalid isoformat string: '2020-11-10T12:4po'", + "input": "2020-11-10T12:4po", + "ctx": {"error": {}}, + } + ] + }, + ), + ], + ) + def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, session, payload, expected): + task_instances = [ + {"logical_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + ] + self.create_task_instances( + session, + dag_id="example_python_operator", + task_instances=task_instances, + update_extras=False, + ) + self.dagbag.sync_to_db() + response = test_client.post( + "/public/dags/example_python_operator/clearTaskInstances", + json=payload, + ) + assert response.status_code == 422 + assert response.json() == expected + + def test_raises_404_for_non_existent_dag(self, test_client): + response = test_client.post( + "/public/dags/non-existent-dag/clearTaskInstances", + json={ + "dry_run": False, + "reset_dag_runs": True, + "only_failed": False, + "only_running": True, + }, + ) + assert response.status_code == 404 + assert "DAG non-existent-dag not found" in response.text