diff --git a/CHANGELOG.md b/CHANGELOG.md index 99161fb9cc..cdadf42859 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ **Note**: Numbers like (\#1234) point to closed Pull Requests on the fractal-server repository. +# 2.3.5 + +> WARNING: The `pre_submission_commands` SLURM configuration is included as an +> experimental feature, since it is still not useful for its main intended +> goal (calling `module load` before running `sbatch`). + +* SLURM runners + * Expose `gpus` SLURM parameter (\#1678). + * For SSH executor, add `pre_submission_commands` (\#1678). + * Removed obsolete arguments from `get_slurm_config` function (\#1678). +* SSH features: + * Add `FractalSSH.write_remote_file` method (\#1678). + # 2.3.4 * SSH SLURM runner: diff --git a/benchmarks/runner/benchmark_runner.py b/benchmarks/runner/benchmark_runner.py index 3ac2f3b23e..e7ec31167a 100644 --- a/benchmarks/runner/benchmark_runner.py +++ b/benchmarks/runner/benchmark_runner.py @@ -59,12 +59,12 @@ def mock_venv(tmp_path: str) -> dict: args[ "command_non_parallel" ] = f"{python} {src_dir / task['executable_non_parallel']}" - args["meta_non_paralell"] = task.get("meta_non_paralell") + args["meta_non_parallel"] = task.get("meta_non_parallel") if task.get("executable_parallel"): args[ "command_parallel" ] = f"{python} {src_dir / task['executable_parallel']}" - args["meta_paralell"] = task.get("meta_paralell") + args["meta_parallel"] = task.get("meta_parallel") t = TaskV2Mock( id=ind, diff --git a/benchmarks/runner/mocks.py b/benchmarks/runner/mocks.py index ea7a81de93..c7da9e0eb2 100644 --- a/benchmarks/runner/mocks.py +++ b/benchmarks/runner/mocks.py @@ -38,8 +38,8 @@ class TaskV2Mock(BaseModel): command_non_parallel: Optional[str] = None command_parallel: Optional[str] = None - meta_non_paralell: Optional[dict[str, Any]] = Field(default_factory=dict) - meta_paralell: Optional[dict[str, Any]] = Field(default_factory=dict) + meta_non_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) + meta_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) type: Optional[str] @root_validator(pre=False) diff --git a/fractal_server/app/runner/executors/slurm/_slurm_config.py b/fractal_server/app/runner/executors/slurm/_slurm_config.py index 5cb19398dc..1dc32276ac 100644 --- a/fractal_server/app/runner/executors/slurm/_slurm_config.py +++ b/fractal_server/app/runner/executors/slurm/_slurm_config.py @@ -62,6 +62,8 @@ class _SlurmConfigSet(BaseModel, extra=Extra.forbid): time: Optional[str] account: Optional[str] extra_lines: Optional[list[str]] + pre_submission_commands: Optional[list[str]] + gpus: Optional[str] class _BatchingConfigSet(BaseModel, extra=Extra.forbid): @@ -219,6 +221,7 @@ class SlurmConfig(BaseModel, extra=Extra.forbid): constraint: Corresponds to SLURM option. gres: Corresponds to SLURM option. account: Corresponds to SLURM option. + gpus: Corresponds to SLURM option. time: Corresponds to SLURM option (WARNING: not fully supported). prefix: Prefix of configuration lines in SLURM submission scripts. shebang_line: Shebang line for SLURM submission scripts. @@ -240,6 +243,8 @@ class SlurmConfig(BaseModel, extra=Extra.forbid): Key-value pairs to be included as `export`-ed variables in SLURM submission script, after prepending values with the user's cache directory. + pre_submission_commands: List of commands to be prepended to the sbatch + command. """ # Required SLURM parameters (note that the integer attributes are those @@ -254,6 +259,7 @@ class SlurmConfig(BaseModel, extra=Extra.forbid): job_name: Optional[str] = None constraint: Optional[str] = None gres: Optional[str] = None + gpus: Optional[str] = None time: Optional[str] = None account: Optional[str] = None @@ -274,6 +280,8 @@ class SlurmConfig(BaseModel, extra=Extra.forbid): target_num_jobs: int max_num_jobs: int + pre_submission_commands: list[str] = Field(default_factory=list) + def _sorted_extra_lines(self) -> list[str]: """ Return a copy of `self.extra_lines`, where lines starting with @@ -340,7 +348,14 @@ def to_sbatch_preamble( f"{self.prefix} --cpus-per-task={self.cpus_per_task}", f"{self.prefix} --mem={mem_per_job_MB}M", ] - for key in ["job_name", "constraint", "gres", "time", "account"]: + for key in [ + "job_name", + "constraint", + "gres", + "gpus", + "time", + "account", + ]: value = getattr(self, key) if value is not None: # Handle the `time` parameter diff --git a/fractal_server/app/runner/executors/slurm/ssh/executor.py b/fractal_server/app/runner/executors/slurm/ssh/executor.py index 04b8659eac..6771607758 100644 --- a/fractal_server/app/runner/executors/slurm/ssh/executor.py +++ b/fractal_server/app/runner/executors/slurm/ssh/executor.py @@ -869,9 +869,22 @@ def _submit_job(self, job: SlurmJob) -> tuple[Future, str]: # Submit job to SLURM, and get jobid sbatch_command = f"sbatch --parsable {job.slurm_script_remote}" - sbatch_stdout = self.fractal_ssh.run_command( - cmd=sbatch_command, - ) + pre_submission_cmds = job.slurm_config.pre_submission_commands + if len(pre_submission_cmds) == 0: + sbatch_stdout = self.fractal_ssh.run_command(cmd=sbatch_command) + else: + logger.debug(f"Now using {pre_submission_cmds=}") + script_lines = pre_submission_cmds + [sbatch_command] + script_content = "\n".join(script_lines) + script_content = f"{script_content}\n" + script_path_remote = ( + f"{job.slurm_script_remote.as_posix()}_wrapper.sh" + ) + self.fractal_ssh.write_remote_file( + path=script_path_remote, content=script_content + ) + cmd = f"bash {script_path_remote}" + sbatch_stdout = self.fractal_ssh.run_command(cmd=cmd) # Extract SLURM job ID from stdout try: @@ -881,7 +894,9 @@ def _submit_job(self, job: SlurmJob) -> tuple[Future, str]: error_msg = ( f"Submit command `{sbatch_command}` returned " f"`{stdout=}` which cannot be cast to an integer " - f"SLURM-job ID. Original error:\n{str(e)}" + f"SLURM-job ID.\n" + f"Note that {pre_submission_cmds=}.\n" + f"Original error:\n{str(e)}" ) logger.error(error_msg) raise JobExecutionError(info=error_msg) diff --git a/fractal_server/app/runner/executors/slurm/sudo/executor.py b/fractal_server/app/runner/executors/slurm/sudo/executor.py index 93d395ad46..e0cb2bda44 100644 --- a/fractal_server/app/runner/executors/slurm/sudo/executor.py +++ b/fractal_server/app/runner/executors/slurm/sudo/executor.py @@ -1121,6 +1121,12 @@ def _start( slurm_err_path=str(job.slurm_stderr), ) + # Print warning for ignored parameter + if len(job.slurm_config.pre_submission_commands) > 0: + logger.warning( + f"Ignoring {job.slurm_config.pre_submission_commands=}." + ) + # Submit job via sbatch, and retrieve jobid # Write script content to a job.slurm_script diff --git a/fractal_server/app/runner/v2/__init__.py b/fractal_server/app/runner/v2/__init__.py index 8c8ad8d47e..45644aae5c 100644 --- a/fractal_server/app/runner/v2/__init__.py +++ b/fractal_server/app/runner/v2/__init__.py @@ -36,8 +36,8 @@ from ._local_experimental import ( process_workflow as local_experimental_process_workflow, ) -from ._slurm import process_workflow as slurm_sudo_process_workflow from ._slurm_ssh import process_workflow as slurm_ssh_process_workflow +from ._slurm_sudo import process_workflow as slurm_sudo_process_workflow from .handle_failed_job import assemble_filters_failed_job from .handle_failed_job import assemble_history_failed_job from .handle_failed_job import assemble_images_failed_job diff --git a/fractal_server/app/runner/v2/_slurm/get_slurm_config.py b/fractal_server/app/runner/v2/_slurm/get_slurm_config.py deleted file mode 100644 index 015d065e2e..0000000000 --- a/fractal_server/app/runner/v2/_slurm/get_slurm_config.py +++ /dev/null @@ -1,182 +0,0 @@ -from pathlib import Path -from typing import Literal -from typing import Optional - -from fractal_server.app.models.v2 import WorkflowTaskV2 -from fractal_server.app.runner.executors.slurm._slurm_config import ( - _parse_mem_value, -) -from fractal_server.app.runner.executors.slurm._slurm_config import ( - load_slurm_config_file, -) -from fractal_server.app.runner.executors.slurm._slurm_config import logger -from fractal_server.app.runner.executors.slurm._slurm_config import SlurmConfig -from fractal_server.app.runner.executors.slurm._slurm_config import ( - SlurmConfigError, -) - - -def get_slurm_config( - wftask: WorkflowTaskV2, - workflow_dir_local: Path, - workflow_dir_remote: Path, - which_type: Literal["non_parallel", "parallel"], - config_path: Optional[Path] = None, -) -> SlurmConfig: - """ - Prepare a `SlurmConfig` configuration object - - The argument `which_type` determines whether we use `wftask.meta_parallel` - or `wftask.meta_non_parallel`. In the following descritpion, let us assume - that `which_type="parallel"`. - - The sources for `SlurmConfig` attributes, in increasing priority order, are - - 1. The general content of the Fractal SLURM configuration file. - 2. The GPU-specific content of the Fractal SLURM configuration file, if - appropriate. - 3. Properties in `wftask.meta_parallel` (which typically include those in - `wftask.task.meta_parallel`). Note that `wftask.meta_parallel` may be - `None`. - - Arguments: - wftask: - WorkflowTask for which the SLURM configuration is is to be - prepared. - workflow_dir_local: - Server-owned directory to store all task-execution-related relevant - files (inputs, outputs, errors, and all meta files related to the - job execution). Note: users cannot write directly to this folder. - workflow_dir_remote: - User-side directory with the same scope as `workflow_dir_local`, - and where a user can write. - config_path: - Path of a Fractal SLURM configuration file; if `None`, use - `FRACTAL_SLURM_CONFIG_FILE` variable from settings. - which_type: - Determines whether to use `meta_parallel` or `meta_non_parallel`. - - Returns: - slurm_config: - The SlurmConfig object - """ - - if which_type == "non_parallel": - wftask_meta = wftask.meta_non_parallel - elif which_type == "parallel": - wftask_meta = wftask.meta_parallel - else: - raise ValueError( - f"get_slurm_config received invalid argument {which_type=}." - ) - - logger.debug( - "[get_slurm_config] WorkflowTask meta attribute: {wftask_meta=}" - ) - - # Incorporate slurm_env.default_slurm_config - slurm_env = load_slurm_config_file(config_path=config_path) - slurm_dict = slurm_env.default_slurm_config.dict( - exclude_unset=True, exclude={"mem"} - ) - if slurm_env.default_slurm_config.mem: - slurm_dict["mem_per_task_MB"] = slurm_env.default_slurm_config.mem - - # Incorporate slurm_env.batching_config - for key, value in slurm_env.batching_config.dict().items(): - slurm_dict[key] = value - - # Incorporate slurm_env.user_local_exports - slurm_dict["user_local_exports"] = slurm_env.user_local_exports - - logger.debug( - "[get_slurm_config] Fractal SLURM configuration file: " - f"{slurm_env.dict()=}" - ) - - # GPU-related options - # Notes about priority: - # 1. This block of definitions takes priority over other definitions from - # slurm_env which are not under the `needs_gpu` subgroup - # 2. This block of definitions has lower priority than whatever comes next - # (i.e. from WorkflowTask.meta). - if wftask_meta is not None: - needs_gpu = wftask_meta.get("needs_gpu", False) - else: - needs_gpu = False - logger.debug(f"[get_slurm_config] {needs_gpu=}") - if needs_gpu: - for key, value in slurm_env.gpu_slurm_config.dict( - exclude_unset=True, exclude={"mem"} - ).items(): - slurm_dict[key] = value - if slurm_env.gpu_slurm_config.mem: - slurm_dict["mem_per_task_MB"] = slurm_env.gpu_slurm_config.mem - - # Number of CPUs per task, for multithreading - if wftask_meta is not None and "cpus_per_task" in wftask_meta: - cpus_per_task = int(wftask_meta["cpus_per_task"]) - slurm_dict["cpus_per_task"] = cpus_per_task - - # Required memory per task, in MB - if wftask_meta is not None and "mem" in wftask_meta: - raw_mem = wftask_meta["mem"] - mem_per_task_MB = _parse_mem_value(raw_mem) - slurm_dict["mem_per_task_MB"] = mem_per_task_MB - - # Job name - if wftask.is_legacy_task: - job_name = wftask.task_legacy.name.replace(" ", "_") - else: - job_name = wftask.task.name.replace(" ", "_") - slurm_dict["job_name"] = job_name - - # Optional SLURM arguments and extra lines - if wftask_meta is not None: - account = wftask_meta.get("account", None) - if account is not None: - error_msg = ( - f"Invalid {account=} property in WorkflowTask `meta` " - "attribute.\n" - "SLURM account must be set in the request body of the " - "apply-workflow endpoint, or by modifying the user properties." - ) - logger.error(error_msg) - raise SlurmConfigError(error_msg) - for key in ["time", "gres", "constraint"]: - value = wftask_meta.get(key, None) - if value: - slurm_dict[key] = value - if wftask_meta is not None: - extra_lines = wftask_meta.get("extra_lines", []) - else: - extra_lines = [] - extra_lines = slurm_dict.get("extra_lines", []) + extra_lines - if len(set(extra_lines)) != len(extra_lines): - logger.debug( - "[get_slurm_config] Removing repeated elements " - f"from {extra_lines=}." - ) - extra_lines = list(set(extra_lines)) - slurm_dict["extra_lines"] = extra_lines - - # Job-batching parameters (if None, they will be determined heuristically) - if wftask_meta is not None: - tasks_per_job = wftask_meta.get("tasks_per_job", None) - parallel_tasks_per_job = wftask_meta.get( - "parallel_tasks_per_job", None - ) - else: - tasks_per_job = None - parallel_tasks_per_job = None - slurm_dict["tasks_per_job"] = tasks_per_job - slurm_dict["parallel_tasks_per_job"] = parallel_tasks_per_job - - # Put everything together - logger.debug( - "[get_slurm_config] Now create a SlurmConfig object based " - f"on {slurm_dict=}" - ) - slurm_config = SlurmConfig(**slurm_dict) - - return slurm_config diff --git a/fractal_server/app/runner/v2/_slurm_common/__init__.py b/fractal_server/app/runner/v2/_slurm_common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fractal_server/app/runner/v2/_slurm_ssh/get_slurm_config.py b/fractal_server/app/runner/v2/_slurm_common/get_slurm_config.py similarity index 90% rename from fractal_server/app/runner/v2/_slurm_ssh/get_slurm_config.py rename to fractal_server/app/runner/v2/_slurm_common/get_slurm_config.py index 80d5663c30..362873cd96 100644 --- a/fractal_server/app/runner/v2/_slurm_ssh/get_slurm_config.py +++ b/fractal_server/app/runner/v2/_slurm_common/get_slurm_config.py @@ -18,8 +18,6 @@ def get_slurm_config( wftask: WorkflowTaskV2, - workflow_dir_local: Path, - workflow_dir_remote: Path, which_type: Literal["non_parallel", "parallel"], config_path: Optional[Path] = None, ) -> SlurmConfig: @@ -43,13 +41,6 @@ def get_slurm_config( wftask: WorkflowTask for which the SLURM configuration is is to be prepared. - workflow_dir_local: - Server-owned directory to store all task-execution-related relevant - files (inputs, outputs, errors, and all meta files related to the - job execution). Note: users cannot write directly to this folder. - workflow_dir_remote: - User-side directory with the same scope as `workflow_dir_local`, - and where a user can write. config_path: Path of a Fractal SLURM configuration file; if `None`, use `FRACTAL_SLURM_CONFIG_FILE` variable from settings. @@ -99,13 +90,13 @@ def get_slurm_config( # 1. This block of definitions takes priority over other definitions from # slurm_env which are not under the `needs_gpu` subgroup # 2. This block of definitions has lower priority than whatever comes next - # (i.e. from WorkflowTask.meta). + # (i.e. from WorkflowTask.meta_parallel). if wftask_meta is not None: needs_gpu = wftask_meta.get("needs_gpu", False) else: needs_gpu = False logger.debug(f"[get_slurm_config] {needs_gpu=}") - if needs_gpu and slurm_env.gpu_slurm_config is not None: # FIXME + if needs_gpu: for key, value in slurm_env.gpu_slurm_config.dict( exclude_unset=True, exclude={"mem"} ).items(): @@ -143,9 +134,9 @@ def get_slurm_config( ) logger.error(error_msg) raise SlurmConfigError(error_msg) - for key in ["time", "gres", "constraint"]: + for key in ["time", "gres", "gpus", "constraint"]: value = wftask_meta.get(key, None) - if value: + if value is not None: slurm_dict[key] = value if wftask_meta is not None: extra_lines = wftask_meta.get("extra_lines", []) diff --git a/fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py b/fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py index 5738a7d332..8cc394e910 100644 --- a/fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py +++ b/fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py @@ -17,8 +17,10 @@ from typing import Literal from ...task_files import get_task_file_paths -from .get_slurm_config import get_slurm_config from fractal_server.app.models.v2 import WorkflowTaskV2 +from fractal_server.app.runner.v2._slurm_common.get_slurm_config import ( + get_slurm_config, +) def _slurm_submit_setup( @@ -62,8 +64,6 @@ def _slurm_submit_setup( # Get SlurmConfig object slurm_config = get_slurm_config( wftask=wftask, - workflow_dir_local=workflow_dir_local, - workflow_dir_remote=workflow_dir_remote, which_type=which_type, ) diff --git a/fractal_server/app/runner/v2/_slurm/__init__.py b/fractal_server/app/runner/v2/_slurm_sudo/__init__.py similarity index 100% rename from fractal_server/app/runner/v2/_slurm/__init__.py rename to fractal_server/app/runner/v2/_slurm_sudo/__init__.py diff --git a/fractal_server/app/runner/v2/_slurm/_submit_setup.py b/fractal_server/app/runner/v2/_slurm_sudo/_submit_setup.py similarity index 95% rename from fractal_server/app/runner/v2/_slurm/_submit_setup.py rename to fractal_server/app/runner/v2/_slurm_sudo/_submit_setup.py index 5738a7d332..8cc394e910 100644 --- a/fractal_server/app/runner/v2/_slurm/_submit_setup.py +++ b/fractal_server/app/runner/v2/_slurm_sudo/_submit_setup.py @@ -17,8 +17,10 @@ from typing import Literal from ...task_files import get_task_file_paths -from .get_slurm_config import get_slurm_config from fractal_server.app.models.v2 import WorkflowTaskV2 +from fractal_server.app.runner.v2._slurm_common.get_slurm_config import ( + get_slurm_config, +) def _slurm_submit_setup( @@ -62,8 +64,6 @@ def _slurm_submit_setup( # Get SlurmConfig object slurm_config = get_slurm_config( wftask=wftask, - workflow_dir_local=workflow_dir_local, - workflow_dir_remote=workflow_dir_remote, which_type=which_type, ) diff --git a/fractal_server/ssh/_fabric.py b/fractal_server/ssh/_fabric.py index 79332b9cbf..d426905435 100644 --- a/fractal_server/ssh/_fabric.py +++ b/fractal_server/ssh/_fabric.py @@ -306,6 +306,28 @@ def remove_folder( cmd = f"rm -r {folder}" self.run_command(cmd=cmd) + def write_remote_file( + self, + *, + path: str, + content: str, + lock_timeout: Optional[float] = None, + ) -> None: + """ + Open a remote file via SFTP and write it. + + Args: + path: Absolute path + contents: File contents + lock_timeout: + """ + actual_lock_timeout = self.default_lock_timeout + if lock_timeout is not None: + actual_lock_timeout = lock_timeout + with self.acquire_timeout(timeout=actual_lock_timeout): + with self.sftp().open(filename=path, mode="w") as f: + f.write(content) + def get_ssh_connection( *, diff --git a/tests/fixtures_docker.py b/tests/fixtures_docker.py index 0a8e39fee8..54b68d05b1 100644 --- a/tests/fixtures_docker.py +++ b/tests/fixtures_docker.py @@ -187,6 +187,6 @@ def fractal_ssh( forward_agent=False, connect_kwargs={"key_filename": ssh_private_key}, ) as connection: - fractal_conn = FractalSSH(connection=connection) - fractal_conn.check_connection() - yield fractal_conn + fractal_ssh_object = FractalSSH(connection=connection) + fractal_ssh_object.check_connection() + yield fractal_ssh_object diff --git a/tests/v1/05_backend/test_unit_slurm_config.py b/tests/v1/05_backend/test_unit_slurm_config_v1.py similarity index 100% rename from tests/v1/05_backend/test_unit_slurm_config.py rename to tests/v1/05_backend/test_unit_slurm_config_v1.py diff --git a/tests/v2/00_ssh/test_FractalSSH.py b/tests/v2/00_ssh/test_FractalSSH.py index 89b48f3061..c386e901c7 100644 --- a/tests/v2/00_ssh/test_FractalSSH.py +++ b/tests/v2/00_ssh/test_FractalSSH.py @@ -258,3 +258,14 @@ def test_remove_folder_input_validation(): safe_root="/actual_root", ) print(e.value) + + +def test_write_remote_file(fractal_ssh: FractalSSH, tmp777_path: Path): + path = tmp777_path / "file" + content = "this is what goes into the file" + fractal_ssh.write_remote_file( + path=path.as_posix(), content=content, lock_timeout=100 + ) + assert path.exists() + with path.open("r") as f: + assert f.read() == content diff --git a/tests/v2/00_ssh/test_executor.py b/tests/v2/00_ssh/test_executor.py index 95a4fe43d0..50380cd038 100644 --- a/tests/v2/00_ssh/test_executor.py +++ b/tests/v2/00_ssh/test_executor.py @@ -79,3 +79,35 @@ def test_slurm_ssh_executor_map( res = executor.map(lambda x: x * 2, [1, 2, 3]) results = list(res) assert results == [2, 4, 6] + + +def test_slurm_ssh_executor_submit_with_pre_sbatch( + fractal_ssh, + tmp_path: Path, + tmp777_path: Path, + override_settings_factory, + current_py_version: str, +): + override_settings_factory( + FRACTAL_SLURM_WORKER_PYTHON=f"/usr/bin/python{current_py_version}" + ) + from fractal_server.app.runner.executors.slurm._slurm_config import ( + get_default_slurm_config, + ) + + auxfile = tmp777_path / "auxfile" + slurm_config = get_default_slurm_config() + slurm_config.pre_submission_commands = [f"touch {auxfile.as_posix()}"] + debug(slurm_config) + + with MockFractalSSHSlurmExecutor( + workflow_dir_local=tmp_path / "job_dir", + workflow_dir_remote=(tmp777_path / "remote_job_dir"), + slurm_poll_interval=1, + fractal_ssh=fractal_ssh, + ) as executor: + fut = executor.submit(lambda: 1, slurm_config=slurm_config) + debug(fut) + debug(fut.result()) + + assert auxfile.exists() diff --git a/tests/v2/04_runner/v2_mock_models.py b/tests/v2/04_runner/v2_mock_models.py index 4f9e018fbe..21135ab763 100644 --- a/tests/v2/04_runner/v2_mock_models.py +++ b/tests/v2/04_runner/v2_mock_models.py @@ -38,8 +38,8 @@ class TaskV2Mock(BaseModel): command_non_parallel: Optional[str] = None command_parallel: Optional[str] = None - meta_paralell: Optional[dict[str, Any]] = Field(default_factory=dict) - meta_non_paralell: Optional[dict[str, Any]] = Field(default_factory=dict) + meta_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) + meta_non_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) type: Optional[str] @root_validator(pre=False) diff --git a/tests/v2/09_backends/test_slurm_config.py b/tests/v2/09_backends/test_slurm_config.py new file mode 100644 index 0000000000..bec3d893cf --- /dev/null +++ b/tests/v2/09_backends/test_slurm_config.py @@ -0,0 +1,450 @@ +import json +from pathlib import Path +from typing import Any +from typing import Optional + +import pytest +from devtools import debug +from pydantic import BaseModel +from pydantic import Extra +from pydantic import Field +from pydantic import root_validator + +from fractal_server.app.runner.executors.slurm._slurm_config import ( + SlurmConfigError, +) +from fractal_server.app.runner.v2._slurm_common.get_slurm_config import ( + get_slurm_config, +) +from fractal_server.app.runner.v2._slurm_sudo._submit_setup import ( + _slurm_submit_setup, +) + + +class TaskV1Mock(BaseModel, extra=Extra.forbid): + id: int = 1 + name: str = "name_t1" + command: str = "cmd_t1" + source: str = "source_t1" + input_type: str + output_type: str + meta: Optional[dict[str, Any]] = Field(default_factory=dict) + + +class TaskV2Mock(BaseModel, extra=Extra.forbid): + id: int = 1 + name: str = "name_t2" + source: str = "source_t2" + input_types: dict[str, bool] = Field(default_factory=dict) + output_types: dict[str, bool] = Field(default_factory=dict) + + command_non_parallel: Optional[str] = "cmd_t2_non_parallel" + command_parallel: Optional[str] = None + meta_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) + meta_non_parallel: Optional[dict[str, Any]] = Field(default_factory=dict) + type: Optional[str] + + +class WorkflowTaskV2Mock(BaseModel, extra=Extra.forbid): + args_non_parallel: dict[str, Any] = Field(default_factory=dict) + args_parallel: dict[str, Any] = Field(default_factory=dict) + meta_non_parallel: dict[str, Any] = Field(default_factory=dict) + meta_parallel: dict[str, Any] = Field(default_factory=dict) + is_legacy_task: Optional[bool] + meta_parallel: Optional[dict[str, Any]] = Field() + meta_non_parallel: Optional[dict[str, Any]] = Field() + task: Optional[TaskV2Mock] = None + task_legacy: Optional[TaskV1Mock] = None + is_legacy_task: bool = False + input_filters: dict[str, Any] = Field(default_factory=dict) + order: int = 0 + id: int = 1 + workflow_id: int = 0 + task_legacy_id: Optional[int] + task_id: Optional[int] + + @root_validator(pre=False) + def _legacy_or_not(cls, values): + is_legacy_task = values["is_legacy_task"] + task = values.get("task") + task_legacy = values.get("task_legacy") + if is_legacy_task: + if task_legacy is None or task is not None: + raise ValueError(f"Invalid WorkflowTaskV2Mock with {values=}") + values["task_legacy_id"] = task_legacy.id + else: + if task is None or task_legacy is not None: + raise ValueError(f"Invalid WorkflowTaskV2Mock with {values=}") + values["task_id"] = task.id + return values + + @root_validator(pre=False) + def merge_meta(cls, values): + if values["is_legacy_task"]: + task_meta = values["task"].meta + if task_meta: + values["meta"] = { + **task_meta, + **values["meta"], + } + else: + task_meta_parallel = values["task"].meta_parallel + if task_meta_parallel: + values["meta_parallel"] = { + **task_meta_parallel, + **values["meta_parallel"], + } + task_meta_non_parallel = values["task"].meta_non_parallel + if task_meta_non_parallel: + values["meta_non_parallel"] = { + **task_meta_non_parallel, + **values["meta_non_parallel"], + } + return values + + +def test_get_slurm_config(tmp_path: Path): + """ + Testing that: + 1. WorkflowTask.meta overrides WorkflowTask.Task.meta + 2. needs_gpu=True triggers other changes + 3. If WorkflowTask.meta includes (e.g.) "gres", then this is the actual + value that is set (even for needs_gpu=True). + """ + + # Write gloabl variables into JSON config file + GPU_PARTITION = "gpu-partition" + GPU_DEFAULT_GRES = "gpu-default-gres" + GPU_DEFAULT_CONSTRAINT = "gpu-default-constraint" + DEFAULT_ACCOUNT = "default-account" + DEFAULT_EXTRA_LINES = ["#SBATCH --option=value", "export VAR1=VALUE1"] + USER_LOCAL_EXPORTS = {"SOME_CACHE_DIR": "SOME_CACHE_DIR"} + + original_slurm_config = { + "default_slurm_config": { + "partition": "main", + "mem": "1G", + "account": DEFAULT_ACCOUNT, + "extra_lines": DEFAULT_EXTRA_LINES, + }, + "gpu_slurm_config": { + "partition": GPU_PARTITION, + "mem": "1G", + "gres": GPU_DEFAULT_GRES, + "constraint": GPU_DEFAULT_CONSTRAINT, + }, + "batching_config": { + "target_cpus_per_job": 10, + "max_cpus_per_job": 12, + "target_mem_per_job": 10, + "max_mem_per_job": 12, + "target_num_jobs": 5, + "max_num_jobs": 10, + }, + "user_local_exports": USER_LOCAL_EXPORTS, + } + + config_path = tmp_path / "slurm_config.json" + with config_path.open("w") as f: + json.dump(original_slurm_config, f) + + # Create Task + CPUS_PER_TASK = 1 + MEM = 1 + CUSTOM_GRES = "my-custom-gres-from-task" + meta_non_parallel = dict( + cpus_per_task=CPUS_PER_TASK, + mem=MEM, + needs_gpu=False, + gres=CUSTOM_GRES, + extra_lines=["a", "b", "c", "d"], + ) + mytask = TaskV2Mock( + name="My beautiful task", + command_non_parallel="python something.py", + meta_non_parallel=meta_non_parallel, + ) + + # Create WorkflowTask + CPUS_PER_TASK_OVERRIDE = 2 + CUSTOM_CONSTRAINT = "my-custom-constraint-from-wftask" + CUSTOM_EXTRA_LINES = ["export VAR1=VALUE1", "export VAR2=VALUE2"] + MEM_OVERRIDE = "1G" + MEM_OVERRIDE_MB = 1000 + meta_non_parallel = dict( + cpus_per_task=CPUS_PER_TASK_OVERRIDE, + mem=MEM_OVERRIDE, + needs_gpu=True, + constraint=CUSTOM_CONSTRAINT, + extra_lines=CUSTOM_EXTRA_LINES, + ) + mywftask = WorkflowTaskV2Mock( + task=mytask, + args_non_parallel=dict(message="test"), + meta_non_parallel=meta_non_parallel, + ) + + # Call get_slurm_config + slurm_config = get_slurm_config( + wftask=mywftask, + config_path=config_path, + which_type="non_parallel", + ) + + # Check that WorkflowTask.meta takes priority over WorkflowTask.Task.meta + assert slurm_config.cpus_per_task == CPUS_PER_TASK_OVERRIDE + assert slurm_config.mem_per_task_MB == MEM_OVERRIDE_MB + assert slurm_config.partition == GPU_PARTITION + + # Check that both WorkflowTask.meta and WorkflowTask.Task.meta take + # priority over the "if_needs_gpu" key-value pair in slurm_config.json + assert slurm_config.gres == CUSTOM_GRES + assert slurm_config.constraint == CUSTOM_CONSTRAINT + + # Check that some optional attributes are set/unset correctly + assert slurm_config.job_name + assert " " not in slurm_config.job_name + assert slurm_config.account == DEFAULT_ACCOUNT + assert "time" not in slurm_config.dict(exclude_unset=True).keys() + # Check that extra_lines from WorkflowTask.meta and config_path + # are combined together, and that repeated elements were removed + assert len(slurm_config.extra_lines) == 3 + assert len(slurm_config.extra_lines) == len(set(slurm_config.extra_lines)) + # Check value of user_local_exports + assert slurm_config.user_local_exports == USER_LOCAL_EXPORTS + + +def test_get_slurm_config_fail(tmp_path): + slurm_config = { + "default_slurm_config": { + "partition": "main", + "cpus_per_task": 1, + "mem": "1G", + }, + "gpu_slurm_config": { + "partition": "main", + }, + "batching_config": { + "target_cpus_per_job": 10, + "max_cpus_per_job": 12, + "target_mem_per_job": 10, + "max_mem_per_job": 12, + "target_num_jobs": 5, + "max_num_jobs": 10, + }, + } + + # Valid + config_path_valid = tmp_path / "slurm_config_valid.json" + with config_path_valid.open("w") as f: + json.dump(slurm_config, f) + get_slurm_config( + wftask=WorkflowTaskV2Mock( + task=TaskV2Mock(), + meta_non_parallel={}, + ), + config_path=config_path_valid, + which_type="non_parallel", + ) + + # Invalid + slurm_config["INVALID_KEY"] = "something" + config_path_invalid = tmp_path / "slurm_config_invalid.json" + with config_path_invalid.open("w") as f: + json.dump(slurm_config, f) + with pytest.raises( + SlurmConfigError, match="extra fields not permitted" + ) as e: + get_slurm_config( + wftask=WorkflowTaskV2Mock( + task=TaskV2Mock(), + meta_non_parallel={}, + ), + config_path=config_path_invalid, + which_type="non_parallel", + ) + debug(e.value) + + +def test_get_slurm_config_wftask_meta_none(tmp_path): + """ + Similar to test_get_slurm_config, but wftask has meta=None. + """ + + # Write gloabl variables into JSON config file + GPU_PARTITION = "gpu-partition" + GPU_DEFAULT_GRES = "gpu-default-gres" + GPU_DEFAULT_CONSTRAINT = "gpu-default-constraint" + DEFAULT_ACCOUNT = "default-account" + DEFAULT_EXTRA_LINES = ["#SBATCH --option=value", "export VAR1=VALUE1"] + USER_LOCAL_EXPORTS = {"SOME_CACHE_DIR": "SOME_CACHE_DIR"} + + slurm_config = { + "default_slurm_config": { + "partition": "main", + "mem": "1G", + "account": DEFAULT_ACCOUNT, + "extra_lines": DEFAULT_EXTRA_LINES, + }, + "gpu_slurm_config": { + "partition": GPU_PARTITION, + "mem": "1G", + "gres": GPU_DEFAULT_GRES, + "constraint": GPU_DEFAULT_CONSTRAINT, + }, + "batching_config": { + "target_cpus_per_job": 10, + "max_cpus_per_job": 12, + "target_mem_per_job": 10, + "max_mem_per_job": 12, + "target_num_jobs": 5, + "max_num_jobs": 10, + }, + "user_local_exports": USER_LOCAL_EXPORTS, + } + config_path = tmp_path / "slurm_config.json" + with config_path.open("w") as f: + json.dump(slurm_config, f) + + # Create WorkflowTask + CPUS_PER_TASK_OVERRIDE = 2 + CUSTOM_CONSTRAINT = "my-custom-constraint-from-wftask" + CUSTOM_EXTRA_LINES = ["export VAR1=VALUE1", "export VAR2=VALUE2"] + MEM_OVERRIDE = "1G" + MEM_OVERRIDE_MB = 1000 + meta_non_parallel = dict( + cpus_per_task=CPUS_PER_TASK_OVERRIDE, + mem=MEM_OVERRIDE, + needs_gpu=True, + constraint=CUSTOM_CONSTRAINT, + extra_lines=CUSTOM_EXTRA_LINES, + ) + mywftask = WorkflowTaskV2Mock( + task=TaskV2Mock(meta_non_parallel=None), + args_non_parallel=dict(message="test"), + meta_non_parallel=meta_non_parallel, + ) + debug(mywftask) + + # Call get_slurm_config + slurm_config = get_slurm_config( + wftask=mywftask, + config_path=config_path, + which_type="non_parallel", + ) + debug(slurm_config) + + # Check that WorkflowTask.meta takes priority over WorkflowTask.Task.meta + assert slurm_config.cpus_per_task == CPUS_PER_TASK_OVERRIDE + assert slurm_config.mem_per_task_MB == MEM_OVERRIDE_MB + assert slurm_config.partition == GPU_PARTITION + # Check that both WorkflowTask.meta and WorkflowTask.Task.meta take + # priority over the "if_needs_gpu" key-value pair in slurm_config.json + assert slurm_config.constraint == CUSTOM_CONSTRAINT + # Check that some optional attributes are set/unset correctly + assert slurm_config.job_name + assert " " not in slurm_config.job_name + assert slurm_config.account == DEFAULT_ACCOUNT + assert "time" not in slurm_config.dict(exclude_unset=True).keys() + # Check that extra_lines from WorkflowTask.meta and config_path + # are combined together, and that repeated elements were removed + assert len(slurm_config.extra_lines) == 3 + assert len(slurm_config.extra_lines) == len(set(slurm_config.extra_lines)) + # Check value of user_local_exports + assert slurm_config.user_local_exports == USER_LOCAL_EXPORTS + + +def test_slurm_submit_setup( + tmp_path: Path, testdata_path: Path, override_settings_factory +): + override_settings_factory( + FRACTAL_SLURM_CONFIG_FILE=testdata_path / "slurm_config.json" + ) + + # No account in `wftask.meta` --> OK + wftask = WorkflowTaskV2Mock(task=TaskV2Mock()) + slurm_config = _slurm_submit_setup( + wftask=wftask, + workflow_dir_local=tmp_path, + workflow_dir_remote=tmp_path, + which_type="non_parallel", + ) + debug(slurm_config) + assert slurm_config["slurm_config"].account is None + + # Account in `wftask.meta_non_parallel` --> fail + wftask = WorkflowTaskV2Mock( + meta_non_parallel=dict(key="value", account="MyFakeAccount"), + task=TaskV2Mock(), + ) + with pytest.raises(SlurmConfigError) as e: + _slurm_submit_setup( + wftask=wftask, + workflow_dir_local=tmp_path, + workflow_dir_remote=tmp_path, + which_type="non_parallel", + ) + debug(e.value) + assert "SLURM account" in str(e.value) + + +def test_get_slurm_config_gpu_options(tmp_path: Path): + """ + Test that GPU-related options are only read when `needs_gpu=True`. + """ + STANDARD_PARTITION = "main" + GPU_PARTITION = "gpupartition" + GPU_MEM = "20G" + GPU_MEM_PER_TASK_MB = 20000 + GPUS = "1" + PRE_SUBMISSION_COMMANDS = ["module load gpu"] + + slurm_config_dict = { + "default_slurm_config": { + "partition": STANDARD_PARTITION, + "mem": "1G", + "cpus_per_task": 1, + }, + "gpu_slurm_config": { + "partition": GPU_PARTITION, + "mem": GPU_MEM, + "gpus": GPUS, + "pre_submission_commands": PRE_SUBMISSION_COMMANDS, + }, + "batching_config": { + "target_cpus_per_job": 10, + "max_cpus_per_job": 12, + "target_mem_per_job": 10, + "max_mem_per_job": 12, + "target_num_jobs": 5, + "max_num_jobs": 10, + }, + } + config_path = tmp_path / "slurm_config.json" + with config_path.open("w") as f: + json.dump(slurm_config_dict, f) + + # In absence of `needs_gpu`, parameters in `gpu_slurm_config` are not used + mywftask = WorkflowTaskV2Mock(task=TaskV2Mock()) + slurm_config = get_slurm_config( + wftask=mywftask, + config_path=config_path, + which_type="non_parallel", + ) + assert slurm_config.partition == STANDARD_PARTITION + assert slurm_config.gpus is None + assert slurm_config.pre_submission_commands == [] + + # When `needs_gpu` is set, parameters in `gpu_slurm_config` are used + mywftask = WorkflowTaskV2Mock( + meta_non_parallel=dict(needs_gpu=True), task=TaskV2Mock() + ) + slurm_config = get_slurm_config( + wftask=mywftask, + config_path=config_path, + which_type="non_parallel", + ) + assert slurm_config.partition == GPU_PARTITION + assert slurm_config.gpus == GPUS + assert slurm_config.mem_per_task_MB == GPU_MEM_PER_TASK_MB + assert slurm_config.pre_submission_commands == PRE_SUBMISSION_COMMANDS