diff --git a/fractal_server/app/routes/auth/group.py b/fractal_server/app/routes/auth/group.py index e1038e8295..3f7f109735 100644 --- a/fractal_server/app/routes/auth/group.py +++ b/fractal_server/app/routes/auth/group.py @@ -15,9 +15,13 @@ from ...schemas.user_group import UserGroupRead from ...schemas.user_group import UserGroupUpdate from ._aux_auth import _get_single_group_with_user_ids +from ._aux_auth import _get_single_user_with_group_ids +from ._aux_auth import _user_or_404 from fractal_server.app.models import LinkUserGroup from fractal_server.app.models import UserGroup from fractal_server.app.models import UserOAuth +from fractal_server.app.schemas import UserRead +from fractal_server.app.schemas import UserUpdateNewGroups router_group = APIRouter() @@ -153,6 +157,50 @@ async def update_single_group( return updated_group +@router_group.patch( + "/group/user/{user_id}/", + response_model=UserRead, + status_code=status.HTTP_200_OK, +) +async def add_groups_to_user( + user_id: int, + user_update: UserUpdateNewGroups, + user: UserOAuth = Depends(current_active_superuser), + db: AsyncSession = Depends(get_async_db), +) -> UserGroupRead: + + user_to_patch = await _user_or_404(user_id, db) + if len(user_update.new_group_ids) > 0: + + # Check that all required groups exist + # Note: The reason for introducing `col` is as in + # https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors, + stm = select(UserGroup).where( + col(UserGroup.id).in_(user_update.new_group_ids) + ) + res = await db.execute(stm) + matching_groups = res.scalars().unique().all() + if not len(matching_groups) == len(user_update.new_group_ids): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=( + "Not all requested groups (IDs: " + f"{user_update.new_group_ids}) exist." + ), + ) + + for new_group_id in user_update.new_group_ids: + link = LinkUserGroup(user_id=user_id, group_id=new_group_id) + db.add(link) + await db.commit() + + patched_user_with_group_ids = await _get_single_user_with_group_ids( + user_to_patch, db + ) + + return patched_user_with_group_ids + + @router_group.delete( "/group/{group_id}/", status_code=status.HTTP_405_METHOD_NOT_ALLOWED ) diff --git a/fractal_server/app/schemas/user.py b/fractal_server/app/schemas/user.py index 2f967ddf5e..bbbb1e7fe4 100644 --- a/fractal_server/app/schemas/user.py +++ b/fractal_server/app/schemas/user.py @@ -16,6 +16,7 @@ "UserRead", "UserUpdate", "UserCreate", + "UserUpdateNewGroups", ) @@ -134,3 +135,14 @@ def slurm_accounts_validator(cls, value): _cache_dir = validator("cache_dir", allow_reuse=True)( val_absolute_path("cache_dir") ) + + +class UserUpdateNewGroups(BaseModel, extra=Extra.forbid): # FIXME RENAME + """ + Simple schema for `add_groups_to_user` endpoint. + + Attributes: + new_group_ids: IDs of groups to be added to user. + """ + + new_group_ids: list[int] = Field(default_factory=list) diff --git a/tests/no_version/test_auth_groups_api.py b/tests/no_version/test_auth_groups_api.py index 4bd6701600..f94f1aa4bc 100644 --- a/tests/no_version/test_auth_groups_api.py +++ b/tests/no_version/test_auth_groups_api.py @@ -261,3 +261,46 @@ async def test_get_user_optional_group_info( debug(user) assert user["group_names"] is None assert user["group_ids"] is None + + +async def test_add_groups_to_user_as_superuser(registered_superuser_client): + + # Create user + res = await registered_superuser_client.post( + f"{PREFIX}/register/", + json=dict( + email="test@fractal.xy", + password="12345", + slurm_accounts=["foo", "bar"], + ), + ) + assert res.status_code == 201 + user_id = res.json()["id"] + res = await registered_superuser_client.get(f"{PREFIX}/users/{user_id}/") + assert res.status_code == 200 + user = res.json() + debug(user) + assert user["group_ids"] == [] + + # Create group + res = await registered_superuser_client.post( + f"{PREFIX}/group/", + json=dict(name="groupname"), + ) + assert res.status_code == 201 + group_id = res.json()["id"] + + # Create user/group link and fail because of invalid `group_id`` + invalid_group_id = 999999 + res = await registered_superuser_client.patch( + f"{PREFIX}/group/user/{user_id}/", + json=dict(new_group_ids=[invalid_group_id]), + ) + assert res.status_code == 404 + + # Create user/group link and succeed + res = await registered_superuser_client.patch( + f"{PREFIX}/group/user/{user_id}/", json=dict(new_group_ids=[group_id]) + ) + assert res.status_code == 200 + assert res.json()["group_ids"] == [group_id]