Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate sudo-slurm user settings #1785

Merged
merged 4 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions fractal_server/app/routes/aux/validate_user_settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from fastapi import HTTPException
from fastapi import status
from pydantic import BaseModel
from pydantic import ValidationError

from fractal_server.app.db import AsyncSession
from fractal_server.app.models import UserOAuth
from fractal_server.app.models import UserSettings
from fractal_server.app.routes.api.v2._aux_functions import logger
from fractal_server.user_settings import SlurmSshUserSettings
from fractal_server.user_settings import SlurmSudoUserSettings


async def validate_user_settings(
Expand All @@ -26,16 +28,22 @@ async def validate_user_settings(
user_settings = await db.get(UserSettings, user.user_settings_id)

if backend == "slurm_ssh":
try:
SlurmSshUserSettings(**user_settings.model_dump())
except ValidationError as e:
error_msg = (
"User settings are not valid for "
f"FRACTAL_RUNNER_BACKEND='{backend}'. "
f"Original error: {str(e)}"
)
logger.warning(error_msg)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=error_msg,
)
UserSettingsModel = SlurmSshUserSettings
elif backend == "slurm":
UserSettingsModel = SlurmSudoUserSettings
else:
UserSettingsModel = BaseModel

try:
UserSettingsModel(**user_settings.model_dump())
except ValidationError as e:
error_msg = (
"User settings are not valid for "
f"FRACTAL_RUNNER_BACKEND='{backend}'. "
f"Original error: {str(e)}"
)
logger.warning(error_msg)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=error_msg,
)
15 changes: 15 additions & 0 deletions fractal_server/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,18 @@ class SlurmSshUserSettings(BaseModel):
ssh_private_key_path: str
ssh_tasks_dir: str
ssh_jobs_dir: str


class SlurmSudoUserSettings(BaseModel):
"""
Subset of user settings which must be present for task collection and job
execution when using the Slurm-sudo runner.

Attributes:
slurm_user: User to be impersonated via `sudo -u`.
cache_dir:
"""

