Skip to content

Commit

Permalink
First version of add-groups-to-user endpoint (ref #1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcompa committed Sep 10, 2024
1 parent 82ff4e3 commit e9dbdbd
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
48 changes: 48 additions & 0 deletions fractal_server/app/routes/auth/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
from ...schemas.user_group import UserGroupRead
from ...schemas.user_group import UserGroupUpdate
from ._aux_auth import _get_single_group_with_user_ids
from ._aux_auth import _get_single_user_with_group_ids
from ._aux_auth import _user_or_404
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.schemas import UserRead
from fractal_server.app.schemas import UserUpdateNewGroups


router_group = APIRouter()
Expand Down Expand Up @@ -153,6 +157,50 @@ async def update_single_group(
return updated_group


@router_group.patch(
"/group/user/{user_id}/",
response_model=UserRead,
status_code=status.HTTP_200_OK,
)
async def add_groups_to_user(
user_id: int,
user_update: UserUpdateNewGroups,
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserGroupRead:

user_to_patch = await _user_or_404(user_id, db)
if len(user_update.new_group_ids) > 0:

# Check that all required groups exist
# Note: The reason for introducing `col` is as in
# https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors,
stm = select(UserGroup).where(
col(UserGroup.id).in_(user_update.new_group_ids)
)
res = await db.execute(stm)
matching_groups = res.scalars().unique().all()
if not len(matching_groups) == len(user_update.new_group_ids):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=(
"Not all requested groups (IDs: "
f"{user_update.new_group_ids}) exist."
),
)

for new_group_id in user_update.new_group_ids:
link = LinkUserGroup(user_id=user_id, group_id=new_group_id)
db.add(link)
await db.commit()

patched_user_with_group_ids = await _get_single_user_with_group_ids(
user_to_patch, db
)

return patched_user_with_group_ids


@router_group.delete(
"/group/{group_id}/", status_code=status.HTTP_405_METHOD_NOT_ALLOWED
)
Expand Down
12 changes: 12 additions & 0 deletions fractal_server/app/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"UserRead",
"UserUpdate",
"UserCreate",
"UserUpdateNewGroups",
)


Expand Down Expand Up @@ -134,3 +135,14 @@ def slurm_accounts_validator(cls, value):
_cache_dir = validator("cache_dir", allow_reuse=True)(
val_absolute_path("cache_dir")
)


class UserUpdateNewGroups(BaseModel, extra=Extra.forbid): # FIXME RENAME
"""
Simple schema for `add_groups_to_user` endpoint.
Attributes:
new_group_ids: IDs of groups to be added to user.
"""

new_group_ids: list[int] = Field(default_factory=list)
43 changes: 43 additions & 0 deletions tests/no_version/test_auth_groups_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,46 @@ async def test_get_user_optional_group_info(
debug(user)
assert user["group_names"] is None
assert user["group_ids"] is None


async def test_add_groups_to_user_as_superuser(registered_superuser_client):

# Create user
res = await registered_superuser_client.post(
f"{PREFIX}/register/",
json=dict(
email="test@fractal.xy",
password="12345",
slurm_accounts=["foo", "bar"],
),
)
assert res.status_code == 201
user_id = res.json()["id"]
res = await registered_superuser_client.get(f"{PREFIX}/users/{user_id}/")
assert res.status_code == 200
user = res.json()
debug(user)
assert user["group_ids"] == []

# Create group
res = await registered_superuser_client.post(
f"{PREFIX}/group/",
json=dict(name="groupname"),
)
assert res.status_code == 201
group_id = res.json()["id"]

# Create user/group link and fail because of invalid `group_id``
invalid_group_id = 999999
res = await registered_superuser_client.patch(
f"{PREFIX}/group/user/{user_id}/",
json=dict(new_group_ids=[invalid_group_id]),
)
assert res.status_code == 404

# Create user/group link and succeed
res = await registered_superuser_client.patch(
f"{PREFIX}/group/user/{user_id}/", json=dict(new_group_ids=[group_id])
)
assert res.status_code == 200
assert res.json()["group_ids"] == [group_id]

0 comments on commit e9dbdbd

Please sign in to comment.