From 26025ed5c544719101cd5dd524d99424fa0ddb0b Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 3 Jul 2023 13:38:33 +0200 Subject: [PATCH 1/4] feat: add `users_v1` router --- src/argilla/server/apis/v1/handlers/users.py | 34 ++++++++++++++++++++ src/argilla/server/routes.py | 4 ++- 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 src/argilla/server/apis/v1/handlers/users.py diff --git a/src/argilla/server/apis/v1/handlers/users.py b/src/argilla/server/apis/v1/handlers/users.py new file mode 100644 index 0000000000..c6129eaa65 --- /dev/null +++ b/src/argilla/server/apis/v1/handlers/users.py @@ -0,0 +1,34 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from fastapi import APIRouter, Depends, Security +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla.server.database import get_async_db +from argilla.server.models import User +from argilla.server.security import auth + +router = APIRouter(tags=["users"]) + + +@router.get("/users/{user_id}/workspaces") +async def get_user_workspaces( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: User = Security(auth.get_current_user), +): + pass diff --git a/src/argilla/server/routes.py b/src/argilla/server/routes.py index 2759df5c0a..81988e527f 100644 --- a/src/argilla/server/routes.py +++ b/src/argilla/server/routes.py @@ -38,6 +38,7 @@ from argilla.server.apis.v1.handlers import questions as questions_v1 from argilla.server.apis.v1.handlers import records as records_v1 from argilla.server.apis.v1.handlers import responses as responses_v1 +from argilla.server.apis.v1.handlers import users as users_v1 from argilla.server.apis.v1.handlers import workspaces as workspaces_v1 from argilla.server.errors.base_errors import __ALL__ @@ -63,8 +64,9 @@ # API v1 api_router.include_router(datasets_v1.router, prefix="/v1") -api_router.include_router(workspaces_v1.router, prefix="/v1") api_router.include_router(fields_v1.router, prefix="/v1") api_router.include_router(questions_v1.router, prefix="/v1") api_router.include_router(records_v1.router, prefix="/v1") api_router.include_router(responses_v1.router, prefix="/v1") +api_router.include_router(users_v1.router, prefix="/v1") +api_router.include_router(workspaces_v1.router, prefix="/v1") From f4a0e236b3c4e34cba8d1424dd3abe2a415a4a82 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 3 Jul 2023 16:54:25 +0200 Subject: [PATCH 2/4] feat: add `list_user_workspaces` endpoint --- src/argilla/server/apis/v1/handlers/users.py | 11 ++++++++--- src/argilla/server/contexts/accounts.py | 10 ++++++++++ src/argilla/server/policies.py | 6 ++++++ src/argilla/server/schemas/v1/workspaces.py | 5 +++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/users.py b/src/argilla/server/apis/v1/handlers/users.py index c6129eaa65..81f7fac258 100644 --- a/src/argilla/server/apis/v1/handlers/users.py +++ b/src/argilla/server/apis/v1/handlers/users.py @@ -17,18 +17,23 @@ from fastapi import APIRouter, Depends, Security from sqlalchemy.ext.asyncio import AsyncSession +from argilla.server.contexts import accounts from argilla.server.database import get_async_db from argilla.server.models import User +from argilla.server.policies import UserPolicyV1, authorize +from argilla.server.schemas.v1.workspaces import Workspaces from argilla.server.security import auth router = APIRouter(tags=["users"]) -@router.get("/users/{user_id}/workspaces") -async def get_user_workspaces( +@router.get("/users/{user_id}/workspaces", response_model=Workspaces) +async def list_user_workspaces( *, db: AsyncSession = Depends(get_async_db), user_id: UUID, current_user: User = Security(auth.get_current_user), ): - pass + await authorize(current_user, UserPolicyV1.list_workspaces) + workspaces = await accounts.list_workspaces_by_user_id(db, user_id) + return Workspaces(items=workspaces) diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index 89d1cb6c9f..0b93d9ebb2 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -70,6 +70,16 @@ async def list_workspaces(db: "AsyncSession") -> List[Workspace]: return result.scalars().all() +async def list_workspaces_by_user_id(db: "AsyncSession", user_id: UUID) -> List[Workspace]: + result = await db.execute( + select(Workspace) + .join(WorkspaceUser) + .filter(WorkspaceUser.user_id == user_id) + .order_by(Workspace.inserted_at.asc()) + ) + return result.scalars().all() + + async def create_workspace(db: "AsyncSession", workspace_create: WorkspaceCreate) -> Workspace: workspace = Workspace(name=workspace_create.name) db.add(workspace) diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index fd1b18109c..b0d544378a 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -117,6 +117,12 @@ async def is_allowed(actor: User) -> bool: return is_allowed +class UserPolicyV1: + @classmethod + async def list_workspaces(cls, actor: User) -> bool: + return actor.is_owner + + class DatasetPolicy: @classmethod async def list(cls, user: User) -> bool: diff --git a/src/argilla/server/schemas/v1/workspaces.py b/src/argilla/server/schemas/v1/workspaces.py index 53bbfe660e..7913f69609 100644 --- a/src/argilla/server/schemas/v1/workspaces.py +++ b/src/argilla/server/schemas/v1/workspaces.py @@ -13,6 +13,7 @@ # limitations under the License. from datetime import datetime +from typing import List from uuid import UUID from pydantic import BaseModel @@ -26,3 +27,7 @@ class Workspace(BaseModel): class Config: orm_mode = True + + +class Workspaces(BaseModel): + items: List[Workspace] From 4c747949100bfcda39e9ce8186355d1548e88582 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 3 Jul 2023 17:08:12 +0200 Subject: [PATCH 3/4] feat: add list user workspaces endpoint unit tests --- tests/server/api/v1/test_users.py | 58 +++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/server/api/v1/test_users.py diff --git a/tests/server/api/v1/test_users.py b/tests/server/api/v1/test_users.py new file mode 100644 index 0000000000..1a67dc5e72 --- /dev/null +++ b/tests/server/api/v1/test_users.py @@ -0,0 +1,58 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import pytest +from argilla._constants import API_KEY_HEADER_NAME +from argilla.server.models import UserRole + +from tests.factories import UserFactory, WorkspaceFactory + +if TYPE_CHECKING: + from fastapi.testclient import TestClient + + +@pytest.mark.asyncio +class TestsUsersV1Endpoints: + async def test_list_user_workspaces(self, client: "TestClient", owner_auth_header: dict): + workspaces = await WorkspaceFactory.create_batch(3) + user = await UserFactory.create(workspaces=workspaces) + + response = client.get(f"/api/v1/users/{user.id}/workspaces", headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(workspace.id), + "name": workspace.name, + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + for workspace in workspaces + ] + } + + @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) + async def test_list_user_workspaces_as_restricted_user(self, client: "TestClient", role: UserRole): + workspaces = await WorkspaceFactory.create_batch(3) + user = await UserFactory.create(workspaces=workspaces) + requesting_user = await UserFactory.create(role=role) + + response = client.get( + f"/api/v1/users/{user.id}/workspaces", headers={API_KEY_HEADER_NAME: requesting_user.api_key} + ) + + assert response.status_code == 403 From 56ae37ec1cad3e7b6ebe52dec93548b68d2e94f7 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 3 Jul 2023 17:12:16 +0200 Subject: [PATCH 4/4] docs: updata changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61bff3ec60..988f4bd096 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,10 @@ These are the section headers that we use: ## [Unreleased] +### Added + +- Added `GET /api/v1/users/{user_id}/workspaces` endpoint to list the workspaces to which a user belongs ([#3308](https://github.com/argilla-io/argilla/pull/3308)). + ## [1.12.0](https://github.com/argilla-io/argilla/compare/v1.11.0...v1.12.0) ### Added