Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User Settings API #1780

Merged
merged 15 commits into from
Sep 19, 2024
11 changes: 10 additions & 1 deletion fractal_server/app/models/user_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

from sqlalchemy import Column
from sqlalchemy.types import JSON
from sqlmodel import Field
from sqlmodel import SQLModel

Expand All @@ -9,9 +11,16 @@ class UserSettings(SQLModel, table=True):

id: Optional[int] = Field(default=None, primary_key=True)

# Actual settings columns
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None

# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str] = Field(
sa_column=Column(JSON, server_default="[]", nullable=False)
)
cache_dir: Optional[str] = None
10 changes: 5 additions & 5 deletions fractal_server/app/routes/auth/_aux_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from ...models.linkusergroup import LinkUserGroup
from ...models.security import UserGroup
from ...models.security import UserOAuth
from ...schemas.user import UserRead
from ...schemas.user_group import UserGroupRead
from fractal_server.app.models.linkusergroup import LinkUserGroup
from fractal_server.app.models.security import UserGroup
from fractal_server.app.models.security import UserOAuth
from fractal_server.app.schemas.user import UserRead
from fractal_server.app.schemas.user_group import UserGroupRead


async def _get_single_user_with_group_names(
Expand Down
36 changes: 36 additions & 0 deletions fractal_server/app/routes/auth/current_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ...schemas.user import UserUpdateStrict
from ._aux_auth import _get_single_user_with_group_names
from fractal_server.app.models import UserOAuth
from fractal_server.app.models import UserSettings
from fractal_server.app.schemas import UserSettingsReadStrict
from fractal_server.app.schemas import UserSettingsUpdateStrict
from fractal_server.app.security import get_user_manager
from fractal_server.app.security import UserManager

Expand Down Expand Up @@ -62,3 +65,36 @@ async def patch_current_user(
patched_user, db
)
return patched_user_with_groups


@router_current_user.get(
"/current-user/settings/", response_model=UserSettingsReadStrict
)
async def get_current_user_settings(
current_user: UserOAuth = Depends(current_active_user),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsReadStrict:
user_settings = await db.get(UserSettings, current_user.user_settings_id)
return user_settings


@router_current_user.patch(
"/current-user/settings/", response_model=UserSettingsReadStrict
)
async def patch_current_user_settings(
settings_update: UserSettingsUpdateStrict,
current_user: UserOAuth = Depends(current_active_user),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsReadStrict:
current_user_settings = await db.get(
UserSettings, current_user.user_settings_id
)

for k, v in settings_update.dict(exclude_unset=True).items():
setattr(current_user_settings, k, v)

db.add(current_user_settings)
await db.commit()
await db.refresh(current_user_settings)

return current_user_settings
39 changes: 39 additions & 0 deletions fractal_server/app/routes/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from fractal_server.app.models import LinkUserGroup
from fractal_server.app.models import UserGroup
from fractal_server.app.models import UserOAuth
from fractal_server.app.models import UserSettings
from fractal_server.app.routes.auth._aux_auth import _user_or_404
from fractal_server.app.schemas import UserSettingsRead
from fractal_server.app.schemas import UserSettingsUpdate
from fractal_server.app.security import get_user_manager
from fractal_server.app.security import UserManager
from fractal_server.logger import set_logger
Expand Down Expand Up @@ -196,3 +199,39 @@ async def list_users(
)

return user_list


@router_users.get(
"/users/{user_id}/settings/", response_model=UserSettingsRead
)
async def get_user_settings(
user_id: int,
superuser: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsRead:

user = await _user_or_404(user_id=user_id, db=db)
user_settings = await db.get(UserSettings, user.user_settings_id)
return user_settings


@router_users.patch(
"/users/{user_id}/settings/", response_model=UserSettingsRead
)
async def patch_user_settings(
user_id: int,
settings_update: UserSettingsUpdate,
superuser: UserOAuth = Depends(current_active_superuser),
db: AsyncSession = Depends(get_async_db),
) -> UserSettingsRead:
user = await _user_or_404(user_id=user_id, db=db)
user_settings = await db.get(UserSettings, user.user_settings_id)

for k, v in settings_update.dict(exclude_unset=True).items():
setattr(user_settings, k, v)

db.add(user_settings)
await db.commit()
await db.refresh(user_settings)

return user_settings
2 changes: 2 additions & 0 deletions fractal_server/app/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .user import * # noqa: F401, F403
from .user_group import * # noqa: F401, F403
from .user_settings import * # noqa: F401, F403
93 changes: 93 additions & 0 deletions fractal_server/app/schemas/user_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import validator
from pydantic.types import StrictStr

from ._validators import val_absolute_path
from ._validators import val_unique_list
from ._validators import valstr
from fractal_server.string_tools import validate_cmd

__all__ = (
"UserSettingsRead",
"UserSettingsReadStrict",
"UserSettingsUpdate",
"UserSettingsUpdateStrict",
)


class UserSettingsRead(BaseModel):
id: int
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str]
cache_dir: Optional[str] = None


class UserSettingsReadStrict(BaseModel):
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: list[str]
cache_dir: Optional[str] = None


class UserSettingsUpdate(BaseModel):
# SSH-SLURM
ssh_host: Optional[str] = None
ssh_username: Optional[str] = None
ssh_private_key_path: Optional[str] = None
ssh_tasks_dir: Optional[str] = None
ssh_jobs_dir: Optional[str] = None
# SUDO-SLURM
slurm_user: Optional[str] = None
slurm_accounts: Optional[list[StrictStr]] = None
cache_dir: Optional[str] = None

_ssh_host = validator("ssh_host", allow_reuse=True)(valstr("ssh_host"))
_ssh_username = validator("ssh_username", allow_reuse=True)(
valstr("ssh_username")
)
_ssh_private_key_path = validator(
"ssh_private_key_path", allow_reuse=True
)(val_absolute_path("ssh_private_key_path"))

_ssh_tasks_dir = validator("ssh_tasks_dir", allow_reuse=True)(
val_absolute_path("ssh_tasks_dir")
)
_ssh_jobs_dir = validator("ssh_jobs_dir", allow_reuse=True)(
val_absolute_path("ssh_jobs_dir")
)

_slurm_user = validator("slurm_user", allow_reuse=True)(
valstr("slurm_user")
)
_slurm_accounts = validator("slurm_accounts", allow_reuse=True)(
val_unique_list("slurm_accounts")
)

@validator("cache_dir")
def cache_dir_validator(cls, value):
validate_cmd(value)
return val_absolute_path("cache_dir")(value)


class UserSettingsUpdateStrict(BaseModel):
# SUDO-SLURM
slurm_accounts: Optional[list[StrictStr]] = None
cache_dir: Optional[str] = None

_slurm_accounts = validator("slurm_accounts", allow_reuse=True)(
val_unique_list("slurm_accounts")
)

@validator("cache_dir")
def cache_dir_validator(cls, value):
validate_cmd(value)
return val_absolute_path("cache_dir")(value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""new user settings columns

Revision ID: e1575a65e853
Revises: dfbe4f3a7bc4
Create Date: 2024-09-19 12:14:38.481210

"""
import sqlalchemy as sa
import sqlmodel
from alembic import op


# revision identifiers, used by Alembic.
revision = "e1575a65e853"
down_revision = "dfbe4f3a7bc4"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user_settings", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"slurm_user", sqlmodel.sql.sqltypes.AutoString(), nullable=True
)
)
batch_op.add_column(
sa.Column(
"slurm_accounts",
sa.JSON(),
server_default="[]",
nullable=False,
)
)
batch_op.add_column(
sa.Column(
"cache_dir", sqlmodel.sql.sqltypes.AutoString(), nullable=True
)
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user_settings", schema=None) as batch_op:
batch_op.drop_column("cache_dir")
batch_op.drop_column("slurm_accounts")
batch_op.drop_column("slurm_user")

# ### end Alembic commands ###
86 changes: 86 additions & 0 deletions tests/no_version/test_auth_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,89 @@ async def test_oauth_accounts_list(
f"{PREFIX}/current-user/", json=dict(cache_dir="/foo/bar")
)
assert len(res.json()["oauth_accounts"]) == 1


async def test_get_and_patch_user_settings(registered_superuser_client):

# Register new user
res = await registered_superuser_client.post(
f"{PREFIX}/register/", json=dict(email="a@b.c", password="1234")
)
assert res.status_code == 201
user_id = res.json()["id"]

# Get user settings
res = await registered_superuser_client.get(
f"{PREFIX}/users/{user_id}/settings/",
)
assert res.status_code == 200
for k, v in res.json().items():
if k == "id":
pass
elif k == "slurm_accounts":
assert v == []
else:
assert v is None

# Path user settings
patch = dict(
ssh_host="127.0.0.1",
ssh_username="fractal",
ssh_private_key_path="/tmp/fractal",
ssh_tasks_dir="/tmp/tasks",
# missing "ssh_jobs_dir"
# missing "slurm_user"
slurm_accounts=["foo", "bar"],
cache_dir="/tmp/cache",
)
res = await registered_superuser_client.patch(
f"{PREFIX}/users/{user_id}/settings/", json=patch
)
debug(res.json())
assert res.status_code == 200

# Assert patch was successful
res = await registered_superuser_client.get(
f"{PREFIX}/users/{user_id}/settings/",
)
for k, v in res.json().items():
if k in patch:
assert v == patch[k]
elif k == "id":
pass
else:
assert v is None

# Get non-existing-user settings
res = await registered_superuser_client.get(f"{PREFIX}/users/42/settings/")
assert res.status_code == 404
# Patch non-existing-user settings
res = await registered_superuser_client.patch(
f"{PREFIX}/users/42/settings/", json=dict()
)
assert res.status_code == 404


async def test_get_and_patch_current_user_settings(registered_client):

res = await registered_client.get(f"{PREFIX}/current-user/settings/")
assert res.status_code == 200
for k, v in res.json().items():
if k == "slurm_accounts":
assert v == []
else:
assert v is None

patch = dict(slurm_accounts=["foo", "bar"], cache_dir="/tmp/foo_cache")
res = await registered_client.patch(
f"{PREFIX}/current-user/settings/", json=patch
)
assert res.status_code == 200

# Assert patch was successful
res = await registered_client.get(f"{PREFIX}/current-user/settings/")
for k, v in res.json().items():
if k in patch:
assert v == patch[k]
else:
assert v is None
Loading