Skip to content

Commit

Permalink
Merge pull request #1747 from fractal-analytics-platform/1737-establi…
Browse files Browse the repository at this point in the history
…sh-usergroup-links-via-users-endpoints

1737 establish usergroup links via users endpoints
  • Loading branch information
tcompa authored Sep 10, 2024
2 parents bf8aed0 + aab184d commit 8d8bb95
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 38 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
* Create `update-db-script` for current version, that adds all users to default group (\#1738).
* API:
* Added `/auth/group/` and `/auth/group-names/` routers (\#1738).
* Implement `/auth/users/{id}/` POST/PATCH routes in `fractal-server` (\#1738).
* Implement `/auth/users/{id}/` POST/PATCH routes in `fractal-server` (\#1738, \#1747).
* Introduce `UserUpdateWithNewGroupIds` schema for `PATCH /auth/users/{id}/` (\#1747).
* Add `UserManager.on_after_register` hook to add new users to default user group (\#1738).
* Database:
* Added new `usergroup` and `linkusergroup` tables (\#1738).
Expand Down
30 changes: 8 additions & 22 deletions fractal_server/app/routes/auth/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from sqlmodel import func
from sqlmodel import select

from . import current_active_superuser
Expand All @@ -19,7 +20,6 @@
from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth


router_group = APIRouter()


Expand All @@ -31,9 +31,6 @@ async def get_list_user_groups(
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> list[UserGroupRead]:
"""
FIXME docstring
"""

# Get all groups
stm_all_groups = select(UserGroup)
Expand All @@ -46,7 +43,8 @@ async def get_list_user_groups(
res = await db.execute(stm_all_links)
links = res.scalars().all()

# FIXME GROUPS: this must be optimized
# TODO: possible optimizations for this construction are listed in
# https://github.com/fractal-analytics-platform/fractal-server/issues/1742
for ind, group in enumerate(groups):
groups[ind] = dict(
group.model_dump(),
Expand All @@ -68,9 +66,6 @@ async def get_single_user_group(
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserGroupRead:
"""
FIXME docstring
"""
group = await _get_single_group_with_user_ids(group_id=group_id, db=db)
return group

Expand All @@ -85,9 +80,6 @@ async def create_single_group(
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserGroupRead:
"""
FIXME docstring
"""

# Check that name is not already in use
existing_name_str = select(UserGroup).where(
Expand Down Expand Up @@ -119,24 +111,21 @@ async def update_single_group(
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserGroupRead:
"""
FIXME docstring
"""

# Check that all required users exist
# Note: The reason for introducing `col` is as in
# https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors,
stm = select(UserOAuth).where(
stm = select(func.count()).where(
col(UserOAuth.id).in_(group_update.new_user_ids)
)
res = await db.execute(stm)
matching_users = res.scalars().unique().all()
if not len(matching_users) == len(group_update.new_user_ids):
number_matching_users = res.scalar()
if number_matching_users != len(group_update.new_user_ids):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=(
f"At least user with IDs {group_update.new_user_ids} "
"does not exist."
f"Not all requested users (IDs {group_update.new_user_ids}) "
"exist."
),
)

Expand All @@ -161,9 +150,6 @@ async def delete_single_group(
user: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserGroupRead:
"""
FIXME docstring
"""
raise HTTPException(
status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
detail=(
Expand Down
98 changes: 84 additions & 14 deletions fractal_server/app/routes/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
from fastapi_users import schemas
from fastapi_users.router.common import ErrorCode
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from sqlmodel import func
from sqlmodel import select

from . import current_active_superuser
from ...db import get_async_db
from ...schemas.user import UserRead
from ...schemas.user import UserUpdate
from ...schemas.user import UserUpdateWithNewGroupIds
from ._aux_auth import _get_single_user_with_group_ids
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.routes.auth._aux_auth import _user_or_404
from fractal_server.app.security import get_user_manager
Expand All @@ -43,31 +47,96 @@ async def get_user(
@router_users.patch("/users/{user_id}/", response_model=UserRead)
async def patch_user(
user_id: int,
user_update: UserUpdate,
user_update: UserUpdateWithNewGroupIds,
current_superuser: UserOAuth = Depends(current_active_superuser),
user_manager: UserManager = Depends(get_user_manager),
db: AsyncSession = Depends(get_async_db),
):
"""
Custom version of the PATCH-user route from `fastapi-users`.
In order to keep the fastapi-users logic in place (which is convenient to
update user attributes), we split the endpoint into two branches. We either
go through the fastapi-users-based attribute-update branch, or through the
branch where we establish new user/group relationships.
Note that we prevent making both changes at the same time, since it would
be more complex to guarantee that endpoint error would leave the database
in the same state as before the API call.
"""

# We prevent simultaneous editing of both user attributes and user/group
# associations
user_update_dict_without_groups = user_update.dict(
exclude_unset=True, exclude={"new_group_ids"}
)
edit_attributes = user_update_dict_without_groups != {}
edit_groups = user_update.new_group_ids is not None
if edit_attributes and edit_groups:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
"Cannot modify both user attributes and group membership. "
"Please make two independent PATCH calls"
),
)

# Check that user exists
user_to_patch = await _user_or_404(user_id, db)

try:
user = await user_manager.update(
user_update, user_to_patch, safe=False, request=None
)
patched_user = schemas.model_validate(UserOAuth, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": ErrorCode.UPDATE_USER_INVALID_PASSWORD,
"reason": e.reason,
},
if edit_groups:
# Establish new user/group relationships

# Check that all required groups exist
# Note: The reason for introducing `col` is as in
# https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors,
stm = select(func.count()).where(
col(UserGroup.id).in_(user_update.new_group_ids)
)
res = await db.execute(stm)
number_matching_groups = res.scalar()
if number_matching_groups != len(user_update.new_group_ids):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=(
"Not all requested groups (IDs: "
f"{user_update.new_group_ids}) exist."
),
)

for new_group_id in user_update.new_group_ids:
link = LinkUserGroup(user_id=user_id, group_id=new_group_id)
db.add(link)
await db.commit()

patched_user = user_to_patch

elif edit_attributes:
# Modify user attributes
try:
user_update_without_groups = UserUpdate(
**user_update_dict_without_groups
)
user = await user_manager.update(
user_update_without_groups,
user_to_patch,
safe=False,
request=None,
)
patched_user = schemas.model_validate(UserOAuth, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": ErrorCode.UPDATE_USER_INVALID_PASSWORD,
"reason": e.reason,
},
)
else:
# Nothing to do, just continue
patched_user = user_to_patch

# Enrich user object with `group_ids` attribute
patched_user_with_group_ids = await _get_single_user_with_group_ids(
patched_user, db
)
Expand All @@ -92,7 +161,8 @@ async def list_users(
res = await db.execute(stm_all_links)
links = res.scalars().all()

# FIXME GROUPS: this must be optimized
# TODO: possible optimizations for this construction are listed in
# https://github.com/fractal-analytics-platform/fractal-server/issues/1742
for ind, user in enumerate(user_list):
user_list[ind] = dict(
user.model_dump(),
Expand Down
5 changes: 5 additions & 0 deletions fractal_server/app/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"UserRead",
"UserUpdate",
"UserCreate",
"UserUpdateWithNewGroupIds",
)


Expand Down Expand Up @@ -102,6 +103,10 @@ class UserUpdateStrict(BaseModel, extra=Extra.forbid):
)


class UserUpdateWithNewGroupIds(UserUpdate):
new_group_ids: Optional[list[int]] = None


class UserCreate(schemas.BaseUserCreate):
"""
Schema for `User` creation.
Expand Down
80 changes: 80 additions & 0 deletions tests/no_version/test_auth_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,83 @@ async def test_delete_user(registered_client, registered_superuser_client):
f"{PREFIX}/users/THIS-IS-NOT-AN-ID"
)
assert res.status_code == 404


async def test_add_groups_to_user_as_superuser(registered_superuser_client):

# Create user
res = await registered_superuser_client.post(
f"{PREFIX}/register/",
json=dict(
email="test@fractal.xy",
password="12345",
slurm_accounts=["foo", "bar"],
),
)
assert res.status_code == 201
user_id = res.json()["id"]
res = await registered_superuser_client.get(f"{PREFIX}/users/{user_id}/")
assert res.status_code == 200
user = res.json()
debug(user)
assert user["group_ids"] == []

# Create group
res = await registered_superuser_client.post(
f"{PREFIX}/group/",
json=dict(name="groupname"),
)
assert res.status_code == 201
group_id = res.json()["id"]

# Create user/group link and fail because of invalid `group_id``
invalid_group_id = 999999
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/",
json=dict(new_group_ids=[invalid_group_id]),
)
assert res.status_code == 404

# Create user/group link and succeed
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/",
json=dict(new_group_ids=[group_id]),
)
assert res.status_code == 200
assert res.json()["group_ids"] == [group_id]


async def test_edit_user_and_fail(registered_superuser_client):

# Create user
res = await registered_superuser_client.post(
f"{PREFIX}/register/",
json=dict(
email="test@fractal.xy",
password="12345",
slurm_accounts=["foo", "bar"],
),
)
assert res.status_code == 201
user_id = res.json()["id"]

# Patch both user attributes and user/group relationship, and fail
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/",
json=dict(
slurm_user="new-slurm-user",
new_group_ids=[],
),
)
assert res.status_code == 422
expected_detail = (
"Cannot modify both user attributes and group membership."
)
assert expected_detail in res.json()["detail"]

# Make a dummy patch to user, and succeed
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/",
json={},
)
assert res.status_code == 200
2 changes: 1 addition & 1 deletion tests/no_version/test_auth_groups_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def test_update_group(registered_superuser_client):
assert res.status_code == 200
assert res.json()["user_ids"] == []

# Patch an existing group by adding a valid users
# Patch an existing group by adding a valid user
res = await registered_superuser_client.patch(
f"{PREFIX}/group/{group_id}/",
json=dict(new_user_ids=[user_A_id]),
Expand Down

0 comments on commit 8d8bb95

Please sign in to comment.