Skip to content

Commit

Permalink
Merge pull request #1780 from fractal-analytics-platform/user-setting…
Browse files Browse the repository at this point in the history
…s-api

User Settings API
  • Loading branch information
tcompa authored Sep 19, 2024
2 parents ff19c19 + de85cea commit 4234d4f
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 6 deletions.
11 changes: 10 additions & 1 deletion fractal_server/app/models/user_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

from sqlalchemy import Column
from sqlalchemy.types import JSON
from sqlmodel import Field
from sqlmodel import SQLModel

Expand All @@ -20,9 +22,16 @@ class UserSettings(SQLModel, table=True):

id: Optional[int] = Field(default=None, primary_key=True)

# Actual settings columns
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None

# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str] = Field(
sa_column=Column(JSON, server_default="[]", nullable=False)
)
cache_dir: Optional[str] = None
10 changes: 5 additions & 5 deletions fractal_server/app/routes/auth/_aux_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from ...models.linkusergroup import LinkUserGroup
from ...models.security import UserGroup
from ...models.security import UserOAuth
from ...schemas.user import UserRead
from ...schemas.user_group import UserGroupRead
from fractal_server.app.models.linkusergroup import LinkUserGroup
from fractal_server.app.models.security import UserGroup
from fractal_server.app.models.security import UserOAuth
from fractal_server.app.schemas.user import UserRead
from fractal_server.app.schemas.user_group import UserGroupRead


