diff --git a/changes/3424.fix.md b/changes/3424.fix.md new file mode 100644 index 00000000000..5004ff7c3c3 --- /dev/null +++ b/changes/3424.fix.md @@ -0,0 +1 @@ +Revamp `ContainerRegistryNode` API. diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 65a7e871bcb..0ceef363367 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -49,6 +49,15 @@ type Queries { """Added in 24.09.0.""" order: String + + """Added in 25.2.0. Default is `system`.""" + scope: ScopeField + + """Added in 25.2.0.""" + container_registry_scope: ContainerRegistryScopeField + + """Added in 25.2.0. Default is read_attribute.""" + permission: GroupPermissionField = "read_attribute" offset: Int before: String after: String @@ -770,6 +779,14 @@ type GroupEdge { cursor: String! } +"""Added in 25.2.0.""" +scalar ContainerRegistryScopeField + +""" +Added in 25.2.0. One of ['read_attribute', 'read_sensitive_attribute', 'update_attribute', 'delete_project', 'associate_with_user']. +""" +scalar GroupPermissionField + type Group { id: UUID name: String @@ -1563,6 +1580,9 @@ type ContainerRegistryNode implements Node { """Added in 24.09.3.""" extra: JSONString + + """Added in 25.2.0.""" + allowed_groups(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): GroupConnection } """Added in 24.09.0.""" @@ -1877,69 +1897,17 @@ type Mutations { """Added in 24.09.0.""" create_container_registry_node( - """Added in 24.09.3.""" - extra: JSONString - - """Added in 24.09.0.""" - is_global: Boolean - - """Added in 24.09.0.""" - password: String - - """Added in 24.09.0.""" - project: String - - """Added in 24.09.0.""" - registry_name: String! - - """Added in 24.09.0.""" - ssl_verify: Boolean - - """ - Added in 24.09.0. Registry type. One of ('docker', 'harbor', 'harbor2', 'github', 'gitlab', 'ecr', 'ecr-public', 'local'). - """ - type: ContainerRegistryTypeField! - - """Added in 24.09.0.""" - url: String! - - """Added in 24.09.0.""" - username: String + """Added in 25.2.0.""" + props: CreateContainerRegistryNodeInput! ): CreateContainerRegistryNode """Added in 24.09.0.""" modify_container_registry_node( - """Added in 24.09.3.""" - extra: JSONString - """Object id. Can be either global id or object id. Added in 24.09.0.""" id: String! - """Added in 24.09.0.""" - is_global: Boolean - - """Added in 24.09.0.""" - password: String - - """Added in 24.09.0.""" - project: String - - """Added in 24.09.0.""" - registry_name: String - - """Added in 24.09.0.""" - ssl_verify: Boolean - - """ - Registry type. One of ('docker', 'harbor', 'harbor2', 'github', 'gitlab', 'ecr', 'ecr-public', 'local'). Added in 24.09.0. - """ - type: ContainerRegistryTypeField - - """Added in 24.09.0.""" - url: String - - """Added in 24.09.0.""" - username: String + """Added in 25.2.0.""" + props: ModifyContainerRegistryNodeInput! ): ModifyContainerRegistryNode """Added in 24.09.0.""" @@ -1957,12 +1925,6 @@ type Mutations { """Added in 25.1.0.""" delete_endpoint_auto_scaling_rule_node(id: String!): DeleteEndpointAutoScalingRuleNode - """Added in 25.2.0.""" - associate_container_registry_with_group(group_id: String!, registry_id: String!): AssociateContainerRegistryWithGroup - - """Added in 25.2.0.""" - disassociate_container_registry_with_group(group_id: String!, registry_id: String!): DisassociateContainerRegistryWithGroup - """Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead""" create_container_registry(hostname: String!, props: CreateContainerRegistryInput!): CreateContainerRegistry @@ -2721,11 +2683,83 @@ type CreateContainerRegistryNode { container_registry: ContainerRegistryNode } +input CreateContainerRegistryNodeInput { + """Added in 24.09.0.""" + url: String! + + """Added in 24.09.0.""" + type: ContainerRegistryTypeField! + + """Added in 24.09.0.""" + registry_name: String! + + """Added in 24.09.0.""" + is_global: Boolean + + """Added in 24.09.0.""" + project: String + + """Added in 24.09.0.""" + username: String + + """Added in 24.09.0.""" + password: String + + """Added in 24.09.0.""" + ssl_verify: Boolean + + """Added in 24.09.3.""" + extra: JSONString + + """Added in 25.2.0.""" + allowed_groups: AllowedGroups +} + +input AllowedGroups { + """List of group_ids to add associations. Added in 25.2.0.""" + add: [String] = [] + + """List of group_ids to remove associations. Added in 25.2.0.""" + remove: [String] = [] +} + """Added in 24.09.0.""" type ModifyContainerRegistryNode { container_registry: ContainerRegistryNode } +input ModifyContainerRegistryNodeInput { + """Added in 24.09.0.""" + url: String + + """Added in 24.09.0.""" + type: ContainerRegistryTypeField + + """Added in 24.09.0.""" + registry_name: String + + """Added in 24.09.0.""" + is_global: Boolean + + """Added in 24.09.0.""" + project: String + + """Added in 24.09.0.""" + username: String + + """Added in 24.09.0.""" + password: String + + """Added in 24.09.0.""" + ssl_verify: Boolean + + """Added in 24.09.3.""" + extra: JSONString + + """Added in 25.2.0.""" + allowed_groups: AllowedGroups +} + """Added in 24.09.0.""" type DeleteContainerRegistryNode { container_registry: ContainerRegistryNode @@ -2775,18 +2809,6 @@ type DeleteEndpointAutoScalingRuleNode { msg: String } -"""Added in 25.2.0.""" -type AssociateContainerRegistryWithGroup { - ok: Boolean - msg: String -} - -"""Added in 25.2.0.""" -type DisassociateContainerRegistryWithGroup { - ok: Boolean - msg: String -} - """Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead""" type CreateContainerRegistry { container_registry: ContainerRegistry diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 08b7055639a..0f55ff1ec0c 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -20,44 +20,303 @@ } }, "schemas": { - "AssociationRequestModel": { + "AllowedGroupsModel": { "properties": { - "registry_id": { - "description": "Container registry row's ID", - "title": "Registry Id", - "type": "string" + "add": { + "default": [], + "items": { + "type": "string" + }, + "title": "Add", + "type": "array" }, - "group_id": { - "description": "Group row's ID", - "title": "Group Id", - "type": "string" + "remove": { + "default": [], + "items": { + "type": "string" + }, + "title": "Remove", + "type": "array" } }, - "required": [ - "registry_id", - "group_id" - ], - "title": "AssociationRequestModel", + "title": "AllowedGroupsModel", "type": "object" }, - "DisassociationRequestModel": { + "ContainerRegistryType": { + "enum": [ + "docker", + "harbor", + "harbor2", + "github", + "gitlab", + "ecr", + "ecr-public", + "local" + ], + "title": "ContainerRegistryType", + "type": "string" + }, + "PatchContainerRegistryRequestModel": { "properties": { - "registry_id": { - "description": "Container registry row's ID", - "title": "Registry Id", - "type": "string" + "id": { + "anyOf": [ + { + "format": "uuid", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Id" }, - "group_id": { - "description": "Group row's ID", - "title": "Group Id", - "type": "string" + "url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Url" + }, + "registry_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Registry Name" + }, + "type": { + "anyOf": [ + { + "$ref": "#/components/schemas/ContainerRegistryType" + }, + { + "type": "null" + } + ], + "default": null + }, + "project": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Project" + }, + "username": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Username" + }, + "password": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Password" + }, + "ssl_verify": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Ssl Verify" + }, + "is_global": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Is Global" + }, + "extra": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Extra" + }, + "allowed_groups": { + "anyOf": [ + { + "$ref": "#/components/schemas/AllowedGroupsModel" + }, + { + "type": "null" + } + ], + "default": null } }, - "required": [ - "registry_id", - "group_id" - ], - "title": "DisassociationRequestModel", + "title": "PatchContainerRegistryRequestModel", + "type": "object" + }, + "PatchContainerRegistryResponseModel": { + "properties": { + "id": { + "anyOf": [ + { + "format": "uuid", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Id" + }, + "url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Url" + }, + "registry_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Registry Name" + }, + "type": { + "anyOf": [ + { + "$ref": "#/components/schemas/ContainerRegistryType" + }, + { + "type": "null" + } + ], + "default": null + }, + "project": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Project" + }, + "username": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Username" + }, + "password": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Password" + }, + "ssl_verify": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Ssl Verify" + }, + "is_global": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Is Global" + }, + "extra": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Extra" + } + }, + "title": "PatchContainerRegistryResponseModel", "type": "object" }, "VFolderPermission": { @@ -1258,15 +1517,22 @@ "description": "\n**Preconditions:**\n* User privilege required.\n* Manager status required: RUNNING\n" } }, - "/container-registries/associate-with-group": { - "post": { - "operationId": "container-registries.associate_with_group", + "/container-registries/{registry_id}": { + "patch": { + "operationId": "container-registries.patch_container_registry", "tags": [ "container-registries" ], "responses": { "200": { - "description": "Successful response" + "description": "", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PatchContainerRegistryResponseModel" + } + } + } } }, "security": [ @@ -1278,41 +1544,21 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AssociationRequestModel" + "$ref": "#/components/schemas/PatchContainerRegistryRequestModel" } } } }, - "parameters": [], - "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" - } - }, - "/container-registries/disassociate-with-group": { - "post": { - "operationId": "container-registries.disassociate_with_group", - "tags": [ - "container-registries" - ], - "responses": { - "200": { - "description": "Successful response" - } - }, - "security": [ + "parameters": [ { - "TokenAuth": [] - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/DisassociationRequestModel" - } + "name": "registry_id", + "in": "path", + "required": true, + "schema": { + "type": "string" } } - }, - "parameters": [], + ], "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" } }, diff --git a/src/ai/backend/client/func/container_registry.py b/src/ai/backend/client/func/container_registry.py index 570ed80237c..8d3af0531ff 100644 --- a/src/ai/backend/client/func/container_registry.py +++ b/src/ai/backend/client/func/container_registry.py @@ -1,8 +1,11 @@ from __future__ import annotations -import textwrap +from ai.backend.client.request import Request +from ai.backend.common.container_registry import ( + PatchContainerRegistryRequestModel, + PatchContainerRegistryResponseModel, +) -from ..session import api_session from .base import BaseFunction, api_function __all__ = ("ContainerRegistry",) @@ -10,58 +13,26 @@ class ContainerRegistry(BaseFunction): """ - Provides a shortcut of :func:`Admin.query() - ` that fetches, modifies various container registry - information. - - .. note:: - - All methods in this function class require your API access key to - have the *admin* privilege. + Provides functions to manage container registries. """ @api_function @classmethod - async def associate_group(cls, registry_id: str, group_id: str) -> dict: + async def patch_container_registry( + cls, registry_id: str, params: PatchContainerRegistryRequestModel + ) -> PatchContainerRegistryResponseModel: """ - Associate container_registry with group. + Updates the container registry information, and return the container registry. :param registry_id: ID of the container registry. - :param group_id: ID of the group. - """ - query = textwrap.dedent( - """\ - mutation($registry_id: String!, $group_id: String!) { - associate_container_registry_with_group( - registry_id: $registry_id, group_id: $group_id) { - ok msg - } - } - """ - ) - variables = {"registry_id": registry_id, "group_id": group_id} - data = await api_session.get().Admin._query(query, variables) - return data["associate_container_registry_with_group"] - - @api_function - @classmethod - async def disassociate_group(cls, registry_id: str, group_id: str) -> dict: + :param params: Parameters to update the container registry. """ - Disassociate container_registry with group. - :param registry_id: ID of the container registry. - :param group_id: ID of the group. - """ - query = textwrap.dedent( - """\ - mutation($registry_id: String!, $group_id: String!) { - disassociate_container_registry_with_group( - registry_id: $registry_id, group_id: $group_id) { - ok msg - } - } - """ + request = Request( + "PATCH", + f"/container-registries/{registry_id}", ) - variables = {"registry_id": registry_id, "group_id": group_id} - data = await api_session.get().Admin._query(query, variables) - return data["disassociate_container_registry_with_group"] + request.set_json(params) + + async with request.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/common/container_registry.py b/src/ai/backend/common/container_registry.py new file mode 100644 index 00000000000..dd41729dd63 --- /dev/null +++ b/src/ai/backend/common/container_registry.py @@ -0,0 +1,45 @@ +import enum +import uuid +from typing import Any, Optional + +from pydantic import BaseModel + + +class ContainerRegistryType(enum.StrEnum): + DOCKER = "docker" + HARBOR = "harbor" + HARBOR2 = "harbor2" + GITHUB = "github" + GITLAB = "gitlab" + ECR = "ecr" + ECR_PUB = "ecr-public" + LOCAL = "local" + + +class AllowedGroupsModel(BaseModel): + add: list[str] = [] + remove: list[str] = [] + + +class ContainerRegistryRowModel(BaseModel): + id: Optional[uuid.UUID] = None + url: Optional[str] = None + registry_name: Optional[str] = None + type: Optional[ContainerRegistryType] = None + project: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None + ssl_verify: Optional[bool] = None + is_global: Optional[bool] = None + extra: Optional[dict[str, Any]] = None + + class Config: + from_attributes = True + + +class PatchContainerRegistryRequestModel(ContainerRegistryRowModel): + allowed_groups: Optional[AllowedGroupsModel] = None + + +class PatchContainerRegistryResponseModel(ContainerRegistryRowModel): + pass diff --git a/src/ai/backend/manager/api/container_registry.py b/src/ai/backend/manager/api/container_registry.py index ab0c282b1c6..488fb69c010 100644 --- a/src/ai/backend/manager/api/container_registry.py +++ b/src/ai/backend/manager/api/container_registry.py @@ -1,20 +1,25 @@ from __future__ import annotations import logging +import uuid from typing import TYPE_CHECKING, Iterable, Tuple import aiohttp_cors import sqlalchemy as sa from aiohttp import web -from pydantic import AliasChoices, BaseModel, Field from sqlalchemy.exc import IntegrityError +from ai.backend.common.container_registry import ( + PatchContainerRegistryRequestModel, + PatchContainerRegistryResponseModel, +) from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.association_container_registries_groups import ( - AssociationContainerRegistriesGroupsRow, +from ai.backend.manager.models.container_registry import ( + ContainerRegistryRow, + handle_allowed_groups_update, ) -from .exceptions import ContainerRegistryNotFound, GenericBadRequest +from .exceptions import ContainerRegistryNotFound, GenericBadRequest, InternalServerError if TYPE_CHECKING: from .context import RootContext @@ -27,76 +32,40 @@ log = BraceStyleAdapter(logging.getLogger(__spec__.name)) -class AssociationRequestModel(BaseModel): - registry_id: str = Field( - validation_alias=AliasChoices("registry_id", "registry"), - description="Container registry row's ID", - ) - group_id: str = Field( - validation_alias=AliasChoices("group_id", "group"), - description="Group row's ID", - ) - - @server_status_required(READ_ALLOWED) @superadmin_required -@pydantic_params_api_handler(AssociationRequestModel) -async def associate_with_group( - request: web.Request, params: AssociationRequestModel -) -> web.Response: - log.info("ASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id) +@pydantic_params_api_handler(PatchContainerRegistryRequestModel) +async def patch_container_registry( + request: web.Request, params: PatchContainerRegistryRequestModel +) -> PatchContainerRegistryResponseModel: + registry_id = uuid.UUID(request.match_info["registry_id"]) + log.info("PATCH_CONTAINER_REGISTRY (cr:{})", registry_id) root_ctx: RootContext = request.app["_root.context"] - registry_id = params.registry_id - group_id = params.group_id - - async with root_ctx.db.begin_session() as db_sess: - insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values({ - "registry_id": registry_id, - "group_id": group_id, - }) - - try: - await db_sess.execute(insert_query) - except IntegrityError: - raise GenericBadRequest("Association already exists.") - - return web.Response(status=204) + registry_row_updates = params.model_dump(exclude={"allowed_groups"}, exclude_none=True) + try: + async with root_ctx.db.begin_session() as db_session: + if registry_row_updates: + update_stmt = ( + sa.update(ContainerRegistryRow) + .where(ContainerRegistryRow.id == registry_id) + .values(registry_row_updates) + ) + await db_session.execute(update_stmt) -class DisassociationRequestModel(BaseModel): - registry_id: str = Field( - validation_alias=AliasChoices("registry_id", "registry"), - description="Container registry row's ID", - ) - group_id: str = Field( - validation_alias=AliasChoices("group_id", "group"), - description="Group row's ID", - ) - - -@server_status_required(READ_ALLOWED) -@superadmin_required -@pydantic_params_api_handler(DisassociationRequestModel) -async def disassociate_with_group( - request: web.Request, params: DisassociationRequestModel -) -> web.Response: - log.info("DISASSOCIATE_WITH_GROUP (cr:{}, gr:{})", params.registry_id, params.group_id) - root_ctx: RootContext = request.app["_root.context"] - registry_id = params.registry_id - group_id = params.group_id - - async with root_ctx.db.begin_session() as db_sess: - delete_query = ( - sa.delete(AssociationContainerRegistriesGroupsRow) - .where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id) - .where(AssociationContainerRegistriesGroupsRow.group_id == group_id) - ) + query = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == registry_id) + container_registry = (await db_session.execute(query)).fetchone()[0] - result = await db_sess.execute(delete_query) - if result.rowcount == 0: - raise ContainerRegistryNotFound() + if params.allowed_groups: + await handle_allowed_groups_update(root_ctx.db, registry_id, params.allowed_groups) + except ContainerRegistryNotFound as e: + raise e + except IntegrityError as e: + raise GenericBadRequest(f"Failed to update allowed groups! Details: {str(e)}") + except Exception as e: + raise InternalServerError(f"Failed to update container registry! Details: {str(e)}") - return web.Response(status=204) + return PatchContainerRegistryResponseModel.model_validate(container_registry) def create_app( @@ -106,6 +75,5 @@ def create_app( app["api_versions"] = (1, 2, 3, 4, 5) app["prefix"] = "container-registries" cors = aiohttp_cors.setup(app, defaults=default_cors_options) - cors.add(app.router.add_route("POST", "/associate-with-group", associate_with_group)) - cors.add(app.router.add_route("POST", "/disassociate-with-group", disassociate_with_group)) + cors.add(app.router.add_route("PATCH", "/{registry_id}", patch_container_registry)) return app, [] diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index cea53c04002..89ba53dde0f 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -4,7 +4,8 @@ import yarl -from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType +from ai.backend.common.container_registry import ContainerRegistryType +from ai.backend.manager.models.container_registry import ContainerRegistryRow if TYPE_CHECKING: from .base import BaseContainerRegistry diff --git a/src/ai/backend/manager/models/container_registry.py b/src/ai/backend/manager/models/container_registry.py index 61fffad431e..f7aac2f72aa 100644 --- a/src/ai/backend/manager/models/container_registry.py +++ b/src/ai/backend/manager/models/container_registry.py @@ -1,37 +1,49 @@ from __future__ import annotations -import enum import logging import uuid from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, TypeAlias, cast import graphene import graphql import sqlalchemy as sa import yarl -from graphql import Undefined, UndefinedType +from graphql import GraphQLError, Undefined from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import load_only, relationship from sqlalchemy.orm.exc import NoResultFound +from ai.backend.common.container_registry import AllowedGroupsModel, ContainerRegistryType from ai.backend.common.exception import UnknownImageRegistry from ai.backend.common.logging_utils import BraceStyleAdapter +from ai.backend.manager.api.exceptions import ContainerRegistryNotFound +from ai.backend.manager.models.rbac import ( + ContainerRegistryScope, + ContainerRegistryScopeType, + SystemScope, +) from ..defs import PASSWORD_PLACEHOLDER +from .association_container_registries_groups import ( + AssociationContainerRegistriesGroupsRow, +) from .base import ( Base, FilterExprArg, IDColumn, OrderExprArg, + PaginatedConnectionField, StrEnumType, generate_sql_info_for_gql_connection, set_if_set, ) +from .gql_models.group import GroupConnection, GroupNode from .gql_relay import AsyncNode, Connection, ConnectionResolverResult from .minilang.ordering import OrderSpecItem, QueryOrderParser from .minilang.queryfilter import FieldSpecItem, QueryFilterParser from .user import UserRole +from .utils import ExtendedAsyncSAEngine if TYPE_CHECKING: from .gql import GraphQueryContext @@ -44,18 +56,17 @@ "CreateContainerRegistry", "ModifyContainerRegistry", "DeleteContainerRegistry", + "ContainerRegistryNode", + "ContainerRegistryConnection", + "CreateContainerRegistryNode", + "ModifyContainerRegistryNode", + "DeleteContainerRegistryNode", ) -class ContainerRegistryType(enum.StrEnum): - DOCKER = "docker" - HARBOR = "harbor" - HARBOR2 = "harbor2" - GITHUB = "github" - GITLAB = "gitlab" - ECR = "ecr" - ECR_PUB = "ecr-public" - LOCAL = "local" +WhereClauseType: TypeAlias = ( + sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList +) class ContainerRegistryRow(Base): @@ -286,6 +297,49 @@ async def load_all( return [cls.from_row(ctx, row) for row in rows] +class AllowedGroups(graphene.InputObjectType): + add = graphene.List( + graphene.String, + default_value=[], + description="List of group_ids to add associations. Added in 25.2.0.", + ) + remove = graphene.List( + graphene.String, + default_value=[], + description="List of group_ids to remove associations. Added in 25.2.0.", + ) + + +async def handle_allowed_groups_update( + db: ExtendedAsyncSAEngine, + registry_id: uuid.UUID, + allowed_group_updates: AllowedGroups | AllowedGroupsModel, +): + async with db.begin_session() as db_sess: + if allowed_group_updates.add: + insert_values = [ + {"registry_id": registry_id, "group_id": group_id} + for group_id in allowed_group_updates.add + ] + + insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values(insert_values) + await db_sess.execute(insert_query) + + if allowed_group_updates.remove: + delete_query = ( + sa.delete(AssociationContainerRegistriesGroupsRow) + .where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id) + .where( + AssociationContainerRegistriesGroupsRow.group_id.in_( + allowed_group_updates.remove + ) + ) + ) + result = await db_sess.execute(delete_query) + if result.rowcount == 0: + raise ContainerRegistryNotFound() + + class ContainerRegistryNode(graphene.ObjectType): class Meta: interfaces = (AsyncNode,) @@ -304,6 +358,7 @@ class Meta: password = graphene.String(description="Added in 24.09.0.") ssl_verify = graphene.Boolean(description="Added in 24.09.0.") extra = graphene.JSONString(description="Added in 24.09.3.") + allowed_groups = PaginatedConnectionField(GroupConnection, description="Added in 25.2.0.") _queryfilter_fieldspec: dict[str, FieldSpecItem] = { "row_id": ("id", None), @@ -389,6 +444,39 @@ def from_row(cls, ctx: GraphQueryContext, row: ContainerRegistryRow) -> Containe extra=row.extra, ) + async def resolve_allowed_groups( + self, + info: graphene.ResolveInfo, + filter: Optional[str] = None, + order: Optional[str] = None, + offset: Optional[int] = None, + after: Optional[str] = None, + first: Optional[int] = None, + before: Optional[str] = None, + last: Optional[int] = None, + ) -> ConnectionResolverResult[GroupNode]: + scope = SystemScope() + + if self.is_global: + container_registry_scope = None + else: + container_registry_scope = ContainerRegistryScope.parse( + f"{ContainerRegistryScopeType.PROJECT}:{self.id}" + ) + + return await GroupNode.get_connection( + info, + scope, + container_registry_scope, + filter_expr=filter, + order_expr=order, + offset=offset, + after=after, + first=first, + before=before, + last=last, + ) + class ContainerRegistryConnection(Connection): """Added in 24.09.0.""" @@ -398,95 +486,100 @@ class Meta: description = "Added in 24.09.0." +class CreateContainerRegistryNodeInput(graphene.InputObjectType): + url = graphene.String(required=True, description="Added in 24.09.0.") + type = ContainerRegistryTypeField(required=True, description="Added in 24.09.0.") + registry_name = graphene.String(required=True, description="Added in 24.09.0.") + is_global = graphene.Boolean(description="Added in 24.09.0.") + project = graphene.String(description="Added in 24.09.0.") + username = graphene.String(description="Added in 24.09.0.") + password = graphene.String(description="Added in 24.09.0.") + ssl_verify = graphene.Boolean(description="Added in 24.09.0.") + extra = graphene.JSONString(description="Added in 24.09.3.") + allowed_groups = AllowedGroups(description="Added in 25.2.0.") + + class CreateContainerRegistryNode(graphene.Mutation): class Meta: description = "Added in 24.09.0." allowed_roles = (UserRole.SUPERADMIN,) - container_registry = graphene.Field(ContainerRegistryNode) class Arguments: - url = graphene.String(required=True, description="Added in 24.09.0.") - type = ContainerRegistryTypeField( - required=True, - description=f"Added in 24.09.0. Registry type. One of {ContainerRegistryTypeField.allowed_values}.", - ) - registry_name = graphene.String(required=True, description="Added in 24.09.0.") - is_global = graphene.Boolean(description="Added in 24.09.0.") - project = graphene.String(description="Added in 24.09.0.") - username = graphene.String(description="Added in 24.09.0.") - password = graphene.String(description="Added in 24.09.0.") - ssl_verify = graphene.Boolean(description="Added in 24.09.0.") - extra = graphene.JSONString(description="Added in 24.09.3.") + props = CreateContainerRegistryNodeInput(required=True, description="Added in 25.2.0.") + + container_registry = graphene.Field(ContainerRegistryNode) @classmethod async def mutate( cls, root, info: graphene.ResolveInfo, - url: str, - type: ContainerRegistryType, - registry_name: str, - is_global: bool | UndefinedType = Undefined, - project: str | UndefinedType = Undefined, - username: str | UndefinedType = Undefined, - password: str | UndefinedType = Undefined, - ssl_verify: bool | UndefinedType = Undefined, - extra: dict | UndefinedType = Undefined, + props: CreateContainerRegistryNodeInput, ) -> CreateContainerRegistryNode: ctx: GraphQueryContext = info.context input_config: dict[str, Any] = { - "registry_name": registry_name, - "url": url, - "type": type, + "registry_name": props.registry_name, + "url": props.url, + "type": props.type, } def _set_if_set(name: str, val: Any) -> None: if val is not Undefined: input_config[name] = val - _set_if_set("project", project) - _set_if_set("username", username) - _set_if_set("password", password) - _set_if_set("ssl_verify", ssl_verify) - _set_if_set("is_global", is_global) - _set_if_set("extra", extra) + _set_if_set("project", props.project) + _set_if_set("username", props.username) + _set_if_set("password", props.password) + _set_if_set("ssl_verify", props.ssl_verify) + _set_if_set("is_global", props.is_global) + _set_if_set("extra", props.extra) - async with ctx.db.begin_session() as db_session: - reg_row = ContainerRegistryRow(**input_config) - db_session.add(reg_row) - await db_session.flush() - await db_session.refresh(reg_row) + try: + async with ctx.db.begin_session() as db_session: + reg_row = ContainerRegistryRow(**input_config) + db_session.add(reg_row) + await db_session.flush() + await db_session.refresh(reg_row) + + if props.allowed_groups: + await handle_allowed_groups_update(ctx.db, reg_row.id, props.allowed_groups) return cls( container_registry=ContainerRegistryNode.from_row(ctx, reg_row), ) + except Exception as e: + raise GraphQLError(str(e)) + + +class ModifyContainerRegistryNodeInput(graphene.InputObjectType): + url = graphene.String(description="Added in 24.09.0.") + type = ContainerRegistryTypeField(description="Added in 24.09.0.") + registry_name = graphene.String(description="Added in 24.09.0.") + is_global = graphene.Boolean(description="Added in 24.09.0.") + project = graphene.String(description="Added in 24.09.0.") + username = graphene.String(description="Added in 24.09.0.") + password = graphene.String(description="Added in 24.09.0.") + ssl_verify = graphene.Boolean(description="Added in 24.09.0.") + extra = graphene.JSONString(description="Added in 24.09.3.") + allowed_groups = AllowedGroups(description="Added in 25.2.0.") class ModifyContainerRegistryNode(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) - container_registry = graphene.Field(ContainerRegistryNode) class Meta: description = "Added in 24.09.0." + container_registry = graphene.Field(ContainerRegistryNode) + class Arguments: id = graphene.String( required=True, description="Object id. Can be either global id or object id. Added in 24.09.0.", ) - url = graphene.String(description="Added in 24.09.0.") - type = ContainerRegistryTypeField( - description=f"Registry type. One of {ContainerRegistryTypeField.allowed_values}. Added in 24.09.0." - ) - registry_name = graphene.String(description="Added in 24.09.0.") - is_global = graphene.Boolean(description="Added in 24.09.0.") - project = graphene.String(description="Added in 24.09.0.") - username = graphene.String(description="Added in 24.09.0.") - password = graphene.String(description="Added in 24.09.0.") - ssl_verify = graphene.Boolean(description="Added in 24.09.0.") - extra = graphene.JSONString(description="Added in 24.09.3.") + props = ModifyContainerRegistryNodeInput(required=True, description="Added in 25.2.0.") @classmethod async def mutate( @@ -494,15 +587,7 @@ async def mutate( root, info: graphene.ResolveInfo, id: str, - url: str | UndefinedType = Undefined, - type: ContainerRegistryType | UndefinedType = Undefined, - registry_name: str | UndefinedType = Undefined, - is_global: bool | UndefinedType = Undefined, - project: str | UndefinedType = Undefined, - username: str | UndefinedType = Undefined, - password: str | UndefinedType = Undefined, - ssl_verify: bool | UndefinedType = Undefined, - extra: dict | UndefinedType = Undefined, + props: ModifyContainerRegistryNodeInput, ) -> ModifyContainerRegistryNode: ctx: GraphQueryContext = info.context @@ -512,33 +597,39 @@ def _set_if_set(name: str, val: Any) -> None: if val is not Undefined: input_config[name] = val - _set_if_set("url", url) - _set_if_set("type", type) - _set_if_set("registry_name", registry_name) - _set_if_set("username", username) - _set_if_set("password", password) - _set_if_set("project", project) - _set_if_set("ssl_verify", ssl_verify) - _set_if_set("is_global", is_global) - _set_if_set("extra", extra) + _set_if_set("url", props.url) + _set_if_set("type", props.type) + _set_if_set("registry_name", props.registry_name) + _set_if_set("username", props.username) + _set_if_set("password", props.password) + _set_if_set("project", props.project) + _set_if_set("ssl_verify", props.ssl_verify) + _set_if_set("is_global", props.is_global) + _set_if_set("extra", props.extra) _, _id = AsyncNode.resolve_global_id(info, id) reg_id = uuid.UUID(_id) if _id else uuid.UUID(id) - async with ctx.db.begin_session() as session: - stmt = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) - reg_row = await session.scalar(stmt) - if reg_row is None: - raise ValueError(f"ContainerRegistry not found (id: {reg_id})") - for field, val in input_config.items(): - setattr(reg_row, field, val) + try: + async with ctx.db.begin_session() as session: + stmt = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) + reg_row = await session.scalar(stmt) + if reg_row is None: + raise ValueError(f"ContainerRegistry not found (id: {reg_id})") + for field, val in input_config.items(): + setattr(reg_row, field, val) + + if props.allowed_groups: + await handle_allowed_groups_update(ctx.db, reg_row.id, props.allowed_groups) return cls(container_registry=ContainerRegistryNode.from_row(ctx, reg_row)) + except Exception as e: + raise GraphQLError(str(e)) + class DeleteContainerRegistryNode(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) - container_registry = graphene.Field(ContainerRegistryNode) class Meta: description = "Added in 24.09.0." @@ -549,6 +640,8 @@ class Arguments: description="Object id. Can be either global id or object id. Added in 24.09.0.", ) + container_registry = graphene.Field(ContainerRegistryNode) + @classmethod async def mutate( cls, @@ -560,19 +653,24 @@ async def mutate( _, _id = AsyncNode.resolve_global_id(info, id) reg_id = uuid.UUID(_id) if _id else uuid.UUID(id) - async with ctx.db.begin_session() as db_session: - reg_row = await ContainerRegistryRow.get(db_session, reg_id) - reg_row = await db_session.scalar( - sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) - ) - if reg_row is None: - raise ValueError(f"Container registry not found (id:{reg_id})") - container_registry = ContainerRegistryNode.from_row(ctx, reg_row) - await db_session.execute( - sa.delete(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) - ) - return cls(container_registry=container_registry) + try: + async with ctx.db.begin_session() as db_session: + reg_row = await ContainerRegistryRow.get(db_session, reg_id) + reg_row = await db_session.scalar( + sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) + ) + if reg_row is None: + raise ValueError(f"Container registry not found (id:{reg_id})") + container_registry = ContainerRegistryNode.from_row(ctx, reg_row) + await db_session.execute( + sa.delete(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) + ) + + return cls(container_registry=container_registry) + + except Exception as e: + raise GraphQLError(str(e)) # Legacy mutations @@ -699,3 +797,32 @@ async def mutate( ) await session.execute(stmt) return cls(container_registry=container_registry) + + +class ContainerRegistryScopeField(graphene.Scalar): + class Meta: + description = "Added in 25.2.0." + + @staticmethod + def serialize(val: ContainerRegistryScope) -> str: + if isinstance(val, ContainerRegistryScope): + return str(val) + raise ValueError("Invalid ContainerRegistryScope") + + @staticmethod + def parse_value(value): + if isinstance(value, str): + try: + return ContainerRegistryScope.parse(value) + except Exception as e: + raise ValueError(f"Invalid ContainerRegistryScope: {e}") + raise ValueError("Invalid ContainerRegistryScope") + + @staticmethod + def parse_literal(node): + if isinstance(node, graphql.language.ast.StringValueNode): + try: + return ContainerRegistryScope.parse(node.value) + except Exception as e: + raise ValueError(f"Invalid ContainerRegistryScope: {e}") + return None diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 85105103784..53e294e4abe 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -28,6 +28,7 @@ ContainerRegistry, ContainerRegistryConnection, ContainerRegistryNode, + ContainerRegistryScopeField, CreateContainerRegistry, CreateContainerRegistryNode, DeleteContainerRegistry, @@ -35,6 +36,7 @@ ModifyContainerRegistry, ModifyContainerRegistryNode, ) +from .rbac import ContainerRegistryScope if TYPE_CHECKING: from ai.backend.common.bgtask import BackgroundTaskManager @@ -74,10 +76,6 @@ AgentSummaryList, ModifyAgent, ) -from .gql_models.container_registry import ( - AssociateContainerRegistryWithGroup, - DisassociateContainerRegistryWithGroup, -) from .gql_models.domain import ( CreateDomainNode, DomainConnection, @@ -93,7 +91,7 @@ ModifyEndpointAutoScalingRuleNode, ) from .gql_models.fields import AgentPermissionField, ScopeField -from .gql_models.group import GroupConnection, GroupNode +from .gql_models.group import GroupConnection, GroupNode, GroupPermissionField from .gql_models.image import ( AliasImage, ClearImages, @@ -144,7 +142,12 @@ from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair from .network import CreateNetwork, DeleteNetwork, ModifyNetwork, NetworkConnection, NetworkNode from .rbac import ProjectScope, ScopeType, SystemScope -from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission +from .rbac.permission_defs import ( + AgentPermission, + ComputeSessionPermission, + DomainPermission, + ProjectPermission, +) from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission from .resource_policy import ( CreateKeyPairResourcePolicy, @@ -355,12 +358,6 @@ class Mutations(graphene.ObjectType): delete_endpoint_auto_scaling_rule_node = DeleteEndpointAutoScalingRuleNode.Field( description="Added in 25.1.0." ) - associate_container_registry_with_group = AssociateContainerRegistryWithGroup.Field( - description="Added in 25.2.0." - ) - disassociate_container_registry_with_group = DisassociateContainerRegistryWithGroup.Field( - description="Added in 25.2.0." - ) # Legacy mutations create_container_registry = CreateContainerRegistry.Field() @@ -474,6 +471,14 @@ class Queries(graphene.ObjectType): description="Added in 24.03.0.", filter=graphene.String(description="Added in 24.09.0."), order=graphene.String(description="Added in 24.09.0."), + scope=ScopeField( + description="Added in 25.2.0. Default is `system`.", + ), + container_registry_scope=ContainerRegistryScopeField(description="Added in 25.2.0."), + permission=GroupPermissionField( + default_value=ProjectPermission.READ_ATTRIBUTE, + description=f"Added in 25.2.0. Default is {ProjectPermission.READ_ATTRIBUTE.value}.", + ), ) group = graphene.Field( @@ -1165,16 +1170,23 @@ async def resolve_group_nodes( root: Any, info: graphene.ResolveInfo, *, - filter: str | None = None, - order: str | None = None, - offset: int | None = None, - after: str | None = None, - first: int | None = None, - before: str | None = None, - last: int | None = None, + scope: Optional[ScopeType] = None, + container_registry_scope: Optional[ContainerRegistryScope] = None, + permission: ProjectPermission = ProjectPermission.READ_ATTRIBUTE, + filter: Optional[str] = None, + order: Optional[str] = None, + offset: Optional[int] = None, + after: Optional[str] = None, + first: Optional[int] = None, + before: Optional[str] = None, + last: Optional[int] = None, ) -> ConnectionResolverResult[GroupNode]: + _scope = scope or SystemScope() return await GroupNode.get_connection( info, + _scope, + container_registry_scope, + permission, filter, order, offset, diff --git a/src/ai/backend/manager/models/gql_models/container_registry.py b/src/ai/backend/manager/models/gql_models/container_registry.py deleted file mode 100644 index 8f995257915..00000000000 --- a/src/ai/backend/manager/models/gql_models/container_registry.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Self - -import graphene -import sqlalchemy as sa - -from ai.backend.logging import BraceStyleAdapter - -from ..association_container_registries_groups import ( - AssociationContainerRegistriesGroupsRow, -) -from ..base import simple_db_mutate -from ..user import UserRole - -log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore - - -class AssociateContainerRegistryWithGroup(graphene.Mutation): - """Added in 25.2.0.""" - - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - registry_id = graphene.String(required=True) - group_id = graphene.String(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - - @classmethod - async def mutate( - cls, - root, - info: graphene.ResolveInfo, - registry_id: str, - group_id: str, - ) -> Self: - insert_query = sa.insert(AssociationContainerRegistriesGroupsRow).values({ - "registry_id": registry_id, - "group_id": group_id, - }) - return await simple_db_mutate(cls, info.context, insert_query) - - -class DisassociateContainerRegistryWithGroup(graphene.Mutation): - """Added in 25.2.0.""" - - allowed_roles = (UserRole.SUPERADMIN,) - - class Arguments: - registry_id = graphene.String(required=True) - group_id = graphene.String(required=True) - - ok = graphene.Boolean() - msg = graphene.String() - - @classmethod - async def mutate( - cls, - root, - info: graphene.ResolveInfo, - registry_id: str, - group_id: str, - ) -> Self: - delete_query = ( - sa.delete(AssociationContainerRegistriesGroupsRow) - .where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id) - .where(AssociationContainerRegistriesGroupsRow.group_id == group_id) - ) - return await simple_db_mutate(cls, info.context, delete_query) diff --git a/src/ai/backend/manager/models/gql_models/group.py b/src/ai/backend/manager/models/gql_models/group.py index d4d5fab7fbf..0524e3a1359 100644 --- a/src/ai/backend/manager/models/gql_models/group.py +++ b/src/ai/backend/manager/models/gql_models/group.py @@ -3,11 +3,14 @@ from collections.abc import Mapping from typing import ( TYPE_CHECKING, + Any, + Optional, Self, Sequence, ) import graphene +import graphql import sqlalchemy as sa from dateutil.parser import parse as dtparse from graphene.types.datetime import DateTime as GQLDateTime @@ -23,13 +26,16 @@ Connection, ConnectionResolverResult, ) -from ..group import AssocGroupUserRow, GroupRow, ProjectType +from ..group import AssocGroupUserRow, GroupRow, ProjectType, get_permission_ctx from ..minilang.ordering import OrderSpecItem, QueryOrderParser from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser +from ..rbac.context import ClientContext +from ..rbac.permission_defs import ProjectPermission from .user import UserConnection, UserNode if TYPE_CHECKING: from ..gql import GraphQueryContext + from ..rbac import ContainerRegistryScope, ScopeType from ..scaling_group import ScalingGroup _queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { @@ -217,13 +223,16 @@ async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: async def get_connection( cls, info: graphene.ResolveInfo, - filter_expr: str | None = None, - order_expr: str | None = None, - offset: int | None = None, - after: str | None = None, - first: int | None = None, - before: str | None = None, - last: int | None = None, + scope: ScopeType, + container_registry_scope: Optional[ContainerRegistryScope] = None, + permission: ProjectPermission = ProjectPermission.READ_ATTRIBUTE, + filter_expr: Optional[str] = None, + order_expr: Optional[str] = None, + offset: Optional[int] = None, + after: Optional[str] = None, + first: Optional[int] = None, + before: Optional[str] = None, + last: Optional[int] = None, ) -> ConnectionResolverResult[Self]: graph_ctx: GraphQueryContext = info.context _filter_arg = ( @@ -255,14 +264,47 @@ async def get_connection( before=before, last=last, ) - async with graph_ctx.db.begin_readonly_session() as db_session: - group_rows = (await db_session.scalars(query)).all() - result = [cls.from_row(graph_ctx, row) for row in group_rows] - total_cnt = await db_session.scalar(cnt_query) - return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + async with graph_ctx.db.connect() as db_conn: + user = graph_ctx.user + client_ctx = ClientContext( + graph_ctx.db, user["domain_name"], user["uuid"], user["role"] + ) + permission_ctx = await get_permission_ctx( + db_conn, client_ctx, permission, scope, container_registry_scope + ) + cond = permission_ctx.query_condition + if cond is None: + return ConnectionResolverResult([], cursor, pagination_order, page_size, 0) + query = query.where(cond) + cnt_query = cnt_query.where(cond) + + async with graph_ctx.db.begin_readonly_session(db_conn) as db_session: + group_rows = (await db_session.scalars(query)).all() + total_cnt = await db_session.scalar(cnt_query) + result = [cls.from_row(graph_ctx, row) for row in group_rows] + + return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) class GroupConnection(Connection): class Meta: node = GroupNode description = "Added in 24.03.0" + + +class GroupPermissionField(graphene.Scalar): + class Meta: + description = f"Added in 25.2.0. One of {[val.value for val in ProjectPermission]}." + + @staticmethod + def serialize(val: ProjectPermission) -> str: + return val.value + + @staticmethod + def parse_literal(node: Any, _variables=None): + if isinstance(node, graphql.language.ast.StringValueNode): + return ProjectPermission(node.value) + + @staticmethod + def parse_value(value: str) -> ProjectPermission: + return ProjectPermission(value) diff --git a/src/ai/backend/manager/models/gql_models/image.py b/src/ai/backend/manager/models/gql_models/image.py index cb081b355ce..382f27f3d27 100644 --- a/src/ai/backend/manager/models/gql_models/image.py +++ b/src/ai/backend/manager/models/gql_models/image.py @@ -21,13 +21,14 @@ from sqlalchemy.orm import load_only, selectinload from ai.backend.common import redis_helper +from ai.backend.common.container_registry import ContainerRegistryType from ai.backend.common.docker import ImageRef from ai.backend.common.exception import UnknownImageReference from ai.backend.common.types import ( ImageAlias, ) from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType +from ai.backend.manager.models.container_registry import ContainerRegistryRow from ...api.exceptions import ImageNotFound, ObjectNotFound from ...defs import DEFAULT_IMAGE_ARCH diff --git a/src/ai/backend/manager/models/group.py b/src/ai/backend/manager/models/group.py index 42273bd8fa6..e1f0263e0bb 100644 --- a/src/ai/backend/manager/models/group.py +++ b/src/ai/backend/manager/models/group.py @@ -38,6 +38,9 @@ from ai.backend.common import msgpack from ai.backend.common.types import ResourceSlot, VFolderID from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.association_container_registries_groups import ( + AssociationContainerRegistriesGroupsRow, +) from ..api.exceptions import VFolderOperationFailed from ..defs import RESERVED_DOTFILES @@ -76,6 +79,7 @@ if TYPE_CHECKING: from .gql import GraphQueryContext + from .rbac import ContainerRegistryScope from .scaling_group import ScalingGroup from .storage import StorageSessionManager @@ -993,6 +997,10 @@ def verify_dotfile_name(dotfile: str) -> bool: @dataclass class ProjectPermissionContext(AbstractPermissionContext[ProjectPermission, GroupRow, uuid.UUID]): + registry_id_to_additional_permission_map: dict[uuid.UUID, frozenset[ProjectPermission]] = field( + default_factory=dict + ) + @property def query_condition(self) -> WhereClauseType | None: cond: WhereClauseType | None = None @@ -1003,6 +1011,15 @@ def _OR_coalesce( ) -> WhereClauseType: return base_cond | _cond if base_cond is not None else _cond + if self.registry_id_to_additional_permission_map: + registry_id = list(self.registry_id_to_additional_permission_map)[0] + + cond = _OR_coalesce( + cond, + GroupRow.association_container_registries_groups_rows.any( + AssociationContainerRegistriesGroupsRow.registry_id == registry_id + ), + ) if self.domain_name_to_permission_map: cond = _OR_coalesce( cond, GroupRow.domain_name.in_(self.domain_name_to_permission_map.keys()) @@ -1096,6 +1113,14 @@ async def build_ctx_in_user_scope( ) -> ProjectPermissionContext: return ProjectPermissionContext() + async def build_ctx_in_container_registry_scope( + self, ctx: ClientContext, scope: ContainerRegistryScope + ) -> ProjectPermissionContext: + permissions = MEMBER_PERMISSIONS + return ProjectPermissionContext( + registry_id_to_additional_permission_map={scope.registry_id: permissions} + ) + @override @classmethod async def _permission_for_owner( @@ -1156,3 +1181,21 @@ async def get_projects( permissions = await permission_ctx.calculate_final_permission(row) result.append(ProjectModel.from_row(row, permissions)) return result + + +async def get_permission_ctx( + db_conn: SAConnection, + ctx: ClientContext, + requested_permission: ProjectPermission, + target_scope: ScopeType, + container_registry_scope: Optional[ContainerRegistryScope] = None, +) -> ProjectPermissionContext: + async with ctx.db.begin_readonly_session(db_conn) as db_session: + builder = ProjectPermissionContextBuilder(db_session) + + if container_registry_scope is not None: + return await builder.build_ctx_in_container_registry_scope( + ctx, container_registry_scope + ) + else: + return await builder.build(ctx, target_scope, requested_permission) diff --git a/src/ai/backend/manager/models/rbac/__init__.py b/src/ai/backend/manager/models/rbac/__init__.py index 22a2b28394b..e3a49675837 100644 --- a/src/ai/backend/manager/models/rbac/__init__.py +++ b/src/ai/backend/manager/models/rbac/__init__.py @@ -438,6 +438,36 @@ class ScalingGroup(ExtraScope): name: str +class ContainerRegistryScopeType(enum.StrEnum): + USER = "user" + PROJECT = "project" + + +@dataclass(frozen=True) +class ContainerRegistryScope(ExtraScope): + scope_type: ContainerRegistryScopeType + registry_id: uuid.UUID + + def __str__(self) -> str: + match self.registry_id: + case uuid.UUID(): + return f"{self.scope_type}:{str(self.registry_id)}" + case _: + raise ValueError(f"Invalid container registry scope ID: {str(self.registry_id)!r}") + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def parse(cls, raw: str) -> ContainerRegistryScope: + scope_type, _, registry_id = raw.partition(":") + match scope_type.lower(): + case ContainerRegistryScopeType.PROJECT | ContainerRegistryScopeType.USER as t: + return cls(t, uuid.UUID(registry_id)) + case _: + raise ValueError(f"Invalid container registry scope type: {scope_type!r}") + + ObjectType = TypeVar("ObjectType") ObjectIDType = TypeVar("ObjectIDType") diff --git a/tests/manager/api/test_container_registries.py b/tests/manager/api/test_container_registries.py index 7e8d194c618..81e0a43b99f 100644 --- a/tests/manager/api/test_container_registries.py +++ b/tests/manager/api/test_container_registries.py @@ -90,19 +90,19 @@ async def test_associate_container_registry_with_group( group_id = test_case["group_id"] registry_id = test_case["registry_id"] - url = "/container-registries/associate-with-group" - params = {"group_id": group_id, "registry_id": registry_id} + url = f"/container-registries/{registry_id}" + params = {"allowed_groups": {"add": [group_id]}} req_bytes = json.dumps(params).encode() - headers = get_headers("POST", url, req_bytes) + headers = get_headers("PATCH", url, req_bytes) - resp = await client.post(url, data=req_bytes, headers=headers) + resp = await client.patch(url, data=req_bytes, headers=headers) association_exist = "association_container_registries_groups" in extra_fixtures if association_exist: assert resp.status == 400 else: - assert resp.status == 204 + assert resp.status == 200 @pytest.mark.asyncio @@ -143,16 +143,16 @@ async def test_disassociate_container_registry_with_group( group_id = test_case["group_id"] registry_id = test_case["registry_id"] - url = "/container-registries/disassociate-with-group" - params = {"group_id": group_id, "registry_id": registry_id} + url = f"/container-registries/{registry_id}" + params = {"allowed_groups": {"remove": [group_id]}} req_bytes = json.dumps(params).encode() - headers = get_headers("POST", url, req_bytes) + headers = get_headers("PATCH", url, req_bytes) - resp = await client.post(url, data=req_bytes, headers=headers) + resp = await client.patch(url, data=req_bytes, headers=headers) association_exist = "association_container_registries_groups" in extra_fixtures if association_exist: - assert resp.status == 204 + assert resp.status == 200 else: assert resp.status == 404 diff --git a/tests/manager/models/gql_models/test_container_registries.py b/tests/manager/models/gql_models/test_container_registries.py index 900bbd3bb1b..fd61b6a9368 100644 --- a/tests/manager/models/gql_models/test_container_registries.py +++ b/tests/manager/models/gql_models/test_container_registries.py @@ -4,7 +4,6 @@ from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries from ai.backend.manager.models.utils import ExtendedAsyncSAEngine -from ai.backend.manager.server import database_ctx @pytest.fixture(scope="module") @@ -33,153 +32,3 @@ def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQuery idle_checker_host=None, # type: ignore network_plugin_ctx=None, # type: ignore ) - - -FIXTURES_WITH_NOASSOC = [ - { - "groups": [ - { - "id": "00000000-0000-0000-0000-000000000001", - "name": "mock_group", - "description": "", - "is_active": True, - "domain_name": "default", - "resource_policy": "default", - "total_resource_slots": {}, - "allowed_vfolder_hosts": {}, - "type": "general", - } - ], - "container_registries": [ - { - "id": "00000000-0000-0000-0000-000000000002", - "url": "https://mock.registry.com", - "type": "docker", - "project": "mock_project", - "registry_name": "mock_registry", - } - ], - } -] - -FIXTURES_WITH_ASSOC = [ - { - **fixture, - "association_container_registries_groups": [ - { - "id": "00000000-0000-0000-0000-000000000000", - "group_id": "00000000-0000-0000-0000-000000000001", - "registry_id": "00000000-0000-0000-0000-000000000002", - } - ], - } - for fixture in FIXTURES_WITH_NOASSOC -] - - -@pytest.mark.dependency() -@pytest.mark.asyncio -@pytest.mark.parametrize( - "extra_fixtures", - FIXTURES_WITH_NOASSOC + FIXTURES_WITH_ASSOC, - ids=["(No association)", "(With association)"], -) -@pytest.mark.parametrize( - "test_case", - [ - { - "group_id": "00000000-0000-0000-0000-000000000001", - "registry_id": "00000000-0000-0000-0000-000000000002", - }, - ], - ids=["Associate One group with one container registry"], -) -async def test_associate_container_registry_with_group( - client: Client, database_fixture, extra_fixtures, test_case, create_app_and_client -): - test_app, _ = await create_app_and_client( - [ - database_ctx, - ], - [], - ) - - root_ctx = test_app["_root.context"] - context = get_graphquery_context(root_ctx.db) - - query = """ - mutation ($group_id: String!, $registry_id: String!) { - associate_container_registry_with_group(group_id: $group_id, registry_id: $registry_id) { - ok - msg - } - } - """ - - variables = { - "group_id": test_case["group_id"], - "registry_id": test_case["registry_id"], - } - - response = await client.execute_async(query, variables=variables, context_value=context) - already_associated = "association_container_registries_groups" in extra_fixtures - - if already_associated: - assert not response["data"]["associate_container_registry_with_group"]["ok"] - else: - assert response["data"]["associate_container_registry_with_group"]["ok"] - assert response["data"]["associate_container_registry_with_group"]["msg"] == "success" - - -@pytest.mark.dependency() -@pytest.mark.asyncio -@pytest.mark.parametrize( - "extra_fixtures", - FIXTURES_WITH_ASSOC + FIXTURES_WITH_NOASSOC, - ids=["(With association)", "(No association)"], -) -@pytest.mark.parametrize( - "test_case", - [ - { - "group_id": "00000000-0000-0000-0000-000000000001", - "registry_id": "00000000-0000-0000-0000-000000000002", - }, - ], - ids=["Disassociate One group with one container registry"], -) -async def test_disassociate_container_registry_with_group( - client: Client, database_fixture, extra_fixtures, test_case, create_app_and_client -): - test_app, _ = await create_app_and_client( - [ - database_ctx, - ], - [], - ) - - root_ctx = test_app["_root.context"] - context = get_graphquery_context(root_ctx.db) - - query = """ - mutation ($group_id: String!, $registry_id: String!) { - disassociate_container_registry_with_group(group_id: $group_id, registry_id: $registry_id) { - ok - msg - } - } - """ - - variables = { - "group_id": test_case["group_id"], - "registry_id": test_case["registry_id"], - } - - response = await client.execute_async(query, variables=variables, context_value=context) - association_exist = "association_container_registries_groups" in extra_fixtures - - if association_exist: - assert response["data"]["disassociate_container_registry_with_group"]["ok"] - assert response["data"]["disassociate_container_registry_with_group"]["msg"] == "success" - else: - assert not response["data"]["disassociate_container_registry_with_group"]["ok"] diff --git a/tests/manager/models/test_container_registry_nodes.py b/tests/manager/models/test_container_registry_nodes.py index a99d9ae8fd2..623719034f5 100644 --- a/tests/manager/models/test_container_registry_nodes.py +++ b/tests/manager/models/test_container_registry_nodes.py @@ -4,10 +4,11 @@ from graphene import Schema from graphene.test import Client +from ai.backend.common.container_registry import ContainerRegistryType from ai.backend.manager.defs import PASSWORD_PLACEHOLDER -from ai.backend.manager.models.container_registry import ContainerRegistryType from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.server import database_ctx CONTAINER_REGISTRY_FIELDS = """ row_id @@ -22,6 +23,48 @@ """ +FIXTURES_WITH_NOASSOC = [ + { + "groups": [ + { + "id": "00000000-0000-0000-0000-000000000001", + "name": "mock_group", + "description": "", + "is_active": True, + "domain_name": "default", + "resource_policy": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "type": "general", + } + ], + "container_registries": [ + { + "id": "00000000-0000-0000-0000-000000000002", + "url": "https://mock.registry.com", + "type": "docker", + "project": "mock_project", + "registry_name": "mock_registry", + } + ], + } +] + +FIXTURES_WITH_ASSOC = [ + { + **fixture, + "association_container_registries_groups": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "group_id": "00000000-0000-0000-0000-000000000001", + "registry_id": "00000000-0000-0000-0000-000000000002", + } + ], + } + for fixture in FIXTURES_WITH_NOASSOC +] + + @pytest.fixture(scope="module") def client() -> Client: return Client(Schema(query=Queries, mutation=Mutations, auto_camelcase=False)) @@ -74,8 +117,8 @@ async def test_create_container_registry(client: Client, database_engine: Extend context = get_graphquery_context(database_engine) query = """ - mutation CreateContainerRegistryNode($type: ContainerRegistryTypeField!, $registry_name: String!, $url: String!, $project: String!, $username: String!, $password: String!, $ssl_verify: Boolean!, $is_global: Boolean!) { - create_container_registry_node(type: $type, registry_name: $registry_name, url: $url, project: $project, username: $username, password: $password, ssl_verify: $ssl_verify, is_global: $is_global) { + mutation ($props: CreateContainerRegistryNodeInput!) { + create_container_registry_node(props: $props) { container_registry { $CONTAINER_REGISTRY_FIELDS } @@ -84,18 +127,19 @@ async def test_create_container_registry(client: Client, database_engine: Extend """.replace("$CONTAINER_REGISTRY_FIELDS", CONTAINER_REGISTRY_FIELDS) variables = { - "registry_name": "cr.example.com", - "url": "http://cr.example.com", - "type": ContainerRegistryType.DOCKER, - "project": "default", - "username": "username", - "password": "password", - "ssl_verify": False, - "is_global": False, + "props": { + "registry_name": "cr.example.com", + "url": "http://cr.example.com", + "type": ContainerRegistryType.DOCKER, + "project": "default", + "username": "username", + "password": "password", + "ssl_verify": False, + "is_global": False, + } } response = await client.execute_async(query, variables=variables, context_value=context) - container_registry = response["data"]["create_container_registry_node"]["container_registry"] id = container_registry.pop("row_id", None) @@ -112,7 +156,7 @@ async def test_create_container_registry(client: Client, database_engine: Extend "is_global": False, } - variables["project"] = "default2" + variables["props"]["project"] = "default2" await client.execute_async(query, variables=variables, context_value=context) @@ -122,7 +166,7 @@ async def test_modify_container_registry(client: Client, database_engine: Extend context = get_graphquery_context(database_engine) query = """ - query ContainerRegistryNodes($filter: String!) { + query ($filter: String!) { container_registry_nodes (filter: $filter) { edges { node { @@ -151,8 +195,8 @@ async def test_modify_container_registry(client: Client, database_engine: Extend target_container_registry = target_container_registries[0]["node"] query = """ - mutation ModifyContainerRegistryNode($id: String!, $type: ContainerRegistryTypeField, $registry_name: String, $url: String, $project: String, $username: String, $password: String, $ssl_verify: Boolean, $is_global: Boolean) { - modify_container_registry_node(id: $id, type: $type, registry_name: $registry_name, url: $url, project: $project, username: $username, password: $password, ssl_verify: $ssl_verify, is_global: $is_global) { + mutation ($id: String!, $props: ModifyContainerRegistryNodeInput!) { + modify_container_registry_node(id: $id, props: $props) { container_registry { $CONTAINER_REGISTRY_FIELDS } @@ -162,8 +206,10 @@ async def test_modify_container_registry(client: Client, database_engine: Extend variables = { "id": target_container_registry["row_id"], - "registry_name": "cr.example.com", - "username": "username2", + "props": { + "registry_name": "cr.example.com", + "username": "username2", + }, } response = await client.execute_async(query, variables=variables, context_value=context) @@ -179,10 +225,12 @@ async def test_modify_container_registry(client: Client, database_engine: Extend variables = { "id": target_container_registry["row_id"], - "registry_name": "cr.example.com", - "url": "http://cr2.example.com", - "type": ContainerRegistryType.HARBOR2, - "project": "example", + "props": { + "registry_name": "cr.example.com", + "url": "http://cr2.example.com", + "type": ContainerRegistryType.HARBOR2, + "project": "example", + }, } response = await client.execute_async(query, variables=variables, context_value=context) @@ -206,7 +254,7 @@ async def test_modify_container_registry_allows_empty_string( context = get_graphquery_context(database_engine) query = """ - query ContainerRegistryNodes($filter: String!) { + query ($filter: String!) { container_registry_nodes (filter: $filter) { edges { node { @@ -233,8 +281,8 @@ async def test_modify_container_registry_allows_empty_string( target_container_registry = target_container_registries[0]["node"] query = """ - mutation ModifyContainerRegistryNode($id: String!, $type: ContainerRegistryTypeField, $registry_name: String, $url: String, $project: String, $username: String, $password: String, $ssl_verify: Boolean, $is_global: Boolean) { - modify_container_registry_node(id: $id, type: $type, registry_name: $registry_name, url: $url, project: $project, username: $username, password: $password, ssl_verify: $ssl_verify, is_global: $is_global) { + mutation ($id: String!, $props: ModifyContainerRegistryNodeInput!) { + modify_container_registry_node(id: $id, props: $props) { container_registry { $CONTAINER_REGISTRY_FIELDS } @@ -245,8 +293,10 @@ async def test_modify_container_registry_allows_empty_string( # Given an empty string to password variables = { "id": target_container_registry["row_id"], - "registry_name": "cr.example.com", - "password": "", + "props": { + "registry_name": "cr.example.com", + "password": "", + }, } # Then password is set to empty string @@ -271,7 +321,7 @@ async def test_modify_container_registry_allows_null_for_unset( context = get_graphquery_context(database_engine) query = """ - query ContainerRegistryNodes($filter: String!) { + query ($filter: String!) { container_registry_nodes (filter: $filter) { edges { node { @@ -283,7 +333,7 @@ async def test_modify_container_registry_allows_null_for_unset( } """.replace("$CONTAINER_REGISTRY_FIELDS", CONTAINER_REGISTRY_FIELDS) - variables: dict[str, str | None] = { + variables: dict[str, dict | str] = { "filter": 'registry_name == "cr.example.com"', } @@ -299,8 +349,8 @@ async def test_modify_container_registry_allows_null_for_unset( target_container_registry = target_container_registries[0]["node"] query = """ - mutation ModifyContainerRegistryNode($id: String!, $type: ContainerRegistryTypeField, $registry_name: String, $url: String, $project: String, $username: String, $password: String, $ssl_verify: Boolean, $is_global: Boolean) { - modify_container_registry_node(id: $id, type: $type, registry_name: $registry_name, url: $url, project: $project, username: $username, password: $password, ssl_verify: $ssl_verify, is_global: $is_global) { + mutation ($id: String!, $props: ModifyContainerRegistryNodeInput!) { + modify_container_registry_node(id: $id, props: $props) { container_registry { $CONTAINER_REGISTRY_FIELDS } @@ -311,8 +361,10 @@ async def test_modify_container_registry_allows_null_for_unset( # Given a null to password variables = { "id": target_container_registry["row_id"], - "registry_name": "cr.example.com", - "password": None, + "props": { + "registry_name": "cr.example.com", + "password": None, + }, } # Then password is unset @@ -334,7 +386,7 @@ async def test_delete_container_registry(client: Client, database_engine: Extend context = get_graphquery_context(database_engine) query = """ - query ContainerRegistryNodes($filter: String!) { + query ($filter: String!) { container_registry_nodes (filter: $filter) { edges { node { @@ -362,7 +414,7 @@ async def test_delete_container_registry(client: Client, database_engine: Extend target_container_registry = target_container_registries[0]["node"] query = """ - mutation DeleteContainerRegistryNode($id: String!) { + mutation ($id: String!) { delete_container_registry_node(id: $id) { container_registry { $CONTAINER_REGISTRY_FIELDS @@ -380,7 +432,7 @@ async def test_delete_container_registry(client: Client, database_engine: Extend assert container_registry["registry_name"] == "cr.example.com" query = """ - query ContainerRegistryNodes($filter: String!) { + query ($filter: String!) { container_registry_nodes (filter: $filter) { edges { node { @@ -398,3 +450,131 @@ async def test_delete_container_registry(client: Client, database_engine: Extend response = await client.execute_async(query, variables=variables, context_value=context) assert response["data"]["container_registry_nodes"] is None + + +@pytest.mark.dependency() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "extra_fixtures", + FIXTURES_WITH_NOASSOC + FIXTURES_WITH_ASSOC, + ids=["(No association)", "(With association)"], +) +@pytest.mark.parametrize( + "test_case", + [ + { + "group_id": "00000000-0000-0000-0000-000000000001", + "registry_id": "00000000-0000-0000-0000-000000000002", + }, + ], + ids=["Associate One group with one container registry"], +) +async def test_associate_container_registry_with_group( + client: Client, database_fixture, extra_fixtures, test_case, create_app_and_client +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + query = """ + mutation ($id: String!, $props: ModifyContainerRegistryNodeInput!) { + modify_container_registry_node(id: $id, props: $props) { + container_registry { + $CONTAINER_REGISTRY_FIELDS + } + } + } + """.replace("$CONTAINER_REGISTRY_FIELDS", CONTAINER_REGISTRY_FIELDS) + + variables = { + "id": test_case["registry_id"], + "props": { + "allowed_groups": { + "add": [test_case["group_id"]], + } + }, + } + + response = await client.execute_async(query, variables=variables, context_value=context) + already_associated = "association_container_registries_groups" in extra_fixtures + + if already_associated: + assert response["data"]["modify_container_registry_node"] is None + assert response["errors"] is not None + else: + assert ( + response["data"]["modify_container_registry_node"]["container_registry"][ + "registry_name" + ] + == "mock_registry" + ) + + +@pytest.mark.dependency() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "extra_fixtures", + FIXTURES_WITH_ASSOC + FIXTURES_WITH_NOASSOC, + ids=["(With association)", "(No association)"], +) +@pytest.mark.parametrize( + "test_case", + [ + { + "group_id": "00000000-0000-0000-0000-000000000001", + "registry_id": "00000000-0000-0000-0000-000000000002", + }, + ], + ids=["Disassociate One group with one container registry"], +) +async def test_disassociate_container_registry_with_group( + client: Client, database_fixture, extra_fixtures, test_case, create_app_and_client +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + query = """ + mutation ($id: String!, $props: ModifyContainerRegistryNodeInput!) { + modify_container_registry_node(id: $id, props: $props) { + container_registry { + $CONTAINER_REGISTRY_FIELDS + } + } + } + """.replace("$CONTAINER_REGISTRY_FIELDS", CONTAINER_REGISTRY_FIELDS) + + variables = { + "id": test_case["registry_id"], + "props": { + "allowed_groups": { + "remove": [test_case["group_id"]], + } + }, + } + + response = await client.execute_async(query, variables=variables, context_value=context) + association_exist = "association_container_registries_groups" in extra_fixtures + + if association_exist: + assert ( + response["data"]["modify_container_registry_node"]["container_registry"][ + "registry_name" + ] + == "mock_registry" + ) + else: + assert response["data"]["modify_container_registry_node"] is None + assert response["errors"] is not None