From f33166a612218a10cee0e42cde36bc43fc184222 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU <68415893+jason810496@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:40:35 +0800 Subject: [PATCH] AIP-81 Add Insert Multiple Pools API (#44121) * Add bulk post pools, refactor post pool * Add 409 case for TestPostPool * Add test for bulk post pools * Remove unused status code, rename post_body to body * Refactor duplicate pool insert handling - handle exception from db level instead of application level * Add global database exception handler for fastapi * Remove manual handle for unique constraint exc * Refactor test_pools * Fix bound for TypeVar, type for comment --- airflow/api_fastapi/app.py | 9 +- airflow/api_fastapi/common/exceptions.py | 64 +++++++++ airflow/api_fastapi/core_api/app.py | 8 ++ .../api_fastapi/core_api/datamodels/pools.py | 8 +- .../core_api/openapi/v1-generated.yaml | 63 +++++++++ .../core_api/routes/public/pools.py | 31 +++- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 38 +++++ .../ui/openapi-gen/requests/schemas.gen.ts | 17 +++ .../ui/openapi-gen/requests/services.gen.ts | 28 ++++ airflow/ui/openapi-gen/requests/types.gen.ts | 44 ++++++ .../core_api/routes/public/test_pools.py | 133 +++++++++++++++++- 12 files changed, 439 insertions(+), 7 deletions(-) create mode 100644 airflow/api_fastapi/common/exceptions.py diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 4bf6ae9f6b77c..02841c8c211f5 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -22,7 +22,13 @@ from fastapi import FastAPI from starlette.routing import Mount -from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views +from airflow.api_fastapi.core_api.app import ( + init_config, + init_dag_bag, + init_error_handlers, + init_plugins, + init_views, +) from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.configuration import conf @@ -61,6 +67,7 @@ def create_app(apps: str = "all") -> FastAPI: init_dag_bag(app) init_views(app) init_plugins(app) + init_error_handlers(app) init_auth_manager() if "execution" in apps_list or "all" in apps_list: diff --git a/airflow/api_fastapi/common/exceptions.py b/airflow/api_fastapi/common/exceptions.py new file mode 100644 index 0000000000000..1e779a6097576 --- /dev/null +++ b/airflow/api_fastapi/common/exceptions.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from fastapi import HTTPException, Request, status +from sqlalchemy.exc import IntegrityError + +T = TypeVar("T", bound=Exception) + + +class BaseErrorHandler(Generic[T], ABC): + """Base class for error handlers.""" + + def __init__(self, exception_cls: T) -> None: + self.exception_cls = exception_cls + + @abstractmethod + def exception_handler(self, request: Request, exc: T): + """exception_handler method.""" + raise NotImplementedError + + +class _UniqueConstraintErrorHandler(BaseErrorHandler[IntegrityError]): + """Exception raised when trying to insert a duplicate value in a unique column.""" + + def __init__(self): + super().__init__(IntegrityError) + self.unique_constraint_error_messages = [ + "UNIQUE constraint failed", # SQLite + "Duplicate entry", # MySQL + "violates unique constraint", # PostgreSQL + ] + + def exception_handler(self, request: Request, exc: IntegrityError): + """Handle IntegrityError exception.""" + exc_orig_str = str(exc.orig) + if any(error_msg in exc_orig_str for error_msg in self.unique_constraint_error_messages): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Unique constraint violation", + ) + + +DatabaseErrorHandlers = [ + _UniqueConstraintErrorHandler(), +] diff --git a/airflow/api_fastapi/core_api/app.py b/airflow/api_fastapi/core_api/app.py index fc29e51f999fe..0e0a375054b19 100644 --- a/airflow/api_fastapi/core_api/app.py +++ b/airflow/api_fastapi/core_api/app.py @@ -120,3 +120,11 @@ def init_config(app: FastAPI) -> None: app.add_middleware(GZipMiddleware, minimum_size=1024, compresslevel=5) app.state.secret_key = conf.get("webserver", "secret_key") + + +def init_error_handlers(app: FastAPI) -> None: + from airflow.api_fastapi.common.exceptions import DatabaseErrorHandlers + + # register database error handlers + for handler in DatabaseErrorHandlers: + app.add_exception_handler(handler.exception_cls, handler.exception_handler) diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index ef3676a8afec7..137392094cb5d 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -72,6 +72,12 @@ class PoolPatchBody(BaseModel): class PoolPostBody(BasePool): """Pool serializer for post bodies.""" - pool: str = Field(alias="name") + pool: str = Field(alias="name", max_length=256) description: str | None = None include_deferred: bool = False + + +class PoolPostBulkBody(BaseModel): + """Pools serializer for post bodies.""" + + pools: list[PoolPostBody] diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 2f74f2268928f..395a6acea5f6e 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3278,6 +3278,56 @@ paths: schema: $ref: '#/components/schemas/HTTPExceptionResponse' description: Forbidden + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /public/pools/bulk: + post: + tags: + - Pool + summary: Post Pools + description: Create multiple pools. + operationId: post_pools + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PoolPostBulkBody' + required: true + responses: + '201': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/PoolCollectionResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '409': + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' '422': description: Validation Error content: @@ -6544,6 +6594,7 @@ components: properties: name: type: string + maxLength: 256 title: Name slots: type: integer @@ -6563,6 +6614,18 @@ components: - slots title: PoolPostBody description: Pool serializer for post bodies. + PoolPostBulkBody: + properties: + pools: + items: + $ref: '#/components/schemas/PoolPostBody' + type: array + title: Pools + type: object + required: + - pools + title: PoolPostBulkBody + description: Pools serializer for post bodies. PoolResponse: properties: name: diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 582e03ab00dbd..0e67994acfaab 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -32,6 +32,7 @@ PoolCollectionResponse, PoolPatchBody, PoolPostBody, + PoolPostBulkBody, PoolResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc @@ -160,14 +161,38 @@ def patch_pool( @pools_router.post( "", status_code=status.HTTP_201_CREATED, + responses=create_openapi_http_exception_doc( + [status.HTTP_409_CONFLICT] + ), # handled by global exception handler ) def post_pool( - post_body: PoolPostBody, + body: PoolPostBody, session: Annotated[Session, Depends(get_session)], ) -> PoolResponse: """Create a Pool.""" - pool = Pool(**post_body.model_dump()) - + pool = Pool(**body.model_dump()) session.add(pool) return PoolResponse.model_validate(pool, from_attributes=True) + + +@pools_router.post( + "/bulk", + status_code=status.HTTP_201_CREATED, + responses=create_openapi_http_exception_doc( + [ + status.HTTP_409_CONFLICT, # handled by global exception handler + ] + ), +) +def post_pools( + body: PoolPostBulkBody, + session: Annotated[Session, Depends(get_session)], +) -> PoolCollectionResponse: + """Create multiple pools.""" + pools = [Pool(**body.model_dump()) for body in body.pools] + session.add_all(pools) + return PoolCollectionResponse( + pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], + total_entries=len(pools), + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 736425218022a..968c0617caf80 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1330,6 +1330,9 @@ export type DagRunServiceClearDagRunMutationResult = Awaited< export type PoolServicePostPoolMutationResult = Awaited< ReturnType >; +export type PoolServicePostPoolsMutationResult = Awaited< + ReturnType +>; export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 2ca159f465f51..b2343d0add10c 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -41,6 +41,7 @@ import { DagWarningType, PoolPatchBody, PoolPostBody, + PoolPostBulkBody, TaskInstancesBatchBody, VariableBody, } from "../requests/types.gen"; @@ -2363,6 +2364,43 @@ export const usePoolServicePostPool = < PoolService.postPool({ requestBody }) as unknown as Promise, ...options, }); +/** + * Post Pools + * Create multiple pools. + * @param data The data for the request. + * @param data.requestBody + * @returns PoolCollectionResponse Successful Response + * @throws ApiError + */ +export const usePoolServicePostPools = < + TData = Common.PoolServicePostPoolsMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + requestBody: PoolPostBulkBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + requestBody: PoolPostBulkBody; + }, + TContext + >({ + mutationFn: ({ requestBody }) => + PoolService.postPools({ requestBody }) as unknown as Promise, + ...options, + }); /** * Get Task Instances Batch * Get list of task instances. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index a0bb85ace80ad..3d51b72eba2e5 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -2848,6 +2848,7 @@ export const $PoolPostBody = { properties: { name: { type: "string", + maxLength: 256, title: "Name", }, slots: { @@ -2877,6 +2878,22 @@ export const $PoolPostBody = { description: "Pool serializer for post bodies.", } as const; +export const $PoolPostBulkBody = { + properties: { + pools: { + items: { + $ref: "#/components/schemas/PoolPostBody", + }, + type: "array", + title: "Pools", + }, + }, + type: "object", + required: ["pools"], + title: "PoolPostBulkBody", + description: "Pools serializer for post bodies.", +} as const; + export const $PoolResponse = { properties: { name: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 63272c5e0c470..32338dfbf6d80 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -109,6 +109,8 @@ import type { GetPoolsResponse, PostPoolData, PostPoolResponse, + PostPoolsData, + PostPoolsResponse, GetProvidersData, GetProvidersResponse, GetTaskInstanceData, @@ -1790,6 +1792,32 @@ export class PoolService { errors: { 401: "Unauthorized", 403: "Forbidden", + 409: "Conflict", + 422: "Validation Error", + }, + }); + } + + /** + * Post Pools + * Create multiple pools. + * @param data The data for the request. + * @param data.requestBody + * @returns PoolCollectionResponse Successful Response + * @throws ApiError + */ + public static postPools( + data: PostPoolsData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/pools/bulk", + body: data.requestBody, + mediaType: "application/json", + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 409: "Conflict", 422: "Validation Error", }, }); diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 926932379b350..73772b38502b7 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -712,6 +712,13 @@ export type PoolPostBody = { include_deferred?: boolean; }; +/** + * Pools serializer for post bodies. + */ +export type PoolPostBulkBody = { + pools: Array; +}; + /** * Pool serializer for responses. */ @@ -1510,6 +1517,12 @@ export type PostPoolData = { export type PostPoolResponse = PoolResponse; +export type PostPoolsData = { + requestBody: PoolPostBulkBody; +}; + +export type PostPoolsResponse = PoolCollectionResponse; + export type GetProvidersData = { limit?: number; offset?: number; @@ -3107,6 +3120,37 @@ export type $OpenApiTs = { * Forbidden */ 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + "/public/pools/bulk": { + post: { + req: PostPoolsData; + res: { + /** + * Successful Response + */ + 201: PoolCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; /** * Validation Error */ diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 4a774f1a1e379..1cbc62406636b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -324,9 +324,138 @@ class TestPostPool(TestPoolsEndpoint): def test_should_respond_200(self, test_client, session, body, expected_status_code, expected_response): self.create_pools() n_pools = session.query(Pool).count() - response = test_client.post("/public/pools", json=body) + response = test_client.post("/public/pools/", json=body) assert response.status_code == expected_status_code - body = response.json() assert response.json() == expected_response assert session.query(Pool).count() == n_pools + 1 + + @pytest.mark.parametrize( + "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", + [ + ( + {"name": "my_pool", "slots": 11}, + 201, + { + "name": "my_pool", + "slots": 11, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 11, + "deferred_slots": 0, + }, + 409, + {"detail": "Unique constraint violation"}, + ), + ], + ) + def test_should_response_409( + self, + test_client, + session, + body, + first_expected_status_code, + first_expected_response, + second_expected_status_code, + second_expected_response, + ): + self.create_pools() + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/", json=body) + assert response.status_code == first_expected_status_code + assert response.json() == first_expected_response + assert session.query(Pool).count() == n_pools + 1 + response = test_client.post("/public/pools/", json=body) + assert response.status_code == second_expected_status_code + assert response.json() == second_expected_response + assert session.query(Pool).count() == n_pools + 1 + + +class TestPostPools(TestPoolsEndpoint): + @pytest.mark.parametrize( + "body, expected_status_code, expected_response", + [ + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": "my_pool2", "slots": 12}, + ] + }, + 201, + { + "pools": [ + { + "name": "my_pool", + "slots": 11, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 11, + "deferred_slots": 0, + }, + { + "name": "my_pool2", + "slots": 12, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 12, + "deferred_slots": 0, + }, + ], + "total_entries": 2, + }, + ), + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": POOL1_NAME, "slots": 12}, + ] + }, + 409, + {"detail": "Unique constraint violation"}, + ), + ( + { + "pools": [ + {"name": POOL1_NAME, "slots": 11}, + {"name": POOL2_NAME, "slots": 12}, + ] + }, + 409, + {"detail": "Unique constraint violation"}, + ), + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": "my_pool", "slots": 12}, + ] + }, + 409, + {"detail": "Unique constraint violation"}, + ), + ], + ) + def test_post_pools(self, test_client, session, body, expected_status_code, expected_response): + self.create_pools() + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/bulk", json=body) + assert response.status_code == expected_status_code + assert response.json() == expected_response + if expected_status_code == 201: + assert session.query(Pool).count() == n_pools + 2 + else: + assert session.query(Pool).count() == n_pools