From 1594037c61f216980c75dbcd31407def82441538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 4 Jul 2023 14:05:29 +0200 Subject: [PATCH] feat: add list user workspaces endpoint (#3308) # Description This PR adds a new endpoint to API v1 to list the workspaces to which a user belongs. Closes #3273 **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** Manually and unit tests have been added to test this new endpoint. **Checklist** - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Alvaro Bartolome --- CHANGELOG.md | 4 ++ src/argilla/server/apis/v1/handlers/users.py | 39 +++++++++++++ src/argilla/server/contexts/accounts.py | 10 ++++ src/argilla/server/policies.py | 6 ++ src/argilla/server/routes.py | 4 +- src/argilla/server/schemas/v1/workspaces.py | 5 ++ tests/server/api/v1/test_users.py | 58 ++++++++++++++++++++ 7 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 src/argilla/server/apis/v1/handlers/users.py create mode 100644 tests/server/api/v1/test_users.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd9537441..a2b154f4d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,10 @@ These are the section headers that we use: ## [Unreleased] +### Added + +- Added `GET /api/v1/users/{user_id}/workspaces` endpoint to list the workspaces to which a user belongs ([#3308](https://github.com/argilla-io/argilla/pull/3308)). +- ### Fixed - Fixed `sqlalchemy.error.OperationalError` being raised when running the unit tests if the local SQLite database file didn't exist and the migrations hadn't been applied ([#3307](https://github.com/argilla-io/argilla/pull/3307)). diff --git a/src/argilla/server/apis/v1/handlers/users.py b/src/argilla/server/apis/v1/handlers/users.py new file mode 100644 index 0000000000..81f7fac258 --- /dev/null +++ b/src/argilla/server/apis/v1/handlers/users.py @@ -0,0 +1,39 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from fastapi import APIRouter, Depends, Security +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla.server.contexts import accounts +from argilla.server.database import get_async_db +from argilla.server.models import User +from argilla.server.policies import UserPolicyV1, authorize +from argilla.server.schemas.v1.workspaces import Workspaces +from argilla.server.security import auth + +router = APIRouter(tags=["users"]) + + +@router.get("/users/{user_id}/workspaces", response_model=Workspaces) +async def list_user_workspaces( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.list_workspaces) + workspaces = await accounts.list_workspaces_by_user_id(db, user_id) + return Workspaces(items=workspaces) diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index 89d1cb6c9f..0b93d9ebb2 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -70,6 +70,16 @@ async def list_workspaces(db: "AsyncSession") -> List[Workspace]: return result.scalars().all() +async def list_workspaces_by_user_id(db: "AsyncSession", user_id: UUID) -> List[Workspace]: + result = await db.execute( + select(Workspace) + .join(WorkspaceUser) + .filter(WorkspaceUser.user_id == user_id) + .order_by(Workspace.inserted_at.asc()) + ) + return result.scalars().all() + + async def create_workspace(db: "AsyncSession", workspace_create: WorkspaceCreate) -> Workspace: workspace = Workspace(name=workspace_create.name) db.add(workspace) diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index fd1b18109c..b0d544378a 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -117,6 +117,12 @@ async def is_allowed(actor: User) -> bool: return is_allowed +class UserPolicyV1: + @classmethod + async def list_workspaces(cls, actor: User) -> bool: + return actor.is_owner + + class DatasetPolicy: @classmethod async def list(cls, user: User) -> bool: diff --git a/src/argilla/server/routes.py b/src/argilla/server/routes.py index 2759df5c0a..81988e527f 100644 --- a/src/argilla/server/routes.py +++ b/src/argilla/server/routes.py @@ -38,6 +38,7 @@ from argilla.server.apis.v1.handlers import questions as questions_v1 from argilla.server.apis.v1.handlers import records as records_v1 from argilla.server.apis.v1.handlers import responses as responses_v1 +from argilla.server.apis.v1.handlers import users as users_v1 from argilla.server.apis.v1.handlers import workspaces as workspaces_v1 from argilla.server.errors.base_errors import __ALL__ @@ -63,8 +64,9 @@ # API v1 api_router.include_router(datasets_v1.router, prefix="/v1") -api_router.include_router(workspaces_v1.router, prefix="/v1") api_router.include_router(fields_v1.router, prefix="/v1") api_router.include_router(questions_v1.router, prefix="/v1") api_router.include_router(records_v1.router, prefix="/v1") api_router.include_router(responses_v1.router, prefix="/v1") +api_router.include_router(users_v1.router, prefix="/v1") +api_router.include_router(workspaces_v1.router, prefix="/v1") diff --git a/src/argilla/server/schemas/v1/workspaces.py b/src/argilla/server/schemas/v1/workspaces.py index 53bbfe660e..7913f69609 100644 --- a/src/argilla/server/schemas/v1/workspaces.py +++ b/src/argilla/server/schemas/v1/workspaces.py @@ -13,6 +13,7 @@ # limitations under the License. from datetime import datetime +from typing import List from uuid import UUID from pydantic import BaseModel @@ -26,3 +27,7 @@ class Workspace(BaseModel): class Config: orm_mode = True + + +class Workspaces(BaseModel): + items: List[Workspace] diff --git a/tests/server/api/v1/test_users.py b/tests/server/api/v1/test_users.py new file mode 100644 index 0000000000..1a67dc5e72 --- /dev/null +++ b/tests/server/api/v1/test_users.py @@ -0,0 +1,58 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import pytest +from argilla._constants import API_KEY_HEADER_NAME +from argilla.server.models import UserRole + +from tests.factories import UserFactory, WorkspaceFactory + +if TYPE_CHECKING: + from fastapi.testclient import TestClient + + +@pytest.mark.asyncio +class TestsUsersV1Endpoints: + async def test_list_user_workspaces(self, client: "TestClient", owner_auth_header: dict): + workspaces = await WorkspaceFactory.create_batch(3) + user = await UserFactory.create(workspaces=workspaces) + + response = client.get(f"/api/v1/users/{user.id}/workspaces", headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(workspace.id), + "name": workspace.name, + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + for workspace in workspaces + ] + } + + @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) + async def test_list_user_workspaces_as_restricted_user(self, client: "TestClient", role: UserRole): + workspaces = await WorkspaceFactory.create_batch(3) + user = await UserFactory.create(workspaces=workspaces) + requesting_user = await UserFactory.create(role=role) + + response = client.get( + f"/api/v1/users/{user.id}/workspaces", headers={API_KEY_HEADER_NAME: requesting_user.api_key} + ) + + assert response.status_code == 403