Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(BA-483): Revamp ContainerRegistryNode API #3424

Draft
wants to merge 16 commits into
base: topic/11-11-feat_implement_associatecontainerregistrywithgroup_disassociatecontainerregistrywithgroup_
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3424.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revamp `ContainerRegistryNode` API.
65 changes: 18 additions & 47 deletions src/ai/backend/client/func/container_registry.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,38 @@
from __future__ import annotations

import textwrap
from ai.backend.client.request import Request
from ai.backend.common.container_registry import (
PatchContainerRegistryRequestModel,
PatchContainerRegistryResponseModel,
)

from ..session import api_session
from .base import BaseFunction, api_function

__all__ = ("ContainerRegistry",)


class ContainerRegistry(BaseFunction):
"""
Provides a shortcut of :func:`Admin.query()
<ai.backend.client.admin.Admin.query>` that fetches, modifies various container registry
information.

.. note::

All methods in this function class require your API access key to
have the *admin* privilege.
Provides functions to manage container registries.
"""

@api_function
@classmethod
async def associate_group(cls, registry_id: str, group_id: str) -> dict:
async def patch_container_registry(
cls, registry_id: str, params: PatchContainerRegistryRequestModel
) -> PatchContainerRegistryResponseModel:
"""
Associate container_registry with group.
Updates the container registry information, and return the container registry.

:param registry_id: ID of the container registry.
:param group_id: ID of the group.
"""
query = textwrap.dedent(
"""\
mutation($registry_id: String!, $group_id: String!) {
associate_container_registry_with_group(
registry_id: $registry_id, group_id: $group_id) {
ok msg
}
}
"""
)
variables = {"registry_id": registry_id, "group_id": group_id}
data = await api_session.get().Admin._query(query, variables)
return data["associate_container_registry_with_group"]