slurm_user: str
cache_dir: str
slurm_accounts: list[str]
33 changes: 25 additions & 8 deletions tests/no_version/test_unit_user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,48 @@ async def test_validate_user_settings(db):
ssh_tasks_dir="/x",
ssh_username="x",
)
user_with_valid_settings = UserOAuth(
user_with_valid_ssh_settings = UserOAuth(
email="c@c.c",
**common_attributes,
settings=valid_settings,
)
db.add(user_with_valid_settings)
db.add(user_with_valid_ssh_settings)
await db.commit()
await db.refresh(user_with_valid_settings)

debug(user_without_settings)
debug(user_with_invalid_settings)
debug(user_with_valid_settings)
await db.refresh(user_with_valid_ssh_settings)

# User with no settings
with pytest.raises(HTTPException, match="has no settings"):
await validate_user_settings(
user=user_without_settings, backend="slurm_ssh", db=db
)

# User with empty settings: backend="local"
await validate_user_settings(
user=user_with_invalid_settings, backend="local", db=db
)
# User with empty settings: backend="slurm_ssh"
with pytest.raises(
HTTPException, match="validation errors for SlurmSshUserSettings"
):
await validate_user_settings(
user=user_with_invalid_settings, backend="slurm_ssh", db=db
)
# User with empty settings: backend="slurm"
with pytest.raises(
HTTPException, match="validation errors for SlurmSudoUserSettings"
):
await validate_user_settings(
user=user_with_invalid_settings, backend="slurm", db=db
)

# User with valid SSH settings: backend="slurm_ssh"
await validate_user_settings(
user=user_with_valid_settings, backend="slurm_ssh", db=db
user=user_with_valid_ssh_settings, backend="slurm_ssh", db=db
)
# User with valid SSH settings: backend="slurm"
with pytest.raises(
HTTPException, match="validation errors for SlurmSudoUserSettings"
):
await validate_user_settings(
user=user_with_valid_ssh_settings, backend="slurm", db=db
)
30 changes: 18 additions & 12 deletions tests/v2/03_api/test_api_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,22 +182,20 @@ async def test_project_apply_missing_user_attributes(
override_settings_factory,
):
"""
When using the slurm backend, user.slurm_user and user.cache_dir become
required attributes. If they are missing, the apply endpoint fails with a
422 error.
When using the slurm backend, some user.settings attributes are required.
If they are missing, the apply endpoint fails with a 422 error.
"""

override_settings_factory(FRACTAL_RUNNER_BACKEND="slurm")

async with MockCurrentUser(user_kwargs=dict(is_verified=True)) as user:
# Make sure that user.cache_dir was not set
debug(user)
assert user.cache_dir is None
async with MockCurrentUser(
user_kwargs=dict(is_verified=True),
user_settings_dict=dict(something="else"),
) as user:

# Create project, datasets, workflow, task, workflowtask
project = await project_factory_v2(user)
dataset = await dataset_factory_v2(project_id=project.id, name="ds")

workflow = await workflow_factory_v2(project_id=project.id)
task = await task_factory_v2()
await _workflow_insert_task(
Expand All @@ -212,10 +210,14 @@ async def test_project_apply_missing_user_attributes(
)
debug(res.json())
assert res.status_code == 422
assert "user.cache_dir=None" in res.json()["detail"]
assert "User settings are not valid" in res.json()["detail"]
assert (
"validation errors for SlurmSudoUserSettings"
in res.json()["detail"]
)

user.cache_dir = "/tmp"
user.slurm_user = None
user.settings.cache_dir = "/tmp"
user.settings.slurm_user = None
await db.commit()

res = await client.post(
Expand All @@ -225,7 +227,11 @@ async def test_project_apply_missing_user_attributes(
)
debug(res.json())
assert res.status_code == 422
assert "user.slurm_user=None" in res.json()["detail"]
assert "User settings are not valid" in res.json()["detail"]
assert (
"validation error for SlurmSudoUserSettings"
in res.json()["detail"]
)


async def test_project_apply_workflow_subset(
Expand Down
16 changes: 12 additions & 4 deletions tests/v2/08_full_workflow/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ async def full_workflow(
dataset_factory_v2,
tasks: dict[str, TaskV2],
user_kwargs: Optional[dict] = None,
user_settings_dict: Optional[dict] = None,
):
if user_kwargs is None:
user_kwargs = {}

async with MockCurrentUser(
user_kwargs={"is_verified": True, **user_kwargs}
user_kwargs={"is_verified": True, **user_kwargs},
user_settings_dict=user_settings_dict,
) as user:
project = await project_factory_v2(user)
project_id = project.id
Expand Down Expand Up @@ -200,14 +202,16 @@ async def full_workflow_TaskExecutionError(
dataset_factory_v2,
tasks: dict[str, TaskV2],
user_kwargs: Optional[dict] = None,
user_settings_dict: Optional[dict] = None,
):

if user_kwargs is None:
user_kwargs = {}

EXPECTED_STATUSES = {}
async with MockCurrentUser(
user_kwargs={"is_verified": True, **user_kwargs}
user_kwargs={"is_verified": True, **user_kwargs},
user_settings_dict=user_settings_dict,
) as user:
project = await project_factory_v2(user)
project_id = project.id
Expand Down Expand Up @@ -329,12 +333,14 @@ async def non_executable_task_command(
dataset_factory_v2,
task_factory_v2,
user_kwargs: Optional[dict] = None,
user_settings_dict: Optional[dict] = None,
):
if user_kwargs is None:
user_kwargs = {}

async with MockCurrentUser(
user_kwargs={"is_verified": True, **user_kwargs}
user_kwargs={"is_verified": True, **user_kwargs},
user_settings_dict=user_settings_dict,
) as user:
# Create task
task = await task_factory_v2(
Expand Down Expand Up @@ -401,13 +407,15 @@ async def failing_workflow_UnknownError(
task_factory,
task_factory_v2,
user_kwargs: Optional[dict] = None,
user_settings_dict: Optional[dict] = None,
):
if user_kwargs is None:
user_kwargs = {}

EXPECTED_STATUSES = {}
async with MockCurrentUser(
user_kwargs={"is_verified": True, **user_kwargs}
user_kwargs={"is_verified": True, **user_kwargs},
user_settings_dict=user_settings_dict,
) as user:
project = await project_factory_v2(user)
project_id = project.id
Expand Down
29 changes: 28 additions & 1 deletion tests/v2/08_full_workflow/test_full_workflow_slurm_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ async def test_full_workflow_slurm(
await full_workflow(
MockCurrentUser=MockCurrentUser,
user_kwargs={"cache_dir": str(tmp777_path / "user_cache_dir-slurm")},
user_settings_dict=dict(
slurm_user=SLURM_USER,
slurm_accounts=[],
cache_dir=str(tmp777_path / "user_cache_dir-slurm"),
),
project_factory_v2=project_factory_v2,
dataset_factory_v2=dataset_factory_v2,
workflow_factory_v2=workflow_factory_v2,
Expand Down Expand Up @@ -75,6 +80,11 @@ async def test_full_workflow_TaskExecutionError_slurm(
await full_workflow_TaskExecutionError(
MockCurrentUser=MockCurrentUser,
user_kwargs={"cache_dir": str(tmp777_path / "user_cache_dir-slurm")},
user_settings_dict=dict(
slurm_user=SLURM_USER,
slurm_accounts=[],
cache_dir=str(tmp777_path / "user_cache_dir-slurm"),
),
project_factory_v2=project_factory_v2,
dataset_factory_v2=dataset_factory_v2,
workflow_factory_v2=workflow_factory_v2,
Expand Down Expand Up @@ -109,7 +119,14 @@ async def test_failing_workflow_JobExecutionError(

user_cache_dir = str(tmp777_path / "user_cache_dir-slurm")
user_kwargs = dict(cache_dir=user_cache_dir, is_verified=True)
async with MockCurrentUser(user_kwargs=user_kwargs) as user:
async with MockCurrentUser(
user_kwargs=user_kwargs,
user_settings_dict=dict(
slurm_user=SLURM_USER,
slurm_accounts=[],
cache_dir=str(tmp777_path / "user_cache_dir-slurm"),
),
) as user:
project = await project_factory_v2(user)
project_id = project.id
dataset = await dataset_factory_v2(
Expand Down Expand Up @@ -244,6 +261,11 @@ async def test_non_executable_task_command_slurm(
await non_executable_task_command(
MockCurrentUser=MockCurrentUser,
user_kwargs={"cache_dir": str(tmp777_path / "user_cache_dir-slurm")},
user_settings_dict=dict(
slurm_user=SLURM_USER,
slurm_accounts=[],
cache_dir=str(tmp777_path / "user_cache_dir-slurm"),
),
client=client,
testdata_path=testdata_path,
project_factory_v2=project_factory_v2,
Expand Down Expand Up @@ -283,6 +305,11 @@ async def test_failing_workflow_UnknownError_slurm(
await failing_workflow_UnknownError(
MockCurrentUser=MockCurrentUser,
user_kwargs={"cache_dir": str(tmp777_path / "user_cache_dir-slurm")},
user_settings_dict=dict(
slurm_user=SLURM_USER,
slurm_accounts=[],
cache_dir=str(tmp777_path / "user_cache_dir-slurm"),
),
client=client,
monkeypatch=monkeypatch,
project_factory_v2=project_factory_v2,
Expand Down
Loading