diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a8d380918..ad4faf2bf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ These are the section headers that we use: ### Added -- Added `GET /api/v1/users/{user_id}/workspaces` endpoint to list the workspaces to which a user belongs ([#3308](https://github.com/argilla-io/argilla/pull/3308)). +- Added `GET /api/v1/users/{user_id}/workspaces` endpoint to list the workspaces to which a user belongs ([#3308](https://github.com/argilla-io/argilla/pull/3308) and [#3343](https://github.com/argilla-io/argilla/pull/3343)). - Added `HuggingFaceDatasetMixIn` for internal usage, to detach the `FeedbackDataset` integrations from the class itself, and use MixIns instead ([#3326](https://github.com/argilla-io/argilla/pull/3326)). - Added `GET /api/v1/records/{record_id}/suggestions` API endpoint to get the list of suggestions for the responses associated to a record ([#3304](https://github.com/argilla-io/argilla/pull/3304)). - Added `POST /api/v1/records/{record_id}/suggestions` API endpoint to create a suggestion for a response associated to a record ([#3304](https://github.com/argilla-io/argilla/pull/3304)). diff --git a/src/argilla/server/apis/v1/handlers/users.py b/src/argilla/server/apis/v1/handlers/users.py index 81f7fac258..1bf5428f23 100644 --- a/src/argilla/server/apis/v1/handlers/users.py +++ b/src/argilla/server/apis/v1/handlers/users.py @@ -14,12 +14,12 @@ from uuid import UUID -from fastapi import APIRouter, Depends, Security +from fastapi import APIRouter, Depends, HTTPException, Security, status from sqlalchemy.ext.asyncio import AsyncSession from argilla.server.contexts import accounts from argilla.server.database import get_async_db -from argilla.server.models import User +from argilla.server.models import User, UserRole from argilla.server.policies import UserPolicyV1, authorize from argilla.server.schemas.v1.workspaces import Workspaces from argilla.server.security import auth @@ -35,5 +35,17 @@ async def list_user_workspaces( current_user: User = Security(auth.get_current_user), ): await authorize(current_user, UserPolicyV1.list_workspaces) - workspaces = await accounts.list_workspaces_by_user_id(db, user_id) + + user = await accounts.get_user_by_id(db, user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with id `{user_id}` not found", + ) + + if user.is_owner: + workspaces = await accounts.list_workspaces(db) + else: + workspaces = await accounts.list_workspaces_by_user_id(db, user_id) + return Workspaces(items=workspaces) diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index 0b93d9ebb2..f77db8a9a6 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -94,7 +94,7 @@ async def delete_workspace(db: "AsyncSession", workspace: Workspace): return workspace -async def get_user_by_id(db: Session, user_id: UUID) -> Union[User, None]: +async def get_user_by_id(db: "AsyncSession", user_id: UUID) -> Union[User, None]: return await db.get(User, user_id) diff --git a/tests/server/api/v1/test_users.py b/tests/server/api/v1/test_users.py index 1a67dc5e72..2c822ca19c 100644 --- a/tests/server/api/v1/test_users.py +++ b/tests/server/api/v1/test_users.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import uuid from typing import TYPE_CHECKING import pytest from argilla._constants import API_KEY_HEADER_NAME -from argilla.server.models import UserRole +from argilla.server.models import User, UserRole from tests.factories import UserFactory, WorkspaceFactory @@ -45,6 +45,24 @@ async def test_list_user_workspaces(self, client: "TestClient", owner_auth_heade ] } + async def test_list_user_workspaces_for_owner(self, client: "TestClient"): + workspaces = await WorkspaceFactory.create_batch(5) + owner = await UserFactory.create(role=UserRole.owner) + + response = client.get(f"/api/v1/users/{owner.id}/workspaces", headers={API_KEY_HEADER_NAME: owner.api_key}) + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(workspace.id), + "name": workspace.name, + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + for workspace in workspaces + ] + } + @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_list_user_workspaces_as_restricted_user(self, client: "TestClient", role: UserRole): workspaces = await WorkspaceFactory.create_batch(3) @@ -56,3 +74,7 @@ async def test_list_user_workspaces_as_restricted_user(self, client: "TestClient ) assert response.status_code == 403 + + async def test_list_user_for_non_existing_user(self, client: "TestClient", owner_auth_header: dict): + response = client.get(f"/api/v1/users/{uuid.uuid4()}/workspaces", headers=owner_auth_header) + assert response.status_code == 404