@api_function
@classmethod
async def disassociate_group(cls, registry_id: str, group_id: str) -> dict:
:param params: Parameters to update the container registry.
"""
Disassociate container_registry with group.

:param registry_id: ID of the container registry.
:param group_id: ID of the group.
"""
query = textwrap.dedent(
"""\
mutation($registry_id: String!, $group_id: String!) {
disassociate_container_registry_with_group(
registry_id: $registry_id, group_id: $group_id) {
ok msg
}
}
"""
request = Request(
"PATCH",
f"/container-registries/{registry_id}",
)
variables = {"registry_id": registry_id, "group_id": group_id}
data = await api_session.get().Admin._query(query, variables)
return data["disassociate_container_registry_with_group"]
request.set_json(params)

async with request.fetch() as resp:
return await resp.json()
45 changes: 45 additions & 0 deletions src/ai/backend/common/container_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import enum
import uuid
from typing import Any, Optional

from pydantic import BaseModel


class ContainerRegistryType(enum.StrEnum):
DOCKER = "docker"
HARBOR = "harbor"
HARBOR2 = "harbor2"
GITHUB = "github"
GITLAB = "gitlab"
ECR = "ecr"
ECR_PUB = "ecr-public"
LOCAL = "local"


class AllowedGroupsModel(BaseModel):
add: list[str] = []
remove: list[str] = []


class ContainerRegistryRowModel(BaseModel):
id: Optional[uuid.UUID] = None
url: Optional[str] = None
registry_name: Optional[str] = None
type: Optional[ContainerRegistryType] = None
project: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
ssl_verify: Optional[bool] = None
is_global: Optional[bool] = None
extra: Optional[dict[str, Any]] = None

class Config:
from_attributes = True


class PatchContainerRegistryRequestModel(ContainerRegistryRowModel):
allowed_groups: Optional[AllowedGroupsModel] = None


class PatchContainerRegistryResponseModel(ContainerRegistryRowModel):
pass
106 changes: 37 additions & 69 deletions src/ai/backend/manager/api/container_registry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from __future__ import annotations

import logging
import uuid
from typing import TYPE_CHECKING, Iterable, Tuple

import aiohttp_cors
import sqlalchemy as sa
from aiohttp import web
from pydantic import AliasChoices, BaseModel, Field
from sqlalchemy.exc import IntegrityError

from ai.backend.common.container_registry import (
PatchContainerRegistryRequestModel,
PatchContainerRegistryResponseModel,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.association_container_registries_groups import (
AssociationContainerRegistriesGroupsRow,
from ai.backend.manager.models.container_registry import (
ContainerRegistryRow,
handle_allowed_groups_update,
)

from .exceptions import ContainerRegistryNotFound, GenericBadRequest
from .exceptions import ContainerRegistryNotFound, GenericBadRequest, InternalServerError

if TYPE_CHECKING:
from .context import RootContext
Expand All @@ -27,76 +32,40 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class AssociationRequestModel(BaseModel):
registry_id: str = Field(
validation_alias=AliasChoices("registry_id", "registry"),
description="Container registry row's ID",
)
group_id: str = Field(
validation_alias=AliasChoices("group_id", "group"),
description="Group row's ID",
)


@server_status_required(READ_ALLOWED)
@superadmin_required
@pydantic_params_api_handler(AssociationRequestModel)
async def associate_with_group(
request: web.Request, params: AssociationRequestModel
) -> web.Response:
log.info("ASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id)
@pydantic_params_api_handler(PatchContainerRegistryRequestModel)
async def patch_container_registry(
request: web.Request, params: PatchContainerRegistryRequestModel
) -> PatchContainerRegistryResponseModel:
registry_id = uuid.UUID(request.match_info["registry_id"])
log.info("PATCH_CONTAINER_REGISTRY (cr:{})", registry_id)
root_ctx: RootContext = request.app["_root.context"]
registry_id = params.registry_id
group_id = params.group_id

async with root_ctx.db.begin_session() as db_sess:
insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values({
"registry_id": registry_id,
"group_id": group_id,
})

try:
await db_sess.execute(insert_query)
except IntegrityError:
raise GenericBadRequest("Association already exists.")

return web.Response(status=204)
registry_row_updates = params.model_dump(exclude={"allowed_groups"}, exclude_none=True)

try:
async with root_ctx.db.begin_session() as db_session:
if registry_row_updates:
update_stmt = (
sa.update(ContainerRegistryRow)
.where(ContainerRegistryRow.id == registry_id)
.values(registry_row_updates)
)
await db_session.execute(update_stmt)

class DisassociationRequestModel(BaseModel):
registry_id: str = Field(
validation_alias=AliasChoices("registry_id", "registry"),
description="Container registry row's ID",
)
group_id: str = Field(
validation_alias=AliasChoices("group_id", "group"),
description="Group row's ID",
)


@server_status_required(READ_ALLOWED)
@superadmin_required
@pydantic_params_api_handler(DisassociationRequestModel)
async def disassociate_with_group(
request: web.Request, params: DisassociationRequestModel
) -> web.Response:
log.info("DISASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id)
root_ctx: RootContext = request.app["_root.context"]
registry_id = params.registry_id
group_id = params.group_id

async with root_ctx.db.begin_session() as db_sess:
delete_query = (
sa.delete(AssociationContainerRegistriesGroupsRow)
.where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id)
.where(AssociationContainerRegistriesGroupsRow.group_id == group_id)
)
query = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == registry_id)
container_registry = (await db_session.execute(query)).fetchone()[0]

result = await db_sess.execute(delete_query)
if result.rowcount == 0:
raise ContainerRegistryNotFound()
if params.allowed_groups:
await handle_allowed_groups_update(root_ctx.db, registry_id, params.allowed_groups)
except ContainerRegistryNotFound as e:
raise e
except IntegrityError as e:
raise GenericBadRequest(f"Failed to update allowed groups! Details: {str(e)}")
except Exception as e:
raise InternalServerError(f"Failed to update container registry! Details: {str(e)}")

return web.Response(status=204)
return PatchContainerRegistryResponseModel.model_validate(container_registry)


def create_app(
Expand All @@ -106,6 +75,5 @@ def create_app(
app["api_versions"] = (1, 2, 3, 4, 5)
app["prefix"] = "container-registries"
cors = aiohttp_cors.setup(app, defaults=default_cors_options)
cors.add(app.router.add_route("POST", "/associate-with-group", associate_with_group))
cors.add(app.router.add_route("POST", "/disassociate-with-group", disassociate_with_group))
cors.add(app.router.add_route("PATCH", "/{registry_id}", patch_container_registry))
return app, []
Loading
Loading