async def _get_single_user_with_group_names(
Expand Down
36 changes: 36 additions & 0 deletions fractal_server/app/routes/auth/current_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ...schemas.user import UserUpdateStrict
from ._aux_auth import _get_single_user_with_group_names
from fractal_server.app.models import UserOAuth
from fractal_server.app.models import UserSettings
from fractal_server.app.schemas import UserSettingsReadStrict
from fractal_server.app.schemas import UserSettingsUpdateStrict
from fractal_server.app.security import get_user_manager
from fractal_server.app.security import UserManager

Expand Down Expand Up @@ -62,3 +65,36 @@ async def patch_current_user(
patched_user, db
)
return patched_user_with_groups


@router_current_user.get(
"/current-user/settings/", response_model=UserSettingsReadStrict
)
async def get_current_user_settings(
current_user: UserOAuth = Depends(current_active_user),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsReadStrict:
user_settings = await db.get(UserSettings, current_user.user_settings_id)
return user_settings


@router_current_user.patch(
"/current-user/settings/", response_model=UserSettingsReadStrict
)
async def patch_current_user_settings(
settings_update: UserSettingsUpdateStrict,
current_user: UserOAuth = Depends(current_active_user),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsReadStrict:
current_user_settings = await db.get(
UserSettings, current_user.user_settings_id
)

for k, v in settings_update.dict(exclude_unset=True).items():
setattr(current_user_settings, k, v)

db.add(current_user_settings)
await db.commit()
await db.refresh(current_user_settings)

return current_user_settings
39 changes: 39 additions & 0 deletions fractal_server/app/routes/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.models import UserSettings
from fractal_server.app.routes.auth._aux_auth import _user_or_404
from fractal_server.app.schemas import UserSettingsRead
from fractal_server.app.schemas import UserSettingsUpdate
from fractal_server.app.security import get_user_manager
from fractal_server.app.security import UserManager
from fractal_server.logger import set_logger
Expand Down Expand Up @@ -196,3 +199,39 @@ async def list_users(
)

return user_list


@router_users.get(
"/users/{user_id}/settings/", response_model=UserSettingsRead
)
async def get_user_settings(
user_id: int,
superuser: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsRead:

user = await _user_or_404(user_id=user_id, db=db)
user_settings = await db.get(UserSettings, user.user_settings_id)
return user_settings


@router_users.patch(
"/users/{user_id}/settings/", response_model=UserSettingsRead
)
async def patch_user_settings(
user_id: int,
settings_update: UserSettingsUpdate,
superuser: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsRead:
user = await _user_or_404(user_id=user_id, db=db)
user_settings = await db.get(UserSettings, user.user_settings_id)

for k, v in settings_update.dict(exclude_unset=True).items():
setattr(user_settings, k, v)

db.add(user_settings)
await db.commit()
await db.refresh(user_settings)

return user_settings
2 changes: 2 additions & 0 deletions fractal_server/app/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .user import * # noqa: F401, F403
from .user_group import * # noqa: F401, F403
from .user_settings import * # noqa: F401, F403
93 changes: 93 additions & 0 deletions fractal_server/app/schemas/user_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import validator
from pydantic.types import StrictStr

from ._validators import val_absolute_path
from ._validators import val_unique_list
from ._validators import valstr
from fractal_server.string_tools import validate_cmd

__all__ = (
"UserSettingsRead",
"UserSettingsReadStrict",
"UserSettingsUpdate",
"UserSettingsUpdateStrict",
)


class UserSettingsRead(BaseModel):
id: int
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str]
cache_dir: Optional[str] = None


class UserSettingsReadStrict(BaseModel):
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str]
cache_dir: Optional[str] = None


class UserSettingsUpdate(BaseModel):
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: Optional[list[StrictStr]] = None
cache_dir: Optional[str] = None

_ssh_host = validator("ssh_host", allow_reuse=True)(valstr("ssh_host"))
_ssh_username = validator("ssh_username", allow_reuse=True)(
valstr("ssh_username")
)
_ssh_private_key_path = validator(
"ssh_private_key_path", allow_reuse=True
)(val_absolute_path("ssh_private_key_path"))

_ssh_tasks_dir = validator("ssh_tasks_dir", allow_reuse=True)(
val_absolute_path("ssh_tasks_dir")
)
_ssh_jobs_dir = validator("ssh_jobs_dir", allow_reuse=True)(
val_absolute_path("ssh_jobs_dir")
)

_slurm_user = validator("slurm_user", allow_reuse=True)(
valstr("slurm_user")
)
_slurm_accounts = validator("slurm_accounts", allow_reuse=True)(
val_unique_list("slurm_accounts")
)

@validator("cache_dir")
def cache_dir_validator(cls, value):
validate_cmd(value)
return val_absolute_path("cache_dir")(value)


class UserSettingsUpdateStrict(BaseModel):
# SUDO-SLURM
slurm_accounts: Optional[list[StrictStr]] = None
cache_dir: Optional[str] = None

_slurm_accounts = validator("slurm_accounts", allow_reuse=True)(
val_unique_list("slurm_accounts")
)

@validator("cache_dir")
def cache_dir_validator(cls, value):
validate_cmd(value)
return val_absolute_path("cache_dir")(value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""new user settings columns
Revision ID: e1575a65e853
Revises: dfbe4f3a7bc4
Create Date: 2024-09-19 12:14:38.481210
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op


# revision identifiers, used by Alembic.
revision = "e1575a65e853"
down_revision = "dfbe4f3a7bc4"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user_settings", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"slurm_user", sqlmodel.sql.sqltypes.AutoString(), nullable=True
)
)
batch_op.add_column(
sa.Column(
"slurm_accounts",
sa.JSON(),
server_default="[]",
nullable=False,
)
)
batch_op.add_column(
sa.Column(
"cache_dir", sqlmodel.sql.sqltypes.AutoString(), nullable=True
)
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user_settings", schema=None) as batch_op:
batch_op.drop_column("cache_dir")
batch_op.drop_column("slurm_accounts")
batch_op.drop_column("slurm_user")

# ### end Alembic commands ###
86 changes: 86 additions & 0 deletions tests/no_version/test_auth_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,89 @@ async def test_oauth_accounts_list(
f"{PREFIX}/current-user/", json=dict(cache_dir="/foo/bar")
)
assert len(res.json()["oauth_accounts"]) == 1


async def test_get_and_patch_user_settings(registered_superuser_client):

# Register new user
res = await registered_superuser_client.post(
f"{PREFIX}/register/", json=dict(email="a@b.c", password="1234")
)
assert res.status_code == 201
user_id = res.json()["id"]

# Get user settings
res = await registered_superuser_client.get(
f"{PREFIX}/users/{user_id}/settings/",
)
assert res.status_code == 200
for k, v in res.json().items():
if k == "id":
pass
elif k == "slurm_accounts":
assert v == []
else:
assert v is None

# Path user settings
patch = dict(
ssh_host="127.0.0.1",
ssh_username="fractal",
ssh_private_key_path="/tmp/fractal",
ssh_tasks_dir="/tmp/tasks",
# missing "ssh_jobs_dir"
# missing "slurm_user"
slurm_accounts=["foo", "bar"],
cache_dir="/tmp/cache",
)
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/settings/", json=patch
)
debug(res.json())
assert res.status_code == 200

# Assert patch was successful
res = await registered_superuser_client.get(
f"{PREFIX}/users/{user_id}/settings/",
)
for k, v in res.json().items():
if k in patch:
assert v == patch[k]
elif k == "id":
pass
else:
assert v is None

# Get non-existing-user settings
res = await registered_superuser_client.get(f"{PREFIX}/users/42/settings/")
assert res.status_code == 404
# Patch non-existing-user settings
res = await registered_superuser_client.patch(
f"{PREFIX}/users/42/settings/", json=dict()
)
assert res.status_code == 404


async def test_get_and_patch_current_user_settings(registered_client):

res = await registered_client.get(f"{PREFIX}/current-user/settings/")
assert res.status_code == 200
for k, v in res.json().items():
if k == "slurm_accounts":
assert v == []
else:
assert v is None

patch = dict(slurm_accounts=["foo", "bar"], cache_dir="/tmp/foo_cache")
res = await registered_client.patch(
f"{PREFIX}/current-user/settings/", json=patch
)
assert res.status_code == 200

# Assert patch was successful
res = await registered_client.get(f"{PREFIX}/current-user/settings/")
for k, v in res.json().items():
if k in patch:
assert v == patch[k]
else:
assert v is None

0 comments on commit 4234d4f

Please sign in to comment.