Skip to content

Commit

Permalink
AIP-84 Patch Pool (#43266)
Browse files Browse the repository at this point in the history
* AIP-84 Patch Pool

* Fix CI
  • Loading branch information
pierrejeambrun authored Oct 23, 2024
1 parent ca2c809 commit 6a17a62
Show file tree
Hide file tree
Showing 11 changed files with 500 additions and 8 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_pools(
return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries))


@mark_fastapi_migration_done
@security.requires_access_pool("PUT")
@action_logging
@provide_session
Expand Down
91 changes: 91 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,72 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Pool
summary: Patch Pool
description: Update a Pool.
operationId: patch_pool
parameters:
- name: pool_name
in: path
required: true
schema:
type: string
title: Pool Name
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PoolBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/PoolResponse'
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/pools/:
get:
tags:
Expand Down Expand Up @@ -2222,6 +2288,31 @@ components:
- timetables
title: PluginResponse
description: Plugin serializer.
PoolBody:
properties:
pool:
anyOf:
- type: string
- type: 'null'
title: Pool
slots:
anyOf:
- type: integer
- type: 'null'
title: Slots
description:
anyOf:
- type: string
- type: 'null'
title: Description
include_deferred:
anyOf:
- type: boolean
- type: 'null'
title: Include Deferred
type: object
title: PoolBody
description: Pool serializer for bodies.
PoolCollectionResponse:
properties:
pools:
Expand Down
45 changes: 43 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
from __future__ import annotations

from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Query
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from typing_extensions import Annotated
Expand All @@ -25,7 +27,12 @@
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.serializers.pools import PoolCollectionResponse, PoolResponse
from airflow.api_fastapi.core_api.serializers.pools import (
BasePool,
PoolBody,
PoolCollectionResponse,
PoolResponse,
)
from airflow.models.pool import Pool

pools_router = AirflowRouter(tags=["Pool"], prefix="/pools")
Expand Down Expand Up @@ -95,3 +102,37 @@ async def get_pools(
pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools],
total_entries=total_entries,
)


@pools_router.patch("/{pool_name}", responses=create_openapi_http_exception_doc([400, 401, 403, 404]))
async def patch_pool(
pool_name: str,
patch_body: PoolBody,
session: Annotated[Session, Depends(get_session)],
update_mask: list[str] | None = Query(None),
) -> PoolResponse:
"""Update a Pool."""
# Only slots and include_deferred can be modified in 'default_pool'
if pool_name == Pool.DEFAULT_POOL_NAME:
if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask):
pass
else:
raise HTTPException(400, "Only slots and included_deferred can be modified on Default Pool")

pool = session.scalar(select(Pool).where(Pool.pool == pool_name).limit(1))
if not pool:
raise HTTPException(404, detail=f"The Pool with name: `{pool_name}` was not found")

if update_mask:
data = patch_body.model_dump(include=set(update_mask), by_alias=True)
else:
data = patch_body.model_dump(by_alias=True)
try:
BasePool.model_validate(data)
except ValidationError as e:
raise RequestValidationError(errors=e.errors())

for key, value in data.items():
setattr(pool, key, value)

return PoolResponse.model_validate(pool, from_attributes=True)
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ async def patch_variable(
data = patch_body.model_dump(exclude=non_update_fields)
for key, val in data.items():
setattr(variable, key, val)
session.add(variable)
return variable


Expand Down
23 changes: 19 additions & 4 deletions airflow/api_fastapi/core_api/serializers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import Annotated, Callable

from pydantic import BaseModel, BeforeValidator, Field
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field


def _call_function(function: Callable[[], int]) -> int:
Expand All @@ -31,14 +31,18 @@ def _call_function(function: Callable[[], int]) -> int:
return function()


class PoolResponse(BaseModel):
"""Pool serializer for responses."""
class BasePool(BaseModel):
"""Base serializer for Pool."""

pool: str = Field(serialization_alias="name", validation_alias="pool")
pool: str = Field(serialization_alias="name")
slots: int
description: str | None
include_deferred: bool


class PoolResponse(BasePool):
"""Pool serializer for responses."""

occupied_slots: Annotated[int, BeforeValidator(_call_function)]
running_slots: Annotated[int, BeforeValidator(_call_function)]
queued_slots: Annotated[int, BeforeValidator(_call_function)]
Expand All @@ -52,3 +56,14 @@ class PoolCollectionResponse(BaseModel):

pools: list[PoolResponse]
total_entries: int


class PoolBody(BaseModel):
"""Pool serializer for bodies."""

model_config = ConfigDict(populate_by_name=True)

name: str | None = Field(default=None, alias="pool")
slots: int | None = None
description: str | None = None
include_deferred: bool | None = None
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ export type DagServicePatchDagMutationResult = Awaited<
export type VariableServicePatchVariableMutationResult = Awaited<
ReturnType<typeof VariableService.patchVariable>
>;
export type PoolServicePatchPoolMutationResult = Awaited<
ReturnType<typeof PoolService.patchPool>
>;
export type DagServiceDeleteDagMutationResult = Awaited<
ReturnType<typeof DagService.deleteDag>
>;
Expand Down
54 changes: 53 additions & 1 deletion airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ import {
ProviderService,
VariableService,
} from "../requests/services.gen";
import { DAGPatchBody, DagRunState, VariableBody } from "../requests/types.gen";
import {
DAGPatchBody,
DagRunState,
PoolBody,
VariableBody,
} from "../requests/types.gen";
import * as Common from "./common";

/**
Expand Down Expand Up @@ -776,6 +781,53 @@ export const useVariableServicePatchVariable = <
}) as unknown as Promise<TData>,
...options,
});
/**
* Patch Pool
* Update a Pool.
* @param data The data for the request.
* @param data.poolName
* @param data.requestBody
* @param data.updateMask
* @returns PoolResponse Successful Response
* @throws ApiError
*/
export const usePoolServicePatchPool = <
TData = Common.PoolServicePatchPoolMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
poolName: string;
requestBody: PoolBody;
updateMask?: string[];
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
poolName: string;
requestBody: PoolBody;
updateMask?: string[];
},
TContext
>({
mutationFn: ({ poolName, requestBody, updateMask }) =>
PoolService.patchPool({
poolName,
requestBody,
updateMask,
}) as unknown as Promise<TData>,
...options,
});
/**
* Delete Dag
* Delete the specific DAG.
Expand Down
52 changes: 52 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,58 @@ export const $PluginResponse = {
description: "Plugin serializer.",
} as const;

export const $PoolBody = {
properties: {
pool: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Pool",
},
slots: {
anyOf: [
{
type: "integer",
},
{
type: "null",
},
],
title: "Slots",
},
description: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Description",
},
include_deferred: {
anyOf: [
{
type: "boolean",
},
{
type: "null",
},
],
title: "Include Deferred",
},
},
type: "object",
title: "PoolBody",
description: "Pool serializer for bodies.",
} as const;

export const $PoolCollectionResponse = {
properties: {
pools: {
Expand Down
Loading

0 comments on commit 6a17a62

Please sign in to comment.