diff --git a/CHANGELOG.md b/CHANGELOG.md index 3111c59d64..5c4e252702 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,30 @@ **Note**: Numbers like (\#1234) point to closed Pull Requests on the fractal-server repository. -# 2.3.1 (unreleased) +# 2.3.2 +> **WARNING**: The remove-remote-venv-folder in the SSH task collection is broken (see issue 1633). Do not deploy this version in an SSH-based `fractal-server` instance. + +* API: + * Fix incorrect zipping of structured job-log folders (\#1648). + +# 2.3.1 + +This release includes a bugfix for task names with special characters. + +> **WARNING**: The remove-remote-venv-folder in the SSH task collection is broken (see issue 1633). Do not deploy this version in an SSH-based `fractal-server` instance. + +* Runner: + * Improve sanitization of subfolder names (commits from 3d89d6ba104d1c6f11812bc9de5cbdff25f81aa2 to 426fa3522cf2eef90d8bd2da3b2b8a5b646b9bf4). +* API: + * Improve error message when task-collection Python is not defined (\#1640). + * Use a single endpoint for standard and SSH task collection (\#1640). * SSH features: - * Remove remote venv folder upon failed task collection in SSH mode (\#1634). + * Remove remote venv folder upon failed task collection in SSH mode (\#1634, \#1640). + * Refactor `FractalSSH` (\#1635). + * Set `fabric.Connection.forward_agent=False` (\#1639). * Testing: + * Improved testing of SSH task-collection API (\#1640). + * Improved testing of `FractalSSH` methods (\#1635). * Stop testing SQLite database for V1 in CI (\#1630). # 2.3.0 diff --git a/fractal_server/__init__.py b/fractal_server/__init__.py index 8f838724f8..2763d367fb 100644 --- a/fractal_server/__init__.py +++ b/fractal_server/__init__.py @@ -1 +1 @@ -__VERSION__ = "2.3.0" +__VERSION__ = "2.3.2" diff --git a/fractal_server/app/routes/admin/v1.py b/fractal_server/app/routes/admin/v1.py index 7809b6871f..163c255b67 100644 --- a/fractal_server/app/routes/admin/v1.py +++ b/fractal_server/app/routes/admin/v1.py @@ -387,9 +387,7 @@ async def download_job_logs( # Create and return byte stream for zipped log folder PREFIX_ZIP = Path(job.working_dir).name zip_filename = f"{PREFIX_ZIP}_archive.zip" - byte_stream = _zip_folder_to_byte_stream( - folder=job.working_dir, zip_filename=zip_filename - ) + byte_stream = _zip_folder_to_byte_stream(folder=job.working_dir) return StreamingResponse( iter([byte_stream.getvalue()]), media_type="application/x-zip-compressed", diff --git a/fractal_server/app/routes/admin/v2.py b/fractal_server/app/routes/admin/v2.py index 6ba0ac6ea0..50c0d9a057 100644 --- a/fractal_server/app/routes/admin/v2.py +++ b/fractal_server/app/routes/admin/v2.py @@ -274,9 +274,7 @@ async def download_job_logs( # Create and return byte stream for zipped log folder PREFIX_ZIP = Path(job.working_dir).name zip_filename = f"{PREFIX_ZIP}_archive.zip" - byte_stream = _zip_folder_to_byte_stream( - folder=job.working_dir, zip_filename=zip_filename - ) + byte_stream = _zip_folder_to_byte_stream(folder=job.working_dir) return StreamingResponse( iter([byte_stream.getvalue()]), media_type="application/x-zip-compressed", diff --git a/fractal_server/app/routes/api/v1/job.py b/fractal_server/app/routes/api/v1/job.py index c9402faf18..55f79ba400 100644 --- a/fractal_server/app/routes/api/v1/job.py +++ b/fractal_server/app/routes/api/v1/job.py @@ -128,9 +128,7 @@ async def download_job_logs( # Create and return byte stream for zipped log folder PREFIX_ZIP = Path(job.working_dir).name zip_filename = f"{PREFIX_ZIP}_archive.zip" - byte_stream = _zip_folder_to_byte_stream( - folder=job.working_dir, zip_filename=zip_filename - ) + byte_stream = _zip_folder_to_byte_stream(folder=job.working_dir) return StreamingResponse( iter([byte_stream.getvalue()]), media_type="application/x-zip-compressed", diff --git a/fractal_server/app/routes/api/v1/task_collection.py b/fractal_server/app/routes/api/v1/task_collection.py index 2733974530..11d838b983 100644 --- a/fractal_server/app/routes/api/v1/task_collection.py +++ b/fractal_server/app/routes/api/v1/task_collection.py @@ -25,8 +25,8 @@ from ....security import current_active_user from ....security import current_active_verified_user from ....security import User +from fractal_server.string_tools import slugify_task_name_for_source from fractal_server.tasks.utils import get_collection_log -from fractal_server.tasks.utils import slugify_task_name from fractal_server.tasks.v1._TaskCollectPip import _TaskCollectPip from fractal_server.tasks.v1.background_operations import ( background_collect_pip, @@ -159,7 +159,7 @@ async def collect_tasks_pip( # Check that tasks are not already in the DB for new_task in task_pkg.package_manifest.task_list: - new_task_name_slug = slugify_task_name(new_task.name) + new_task_name_slug = slugify_task_name_for_source(new_task.name) new_task_source = f"{task_pkg.package_source}:{new_task_name_slug}" stm = select(Task).where(Task.source == new_task_source) res = await db.execute(stm) diff --git a/fractal_server/app/routes/api/v2/__init__.py b/fractal_server/app/routes/api/v2/__init__.py index b92ebf0a98..9231a17bef 100644 --- a/fractal_server/app/routes/api/v2/__init__.py +++ b/fractal_server/app/routes/api/v2/__init__.py @@ -12,7 +12,6 @@ from .task import router as task_router_v2 from .task_collection import router as task_collection_router_v2 from .task_collection_custom import router as task_collection_router_v2_custom -from .task_collection_ssh import router as task_collection_router_v2_ssh from .task_legacy import router as task_legacy_router_v2 from .workflow import router as workflow_router_v2 from .workflowtask import router as workflowtask_router_v2 @@ -30,21 +29,14 @@ settings = Inject(get_settings) -if settings.FRACTAL_RUNNER_BACKEND == "slurm_ssh": - router_api_v2.include_router( - task_collection_router_v2_ssh, - prefix="/task", - tags=["V2 Task Collection"], - ) -else: - router_api_v2.include_router( - task_collection_router_v2, prefix="/task", tags=["V2 Task Collection"] - ) - router_api_v2.include_router( - task_collection_router_v2_custom, - prefix="/task", - tags=["V2 Task Collection"], - ) +router_api_v2.include_router( + task_collection_router_v2, prefix="/task", tags=["V2 Task Collection"] +) +router_api_v2.include_router( + task_collection_router_v2_custom, + prefix="/task", + tags=["V2 Task Collection"], +) router_api_v2.include_router(task_router_v2, prefix="/task", tags=["V2 Task"]) router_api_v2.include_router( task_legacy_router_v2, prefix="/task-legacy", tags=["V2 Task Legacy"] diff --git a/fractal_server/app/routes/api/v2/job.py b/fractal_server/app/routes/api/v2/job.py index d6851bd902..02da7354bd 100644 --- a/fractal_server/app/routes/api/v2/job.py +++ b/fractal_server/app/routes/api/v2/job.py @@ -131,9 +131,7 @@ async def download_job_logs( # Create and return byte stream for zipped log folder PREFIX_ZIP = Path(job.working_dir).name zip_filename = f"{PREFIX_ZIP}_archive.zip" - byte_stream = _zip_folder_to_byte_stream( - folder=job.working_dir, zip_filename=zip_filename - ) + byte_stream = _zip_folder_to_byte_stream(folder=job.working_dir) return StreamingResponse( iter([byte_stream.getvalue()]), media_type="application/x-zip-compressed", diff --git a/fractal_server/app/routes/api/v2/task_collection.py b/fractal_server/app/routes/api/v2/task_collection.py index 83354328f9..9303211b19 100644 --- a/fractal_server/app/routes/api/v2/task_collection.py +++ b/fractal_server/app/routes/api/v2/task_collection.py @@ -7,6 +7,7 @@ from fastapi import BackgroundTasks from fastapi import Depends from fastapi import HTTPException +from fastapi import Request from fastapi import Response from fastapi import status from pydantic.error_wrappers import ValidationError @@ -27,10 +28,10 @@ from ....security import current_active_user from ....security import current_active_verified_user from ....security import User +from fractal_server.string_tools import slugify_task_name_for_source from fractal_server.tasks.utils import get_absolute_venv_path from fractal_server.tasks.utils import get_collection_log from fractal_server.tasks.utils import get_collection_path -from fractal_server.tasks.utils import slugify_task_name from fractal_server.tasks.v2._TaskCollectPip import _TaskCollectPip from fractal_server.tasks.v2.background_operations import ( background_collect_pip, @@ -38,6 +39,7 @@ from fractal_server.tasks.v2.endpoint_operations import create_package_dir_pip from fractal_server.tasks.v2.endpoint_operations import download_package from fractal_server.tasks.v2.endpoint_operations import inspect_package +from fractal_server.tasks.v2.utils import get_python_interpreter_v2 router = APIRouter() @@ -66,6 +68,7 @@ async def collect_tasks_pip( task_collect: TaskCollectPipV2, background_tasks: BackgroundTasks, response: Response, + request: Request, user: User = Depends(current_active_verified_user), db: AsyncSession = Depends(get_async_db), ) -> CollectionStateReadV2: @@ -76,17 +79,26 @@ async def collect_tasks_pip( of a package and the collection of tasks as advertised in the manifest. """ - logger = set_logger(logger_name="collect_tasks_pip") + # Get settings + settings = Inject(get_settings) - # Set default python version + # Set/check python version if task_collect.python_version is None: - settings = Inject(get_settings) task_collect.python_version = ( settings.FRACTAL_TASKS_PYTHON_DEFAULT_VERSION ) + try: + get_python_interpreter_v2(python_version=task_collect.python_version) + except ValueError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"Python version {task_collect.python_version} is " + "not available for Fractal task collection." + ), + ) - # Validate payload as _TaskCollectPip, which has more strict checks than - # TaskCollectPip + # Validate payload try: task_pkg = _TaskCollectPip(**task_collect.dict(exclude_unset=True)) except ValidationError as e: @@ -95,6 +107,37 @@ async def collect_tasks_pip( detail=f"Invalid task-collection object. Original error: {e}", ) + # END of SSH/non-SSH common part + + if settings.FRACTAL_RUNNER_BACKEND == "slurm_ssh": + + from fractal_server.tasks.v2.background_operations_ssh import ( + background_collect_pip_ssh, + ) + + # Construct and return state + state = CollectionStateV2( + data=dict( + status=CollectionStatusV2.PENDING, package=task_collect.package + ) + ) + db.add(state) + await db.commit() + + background_tasks.add_task( + background_collect_pip_ssh, + state.id, + task_pkg, + request.app.state.fractal_ssh, + ) + + response.status_code = status.HTTP_201_CREATED + return state + + # Actual non-SSH endpoint + + logger = set_logger(logger_name="collect_tasks_pip") + with TemporaryDirectory() as tmpdir: try: # Copy or download the package wheel file to tmpdir @@ -197,7 +240,7 @@ async def collect_tasks_pip( # Check that tasks are not already in the DB for new_task in task_pkg.package_manifest.task_list: - new_task_name_slug = slugify_task_name(new_task.name) + new_task_name_slug = slugify_task_name_for_source(new_task.name) new_task_source = f"{task_pkg.package_source}:{new_task_name_slug}" stm = select(TaskV2).where(TaskV2.source == new_task_source) res = await db.execute(stm) @@ -253,6 +296,7 @@ async def check_collection_status( """ Check status of background task collection """ + logger = set_logger(logger_name="check_collection_status") logger.debug(f"Querying state for state.id={state_id}") state = await db.get(CollectionStateV2, state_id) @@ -263,17 +307,28 @@ async def check_collection_status( detail=f"No task collection info with id={state_id}", ) - # In some cases (i.e. a successful or ongoing task collection), - # state.data.log is not set; if so, we collect the current logs. - if verbose and not state.data.get("log"): - if "venv_path" not in state.data.keys(): - await db.close() - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"No 'venv_path' in CollectionStateV2[{state_id}].data", + settings = Inject(get_settings) + if settings.FRACTAL_RUNNER_BACKEND == "slurm_ssh": + # FIXME SSH: add logic for when data.state["log"] is empty + pass + else: + # Non-SSH mode + # In some cases (i.e. a successful or ongoing task collection), + # state.data.log is not set; if so, we collect the current logs. + if verbose and not state.data.get("log"): + if "venv_path" not in state.data.keys(): + await db.close() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + f"No 'venv_path' in CollectionStateV2[{state_id}].data" + ), + ) + state.data["log"] = get_collection_log( + Path(state.data["venv_path"]) ) - state.data["log"] = get_collection_log(Path(state.data["venv_path"])) - state.data["venv_path"] = str(state.data["venv_path"]) + state.data["venv_path"] = str(state.data["venv_path"]) + reset_logger_handlers(logger) await db.close() return state diff --git a/fractal_server/app/routes/api/v2/task_collection_ssh.py b/fractal_server/app/routes/api/v2/task_collection_ssh.py deleted file mode 100644 index 481969da8a..0000000000 --- a/fractal_server/app/routes/api/v2/task_collection_ssh.py +++ /dev/null @@ -1,125 +0,0 @@ -from fastapi import APIRouter -from fastapi import BackgroundTasks -from fastapi import Depends -from fastapi import HTTPException -from fastapi import Request -from fastapi import Response -from fastapi import status -from pydantic.error_wrappers import ValidationError - -from .....config import get_settings -from .....logger import reset_logger_handlers -from .....logger import set_logger -from .....syringe import Inject -from .....tasks.v2._TaskCollectPip import _TaskCollectPip -from ....db import AsyncSession -from ....db import get_async_db -from ....models.v2 import CollectionStateV2 -from ....schemas.v2 import CollectionStateReadV2 -from ....schemas.v2 import CollectionStatusV2 -from ....schemas.v2 import TaskCollectPipV2 -from ....security import current_active_user -from ....security import current_active_verified_user -from ....security import User -from fractal_server.tasks.v2.background_operations_ssh import ( - background_collect_pip_ssh, -) - -router = APIRouter() - -logger = set_logger(__name__) - - -@router.post( - "/collect/pip/", - response_model=CollectionStateReadV2, - responses={ - 201: dict( - description=( - "Task collection successfully started in the background" - ) - ), - 200: dict( - description=( - "Package already collected. Returning info on already " - "available tasks" - ) - ), - }, -) -async def collect_tasks_pip( - task_collect: TaskCollectPipV2, - background_tasks: BackgroundTasks, - response: Response, - request: Request, - user: User = Depends(current_active_verified_user), - db: AsyncSession = Depends(get_async_db), -) -> CollectionStateReadV2: - """ - Task collection endpoint (SSH version) - """ - - # Set default python version - if task_collect.python_version is None: - settings = Inject(get_settings) - task_collect.python_version = ( - settings.FRACTAL_TASKS_PYTHON_DEFAULT_VERSION - ) - - # Validate payload as _TaskCollectPip, which has more strict checks than - # TaskCollectPip - try: - task_pkg = _TaskCollectPip(**task_collect.dict(exclude_unset=True)) - except ValidationError as e: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Invalid task-collection object. Original error: {e}", - ) - - # Note: we don't use TaskCollectStatusV2 here for the JSON column `data` - state = CollectionStateV2( - data=dict( - status=CollectionStatusV2.PENDING, package=task_collect.package - ) - ) - db.add(state) - await db.commit() - - background_tasks.add_task( - background_collect_pip_ssh, - state.id, - task_pkg, - request.app.state.fractal_ssh, - ) - - response.status_code = status.HTTP_201_CREATED - return state - - -# FIXME SSH: check_collection_status code is almost identical to the -# one in task_collection.py -@router.get("/collect/{state_id}/", response_model=CollectionStateReadV2) -async def check_collection_status( - state_id: int, - verbose: bool = False, - user: User = Depends(current_active_user), - db: AsyncSession = Depends(get_async_db), -) -> CollectionStateReadV2: - """ - Check status of background task collection - """ - logger = set_logger(logger_name="check_collection_status") - logger.debug(f"Querying state for state.id={state_id}") - state = await db.get(CollectionStateV2, state_id) - if not state: - await db.close() - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No task collection info with id={state_id}", - ) - - # FIXME SSH: add logic for when data.state["log"] is empty - - reset_logger_handlers(logger) - await db.close() - return state diff --git a/fractal_server/app/routes/aux/_job.py b/fractal_server/app/routes/aux/_job.py index dcbef0399d..31188f0f4a 100644 --- a/fractal_server/app/routes/aux/_job.py +++ b/fractal_server/app/routes/aux/_job.py @@ -1,3 +1,4 @@ +import os from io import BytesIO from pathlib import Path from typing import Union @@ -25,19 +26,20 @@ def _write_shutdown_file(*, job: Union[ApplyWorkflow, JobV2]): f.write(f"Trigger executor shutdown for {job.id=}.") -def _zip_folder_to_byte_stream(*, folder: str, zip_filename: str) -> BytesIO: +def _zip_folder_to_byte_stream(*, folder: str) -> BytesIO: """ Get byte stream with the zipped log folder of a job. Args: folder: the folder to zip - zip_filename: name of the zipped archive """ - working_dir_path = Path(folder) byte_stream = BytesIO() with ZipFile(byte_stream, mode="w", compression=ZIP_DEFLATED) as zipfile: - for fpath in working_dir_path.glob("*"): - zipfile.write(filename=str(fpath), arcname=str(fpath.name)) + for root, dirs, files in os.walk(folder): + for file in files: + file_path = os.path.join(root, file) + archive_path = os.path.relpath(file_path, folder) + zipfile.write(file_path, archive_path) return byte_stream diff --git a/fractal_server/app/runner/executors/slurm/ssh/executor.py b/fractal_server/app/runner/executors/slurm/ssh/executor.py index 4bf6070b9b..d06da83c44 100644 --- a/fractal_server/app/runner/executors/slurm/ssh/executor.py +++ b/fractal_server/app/runner/executors/slurm/ssh/executor.py @@ -44,7 +44,6 @@ from fractal_server.config import get_settings from fractal_server.logger import set_logger from fractal_server.ssh._fabric import FractalSSH -from fractal_server.ssh._fabric import run_command_over_ssh from fractal_server.syringe import Inject logger = set_logger(__name__) @@ -852,7 +851,7 @@ def _put_subfolder_sftp(self, jobs: list[SlurmJob]) -> None: "fractal_server.app.runner.extract_archive " f"{tarfile_path_remote}" ) - run_command_over_ssh(cmd=tar_command, fractal_ssh=self.fractal_ssh) + self.fractal_ssh.run_command(cmd=tar_command) # Remove local version t_0_rm = time.perf_counter() @@ -874,9 +873,8 @@ def _submit_job(self, job: SlurmJob) -> tuple[Future, str]: # Submit job to SLURM, and get jobid sbatch_command = f"sbatch --parsable {job.slurm_script_remote}" - sbatch_stdout = run_command_over_ssh( + sbatch_stdout = self.fractal_ssh.run_command( cmd=sbatch_command, - fractal_ssh=self.fractal_ssh, ) # Extract SLURM job ID from stdout @@ -1226,9 +1224,7 @@ def _get_subfolder_sftp(self, jobs: list[SlurmJob]) -> None: "-m fractal_server.app.runner.compress_folder " f"{(self.workflow_dir_remote / subfolder_name).as_posix()}" ) - stdout = run_command_over_ssh( - cmd=tar_command, fractal_ssh=self.fractal_ssh - ) + stdout = self.fractal_ssh.run_command(cmd=tar_command) print(stdout) # Fetch tarfile @@ -1352,9 +1348,7 @@ def shutdown(self, wait=True, *, cancel_futures=False): scancel_string = " ".join(slurm_jobs_to_scancel) logger.warning(f"Now scancel-ing SLURM jobs {scancel_string}") scancel_command = f"scancel {scancel_string}" - run_command_over_ssh( - cmd=scancel_command, fractal_ssh=self.fractal_ssh - ) + self.fractal_ssh.run_command(cmd=scancel_command) logger.debug("Executor shutdown: end") def __exit__(self, *args, **kwargs): @@ -1379,10 +1373,7 @@ def run_squeue(self, job_ids): ) job_ids = ",".join([str(j) for j in job_ids]) squeue_command = squeue_command.replace("__JOBS__", job_ids) - stdout = run_command_over_ssh( - cmd=squeue_command, - fractal_ssh=self.fractal_ssh, - ) + stdout = self.fractal_ssh.run_command(cmd=squeue_command) return stdout def _jobs_finished(self, job_ids: list[str]) -> set[str]: @@ -1462,7 +1453,7 @@ def handshake(self) -> dict: logger.info("[FractalSlurmSSHExecutor.ssh_handshake] START") cmd = f"{self.python_remote} -m fractal_server.app.runner.versions" - stdout = run_command_over_ssh(cmd=cmd, fractal_ssh=self.fractal_ssh) + stdout = self.fractal_ssh.run_command(cmd=cmd) remote_versions = json.loads(stdout.strip("\n")) # Check compatibility with local versions diff --git a/fractal_server/app/runner/task_files.py b/fractal_server/app/runner/task_files.py index 92a0b50c2e..33cb84b2aa 100644 --- a/fractal_server/app/runner/task_files.py +++ b/fractal_server/app/runner/task_files.py @@ -2,18 +2,7 @@ from typing import Optional from typing import Union -from fractal_server.tasks.utils import slugify_task_name - - -def sanitize_component(value: str) -> str: - """ - Remove {" ", "/", "."} form a string, e.g. going from - 'plate.zarr/B/03/0' to 'plate_zarr_B_03_0'. - - Args: - value: Input strig - """ - return value.replace(" ", "_").replace("/", "_").replace(".", "_") +from fractal_server.string_tools import sanitize_string def task_subfolder_name(order: Union[int, str], task_name: str) -> str: @@ -24,7 +13,7 @@ def task_subfolder_name(order: Union[int, str], task_name: str) -> str: order: task_name: """ - task_name_slug = slugify_task_name(task_name) + task_name_slug = sanitize_string(task_name) return f"{order}_{task_name_slug}" @@ -93,7 +82,7 @@ def __init__( self.component = component if self.component is not None: - component_safe = sanitize_component(str(self.component)) + component_safe = sanitize_string(str(self.component)) component_safe = f"_par_{component_safe}" else: component_safe = "" diff --git a/fractal_server/app/runner/v2/__init__.py b/fractal_server/app/runner/v2/__init__.py index 9b706ddbb7..8c8ad8d47e 100644 --- a/fractal_server/app/runner/v2/__init__.py +++ b/fractal_server/app/runner/v2/__init__.py @@ -189,11 +189,8 @@ async def submit_workflow( / WORKFLOW_DIR_LOCAL.name ) # FIXME SSH: move mkdir to executor, likely within handshake - - from ....ssh._fabric import _mkdir_over_ssh - - _mkdir_over_ssh( - folder=str(WORKFLOW_DIR_REMOTE), fractal_ssh=fractal_ssh + fractal_ssh.mkdir( + folder=str(WORKFLOW_DIR_REMOTE), ) logging.info(f"Created {str(WORKFLOW_DIR_REMOTE)} via SSH.") else: diff --git a/fractal_server/ssh/_fabric.py b/fractal_server/ssh/_fabric.py index b8b953406e..c888f21984 100644 --- a/fractal_server/ssh/_fabric.py +++ b/fractal_server/ssh/_fabric.py @@ -1,8 +1,11 @@ +import logging import time from contextlib import contextmanager from pathlib import Path from threading import Lock from typing import Any +from typing import Generator +from typing import Literal from typing import Optional import paramiko.sftp_client @@ -16,66 +19,106 @@ from fractal_server.config import get_settings from fractal_server.syringe import Inject -logger = set_logger(__name__) - -MAX_ATTEMPTS = 5 - -class TimeoutException(Exception): +class FractalSSHTimeoutError(RuntimeError): pass -@contextmanager -def acquire_timeout(lock: Lock, timeout: int) -> Any: - logger.debug(f"Trying to acquire lock, with {timeout=}") - result = lock.acquire(timeout=timeout) - try: - if not result: - raise TimeoutException( - f"Failed to acquire lock within {timeout} seconds" - ) - logger.debug("Lock was acquired.") - yield result - finally: - if result: - lock.release() - logger.debug("Lock was released") +logger = set_logger(__name__) class FractalSSH(object): - lock: Lock - connection: Connection - default_timeout: int - # FIXME SSH: maybe extend the actual_timeout logic to other methods + """ + FIXME SSH: Fix docstring - def __init__(self, connection: Connection, default_timeout: int = 250): - self.lock = Lock() - self.conn = connection - self.default_timeout = default_timeout + Attributes: + _lock: + connection: + default_lock_timeout: + logger_name: + """ + + _lock: Lock + _connection: Connection + default_lock_timeout: float + default_max_attempts: int + default_base_interval: float + logger_name: str + + def __init__( + self, + connection: Connection, + default_timeout: float = 250, + default_max_attempts: int = 5, + default_base_interval: float = 3.0, + logger_name: str = __name__, + ): + self._lock = Lock() + self._connection = connection + self.default_lock_timeout = default_timeout + self.default_base_interval = default_base_interval + self.default_max_attempts = default_max_attempts + self.logger_name = logger_name + set_logger(self.logger_name) + + @contextmanager + def acquire_timeout( + self, timeout: float + ) -> Generator[Literal[True], Any, None]: + self.logger.debug(f"Trying to acquire lock, with {timeout=}") + result = self._lock.acquire(timeout=timeout) + try: + if not result: + self.logger.error("Lock was *NOT* acquired.") + raise FractalSSHTimeoutError( + f"Failed to acquire lock within {timeout} seconds" + ) + self.logger.debug("Lock was acquired.") + yield result + finally: + if result: + self._lock.release() + self.logger.debug("Lock was released") @property def is_connected(self) -> bool: - return self.conn.is_connected - - def put(self, *args, timeout: Optional[int] = None, **kwargs) -> Result: - actual_timeout = timeout or self.default_timeout - with acquire_timeout(self.lock, timeout=actual_timeout): - return self.conn.put(*args, **kwargs) - - def get(self, *args, **kwargs) -> Result: - with acquire_timeout(self.lock, timeout=self.default_timeout): - return self.conn.get(*args, **kwargs) + return self._connection.is_connected - def run(self, *args, **kwargs) -> Any: - with acquire_timeout(self.lock, timeout=self.default_timeout): - return self.conn.run(*args, **kwargs) - - def close(self) -> None: - return self.conn.close() + @property + def logger(self) -> logging.Logger: + return get_logger(self.logger_name) + + def put( + self, *args, lock_timeout: Optional[float] = None, **kwargs + ) -> Result: + actual_lock_timeout = self.default_lock_timeout + if lock_timeout is not None: + actual_lock_timeout = lock_timeout + with self.acquire_timeout(timeout=actual_lock_timeout): + return self._connection.put(*args, **kwargs) + + def get( + self, *args, lock_timeout: Optional[float] = None, **kwargs + ) -> Result: + actual_lock_timeout = self.default_lock_timeout + if lock_timeout is not None: + actual_lock_timeout = lock_timeout + with self.acquire_timeout(timeout=actual_lock_timeout): + return self._connection.get(*args, **kwargs) + + def run( + self, *args, lock_timeout: Optional[float] = None, **kwargs + ) -> Any: + + actual_lock_timeout = self.default_lock_timeout + if lock_timeout is not None: + actual_lock_timeout = lock_timeout + with self.acquire_timeout(timeout=actual_lock_timeout): + return self._connection.run(*args, **kwargs) def sftp(self) -> paramiko.sftp_client.SFTPClient: - return self.conn.sftp() + return self._connection.sftp() def check_connection(self) -> None: """ @@ -85,13 +128,184 @@ def check_connection(self) -> None: `connection`, so that we can provide a meaningful error in case the SSH connection cannot be opened. """ - if not self.conn.is_connected: + if not self._connection.is_connected: try: - self.conn.open() + self._connection.open() except Exception as e: raise RuntimeError( - f"Cannot open SSH connection (original error: '{str(e)}')." + f"Cannot open SSH connection. Original error:\n{str(e)}" + ) + + def close(self) -> None: + return self._connection.close() + + def run_command( + self, + *, + cmd: str, + max_attempts: Optional[int] = None, + base_interval: Optional[int] = None, + lock_timeout: Optional[int] = None, + ) -> str: + """ + Run a command within an open SSH connection. + + Args: + cmd: Command to be run + max_attempts: + base_interval: + lock_timeout: + + Returns: + Standard output of the command, if successful. + """ + actual_max_attempts = self.default_max_attempts + if max_attempts is not None: + actual_max_attempts = max_attempts + + actual_base_interval = self.default_base_interval + if base_interval is not None: + actual_base_interval = base_interval + + actual_lock_timeout = self.default_lock_timeout + if lock_timeout is not None: + actual_lock_timeout = lock_timeout + + t_0 = time.perf_counter() + ind_attempt = 0 + while ind_attempt <= actual_max_attempts: + ind_attempt += 1 + prefix = f"[attempt {ind_attempt}/{actual_max_attempts}]" + self.logger.info(f"{prefix} START running '{cmd}' over SSH.") + try: + # Case 1: Command runs successfully + res = self.run( + cmd, lock_timeout=actual_lock_timeout, hide=True + ) + t_1 = time.perf_counter() + self.logger.info( + f"{prefix} END running '{cmd}' over SSH, " + f"elapsed {t_1-t_0:.3f}" ) + self.logger.debug(f"STDOUT: {res.stdout}") + self.logger.debug(f"STDERR: {res.stderr}") + return res.stdout + except NoValidConnectionsError as e: + # Case 2: Command fails with a connection error + self.logger.warning( + f"{prefix} Running command `{cmd}` over SSH failed.\n" + f"Original NoValidConnectionError:\n{str(e)}.\n" + f"{e.errors=}\n" + ) + if ind_attempt < actual_max_attempts: + sleeptime = actual_base_interval**ind_attempt + self.logger.warning( + f"{prefix} Now sleep {sleeptime:.3f} " + "seconds and continue." + ) + time.sleep(sleeptime) + else: + self.logger.error(f"{prefix} Reached last attempt") + break + except UnexpectedExit as e: + # Case 3: Command fails with an actual error + error_msg = ( + f"{prefix} Running command `{cmd}` over SSH failed.\n" + f"Original error:\n{str(e)}." + ) + self.logger.error(error_msg) + raise RuntimeError(error_msg) + except Exception as e: + self.logger.error( + f"Running command `{cmd}` over SSH failed.\n" + f"Original Error:\n{str(e)}." + ) + raise e + + raise RuntimeError( + f"Reached last attempt ({max_attempts=}) for running " + f"'{cmd}' over SSH" + ) + + def send_file( + self, + *, + local: str, + remote: str, + logger_name: Optional[str] = None, + lock_timeout: Optional[float] = None, + ) -> None: + """ + Transfer a file via SSH + + Args: + local: Local path to file + remote: Target path on remote host + fractal_ssh: FractalSSH connection object with custom lock + logger_name: Name of the logger + + """ + try: + self.put(local=local, remote=remote, lock_timeout=lock_timeout) + except Exception as e: + logger = get_logger(logger_name=logger_name) + logger.error( + f"Transferring {local=} to {remote=} over SSH failed.\n" + f"Original Error:\n{str(e)}." + ) + raise e + + def mkdir(self, *, folder: str, parents: bool = True) -> None: + """ + Create a folder remotely via SSH. + + Args: + folder: + fractal_ssh: + parents: + """ + # FIXME SSH: try using `mkdir` method of `paramiko.SFTPClient` + if parents: + cmd = f"mkdir -p {folder}" + else: + cmd = f"mkdir {folder}" + self.run_command(cmd=cmd) + + def remove_folder( + self, + *, + folder: str, + safe_root: str, + ) -> None: + """ + Removes a folder remotely via SSH. + + This functions calls `rm -r`, after a few checks on `folder`. + + Args: + folder: Absolute path to a folder that should be removed. + safe_root: If `folder` is not a subfolder of the absolute + `safe_root` path, raise an error. + fractal_ssh: + """ + invalid_characters = {" ", "\n", ";", "$", "`"} + + if ( + not isinstance(folder, str) + or not isinstance(safe_root, str) + or len(invalid_characters.intersection(folder)) > 0 + or len(invalid_characters.intersection(safe_root)) > 0 + or not Path(folder).is_absolute() + or not Path(safe_root).is_absolute() + or not Path(folder).resolve().is_relative_to(safe_root) + ): + raise ValueError( + f"{folder=} argument is invalid or it is not " + f"relative to {safe_root=}." + ) + else: + cmd = f"rm -r {folder}" + self.run_command(cmd=cmd) def get_ssh_connection( @@ -123,161 +337,7 @@ def get_ssh_connection( connection = Connection( host=host, user=user, + forward_agent=False, connect_kwargs={"key_filename": key_filename}, ) - logger.debug(f"Now created {connection=}.") return connection - - -def run_command_over_ssh( - *, - cmd: str, - fractal_ssh: FractalSSH, - max_attempts: int = MAX_ATTEMPTS, - base_interval: float = 3.0, -) -> str: - """ - Run a command within an open SSH connection. - - Args: - cmd: Command to be run - fractal_ssh: FractalSSH connection object with custom lock - - Returns: - Standard output of the command, if successful. - """ - t_0 = time.perf_counter() - ind_attempt = 0 - while ind_attempt <= max_attempts: - ind_attempt += 1 - prefix = f"[attempt {ind_attempt}/{max_attempts}]" - logger.info(f"{prefix} START running '{cmd}' over SSH.") - try: - # Case 1: Command runs successfully - res = fractal_ssh.run(cmd, hide=True) - t_1 = time.perf_counter() - logger.info( - f"{prefix} END running '{cmd}' over SSH, " - f"elapsed {t_1-t_0:.3f}" - ) - logger.debug(f"STDOUT: {res.stdout}") - logger.debug(f"STDERR: {res.stderr}") - return res.stdout - except NoValidConnectionsError as e: - # Case 2: Command fails with a connection error - logger.warning( - f"{prefix} Running command `{cmd}` over SSH failed.\n" - f"Original NoValidConnectionError:\n{str(e)}.\n" - f"{e.errors=}\n" - ) - if ind_attempt < max_attempts: - sleeptime = base_interval**ind_attempt - logger.warning( - f"{prefix} Now sleep {sleeptime:.3f} seconds and continue." - ) - time.sleep(sleeptime) - continue - else: - logger.error(f"{prefix} Reached last attempt") - break - except UnexpectedExit as e: - # Case 3: Command fails with an actual error - error_msg = ( - f"{prefix} Running command `{cmd}` over SSH failed.\n" - f"Original error:\n{str(e)}." - ) - logger.error(error_msg) - raise ValueError(error_msg) - except Exception as e: - logger.error( - f"Running command `{cmd}` over SSH failed.\n" - f"Original Error:\n{str(e)}." - ) - raise e - - raise ValueError( - f"Reached last attempt ({max_attempts=}) for running '{cmd}' over SSH" - ) - - -def put_over_ssh( - *, - local: str, - remote: str, - fractal_ssh: FractalSSH, - logger_name: Optional[str] = None, -) -> None: - """ - Transfer a file via SSH - - Args: - local: Local path to file - remote: Target path on remote host - fractal_ssh: FractalSSH connection object with custom lock - logger_name: Name of the logger - - """ - try: - fractal_ssh.put(local=local, remote=remote) - except Exception as e: - logger = get_logger(logger_name=logger_name) - logger.error( - f"Transferring {local=} to {remote=} over SSH failed.\n" - f"Original Error:\n{str(e)}." - ) - raise e - - -def _mkdir_over_ssh( - *, folder: str, fractal_ssh: FractalSSH, parents: bool = True -) -> None: - """ - Create a folder remotely via SSH. - - Args: - folder: - fractal_ssh: - parents: - """ - # FIXME SSH: try using `mkdir` method of `paramiko.SFTPClient` - if parents: - cmd = f"mkdir -p {folder}" - else: - cmd = f"mkdir {folder}" - run_command_over_ssh(cmd=cmd, fractal_ssh=fractal_ssh) - - -def remove_folder_over_ssh( - *, - folder: str, - safe_root: str, - fractal_ssh: FractalSSH, -) -> None: - """ - Removes a folder remotely via SSH. - - This functions calls `rm -r`, after a few checks on `folder`. - - Args: - folder: Absolute path to a folder that should be removed. - safe_root: If `folder` is not a subfolder of the absolute - `safe_root` path, raise an error. - fractal_ssh: - """ - invalid_characters = {" ", "\n", ";", "$", "`"} - - if ( - not isinstance(folder, str) - or len(invalid_characters.intersection(folder)) > 0 - or not Path(folder).is_absolute() - or not Path(safe_root).is_absolute() - or not Path(folder).resolve().is_relative_to(safe_root) - ): - raise ValueError( - f"{folder=} argument is invalid or it is not " - f"relative to {safe_root=}." - ) - else: - - cmd = f"rm -r {folder}" - run_command_over_ssh(cmd=cmd, fractal_ssh=fractal_ssh) diff --git a/fractal_server/string_tools.py b/fractal_server/string_tools.py new file mode 100644 index 0000000000..f23bef29f7 --- /dev/null +++ b/fractal_server/string_tools.py @@ -0,0 +1,45 @@ +import string + +__SPECIAL_CHARACTERS__ = f"{string.punctuation}{string.whitespace}" + + +def sanitize_string(value: str) -> str: + """ + Make string safe to be used in file/folder names and subprocess commands. + + Make the string lower-case, and replace any special character with an + underscore, where special characters are: + + + >>> string.punctuation + '!"#$%&\'()*+,-./:;<=>?@[\\\\]^_`{|}~' + >>> string.whitespace + ' \\t\\n\\r\\x0b\\x0c' + + Args: + value: Input string + + Returns: + Sanitized value + """ + new_value = value.lower() + for character in __SPECIAL_CHARACTERS__: + new_value = new_value.replace(character, "_") + return new_value + + +def slugify_task_name_for_source(task_name: str) -> str: + """ + NOTE: this function is used upon creation of tasks' sources, therefore + for the moment we cannot replace it with its more comprehensive version + from `fractal_server.string_tools.sanitize_string`, nor we can remove it. + As 2.3.1, we are renaming it to `slugify_task_name_for_source`, to make + it clear that it should not be used for other purposes. + + Args: + task_name: + + Return: + Slug-ified task name. + """ + return task_name.replace(" ", "_").lower() diff --git a/fractal_server/tasks/utils.py b/fractal_server/tasks/utils.py index 7ece7d42c0..46a3d63294 100644 --- a/fractal_server/tasks/utils.py +++ b/fractal_server/tasks/utils.py @@ -9,10 +9,6 @@ COLLECTION_FREEZE_FILENAME = "collection_freeze.txt" -def slugify_task_name(task_name: str) -> str: - return task_name.replace(" ", "_").lower() - - def get_absolute_venv_path(venv_path: Path) -> Path: """ If a path is not absolute, make it a relative path of FRACTAL_TASKS_DIR. diff --git a/fractal_server/tasks/v1/background_operations.py b/fractal_server/tasks/v1/background_operations.py index e548edc0bd..9192307f1b 100644 --- a/fractal_server/tasks/v1/background_operations.py +++ b/fractal_server/tasks/v1/background_operations.py @@ -6,11 +6,11 @@ from pathlib import Path from shutil import rmtree as shell_rmtree +from ...string_tools import slugify_task_name_for_source from ..utils import _normalize_package_name from ..utils import get_collection_log from ..utils import get_collection_path from ..utils import get_log_path -from ..utils import slugify_task_name from ._TaskCollectPip import _TaskCollectPip from .utils import _init_venv_v1 from fractal_server.app.db import DBSyncSession @@ -215,7 +215,7 @@ async def create_package_environment_pip( # Fill in attributes for TaskCreate task_executable = package_root / t.executable cmd = f"{python_bin.as_posix()} {task_executable.as_posix()}" - task_name_slug = slugify_task_name(t.name) + task_name_slug = slugify_task_name_for_source(t.name) task_source = f"{task_pkg.package_source}:{task_name_slug}" if not task_executable.exists(): raise FileNotFoundError( diff --git a/fractal_server/tasks/v2/background_operations.py b/fractal_server/tasks/v2/background_operations.py index f2f99d8aa6..293bbdf75b 100644 --- a/fractal_server/tasks/v2/background_operations.py +++ b/fractal_server/tasks/v2/background_operations.py @@ -10,12 +10,12 @@ from sqlalchemy.orm import Session as DBSyncSession from sqlalchemy.orm.attributes import flag_modified +from ...string_tools import slugify_task_name_for_source from ..utils import get_absolute_venv_path from ..utils import get_collection_freeze from ..utils import get_collection_log from ..utils import get_collection_path from ..utils import get_log_path -from ..utils import slugify_task_name from ._TaskCollectPip import _TaskCollectPip from fractal_server.app.db import get_sync_db from fractal_server.app.models.v2 import CollectionStateV2 @@ -177,7 +177,7 @@ def _prepare_tasks_metadata( task_attributes = {} if package_version is not None: task_attributes["version"] = package_version - task_name_slug = slugify_task_name(_task.name) + task_name_slug = slugify_task_name_for_source(_task.name) task_attributes["source"] = f"{package_source}:{task_name_slug}" if package_manifest.has_args_schemas: task_attributes[ diff --git a/fractal_server/tasks/v2/background_operations_ssh.py b/fractal_server/tasks/v2/background_operations_ssh.py index c96c18ddf1..d7e181f1be 100644 --- a/fractal_server/tasks/v2/background_operations_ssh.py +++ b/fractal_server/tasks/v2/background_operations_ssh.py @@ -18,9 +18,6 @@ from fractal_server.logger import get_logger from fractal_server.logger import set_logger from fractal_server.ssh._fabric import FractalSSH -from fractal_server.ssh._fabric import put_over_ssh -from fractal_server.ssh._fabric import remove_folder_over_ssh -from fractal_server.ssh._fabric import run_command_over_ssh from fractal_server.syringe import Inject from fractal_server.tasks.v2.utils import get_python_interpreter_v2 @@ -95,17 +92,15 @@ def _customize_and_run_template( f"script_{abs(hash(tmpdir))}{script_filename}", ) logger.debug(f"Now transfer {script_path_local=} over SSH.") - put_over_ssh( + fractal_ssh.send_file( local=script_path_local, remote=script_path_remote, - fractal_ssh=fractal_ssh, - logger_name=logger_name, ) # Execute script remotely cmd = f"bash {script_path_remote}" logger.debug(f"Now run '{cmd}' over SSH.") - stdout = run_command_over_ssh(cmd=cmd, fractal_ssh=fractal_ssh) + stdout = fractal_ssh.run_command(cmd=cmd) logger.debug(f"Standard output of '{cmd}':\n{stdout}") logger.debug(f"_customize_and_run_template {script_filename} - END") @@ -127,6 +122,7 @@ def background_collect_pip_ssh( starlette/fastapi handling of background tasks (see https://github.com/encode/starlette/blob/master/starlette/background.py). """ + # Work within a temporary folder, where also logs will be placed with TemporaryDirectory() as tmpdir: LOGGER_NAME = "task_collection_ssh" @@ -316,17 +312,16 @@ def background_collect_pip_ssh( exception=e, db=db, ) - logger.info(f"Now delete remote folder {package_env_dir}") try: - remove_folder_over_ssh( - remote_dir=package_env_dir, + logger.info(f"Now delete remote folder {package_env_dir}") + fractal_ssh.remove_folder( + folder=package_env_dir, safe_root=settings.FRACTAL_SLURM_SSH_WORKING_BASE_DIR, - fractal_ssh=fractal_ssh, ) logger.info(f"Deleted remoted folder {package_env_dir}") except Exception as e: logger.error( - f"Deleting remote folder {package_env_dir} failed.\n" + f"Deleting remote folder failed.\n" f"Original error:\n{str(e)}" ) return diff --git a/pyproject.toml b/pyproject.toml index 28dc6e8239..d1fd68d2da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fractal-server" -version = "2.3.0" +version = "2.3.2" description = "Server component of the Fractal analytics platform" authors = [ "Tommaso Comparin ", @@ -91,7 +91,7 @@ filterwarnings = [ ] [tool.bumpver] -current_version = "2.3.0" +current_version = "2.3.2" version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" commit_message = "bump version {old_version} -> {new_version}" commit = true diff --git a/tests/fixtures_docker.py b/tests/fixtures_docker.py index b09103e7d9..14aeb90226 100644 --- a/tests/fixtures_docker.py +++ b/tests/fixtures_docker.py @@ -1,3 +1,4 @@ +import io import logging import shlex import shutil @@ -5,11 +6,16 @@ import sys import time from pathlib import Path +from typing import Any +from typing import Generator import pytest +from fabric.connection import Connection from pytest import TempPathFactory from pytest_docker.plugin import containers_scope +from fractal_server.ssh._fabric import FractalSSH + HAS_LOCAL_SBATCH = bool(shutil.which("sbatch")) @@ -185,3 +191,27 @@ def ssh_alive(slurmlogin_ip, slurmlogin_container) -> None: return time.sleep(interval) raise RuntimeError(f"[ssh_alive] SSH not active on {slurmlogin_container}") + + +@pytest.fixture +def fractal_ssh( + slurmlogin_ip, + ssh_alive, + ssh_keys, + monkeypatch, +) -> Generator[FractalSSH, Any, None]: + ssh_private_key = ssh_keys["private"] + + # https://github.com/fabric/fabric/issues/1979 + # https://github.com/fabric/fabric/issues/2005#issuecomment-525664468 + monkeypatch.setattr("sys.stdin", io.StringIO("")) + + with Connection( + host=slurmlogin_ip, + user="fractal", + forward_agent=False, + connect_kwargs={"key_filename": ssh_private_key}, + ) as connection: + fractal_conn = FractalSSH(connection=connection) + fractal_conn.check_connection() + yield fractal_conn diff --git a/tests/no_version/test_string_tools.py b/tests/no_version/test_string_tools.py new file mode 100644 index 0000000000..5eb18157c9 --- /dev/null +++ b/tests/no_version/test_string_tools.py @@ -0,0 +1,12 @@ +from fractal_server.string_tools import __SPECIAL_CHARACTERS__ +from fractal_server.string_tools import sanitize_string + + +def test_unit_sanitize_string(): + for value in __SPECIAL_CHARACTERS__: + sanitized_value = sanitize_string(value) + assert sanitized_value == "_" + + value = "/some (rm) \t path *!" + expected_value = "_some__rm____path___" + assert sanitize_string(value) == expected_value diff --git a/tests/no_version/test_unit_test_zip_folder_to_byte_stream.py b/tests/no_version/test_unit_test_zip_folder_to_byte_stream.py new file mode 100644 index 0000000000..86dd156f34 --- /dev/null +++ b/tests/no_version/test_unit_test_zip_folder_to_byte_stream.py @@ -0,0 +1,39 @@ +from pathlib import Path +from zipfile import ZipFile + +from devtools import debug + +from fractal_server.app.routes.aux._job import _zip_folder_to_byte_stream + + +def test_zip_folder_to_byte_stream(tmp_path: Path): + debug(tmp_path) + + # Prepare file/folder structure + (tmp_path / "file1").touch() + (tmp_path / "file2").touch() + (tmp_path / "folder").mkdir() + (tmp_path / "folder/file3").touch() + (tmp_path / "folder/file4").touch() + + output = _zip_folder_to_byte_stream(folder=tmp_path.as_posix()) + + # Write BytesIO to file + archive_path = tmp_path / "zipped_folder.zip" + with archive_path.open("wb") as f: + f.write(output.getbuffer()) + + # Unzip the log archive + unzipped_archived_path = tmp_path / "unzipped_folder" + unzipped_archived_path.mkdir() + with ZipFile(archive_path.as_posix(), mode="r") as zipfile: + zipfile.extractall(path=unzipped_archived_path.as_posix()) + + # Verify that all expected items are present + glob_list = [file.name for file in unzipped_archived_path.rglob("*")] + debug(glob_list) + assert "file1" in glob_list + assert "file2" in glob_list + assert "folder" in glob_list + assert "file3" in glob_list + assert "file4" in glob_list diff --git a/tests/v1/07_full_workflow/test_full_workflow.py b/tests/v1/07_full_workflow/test_full_workflow.py index 59cef69253..8ff8654e37 100644 --- a/tests/v1/07_full_workflow/test_full_workflow.py +++ b/tests/v1/07_full_workflow/test_full_workflow.py @@ -728,10 +728,10 @@ async def test_non_python_task( assert f in glob_list # Check that stderr and stdout are as expected - with open(f"{working_dir}/0_non-python/0.out", "r") as f: + with open(f"{working_dir}/0_non_python/0.out", "r") as f: out = f.read() assert "This goes to standard output" in out - with open(f"{working_dir}/0_non-python/0.err", "r") as f: + with open(f"{working_dir}/0_non_python/0.err", "r") as f: err = f.read() assert "This goes to standard error" in err @@ -841,24 +841,24 @@ async def test_metadiff( "0.args.json", "0.err", "0.out", - "1_par_A.args.json", - "1_par_A.err", - "1_par_A.out", - "1_par_B.args.json", - "1_par_B.err", - "1_par_B.out", + "1_par_a.args.json", + "1_par_a.err", + "1_par_a.out", + "1_par_b.args.json", + "1_par_b.err", + "1_par_b.out", "2.args.json", "2.err", "2.out", "2.metadiff.json", - "3_par_A.args.json", - "3_par_A.err", - "3_par_A.out", - "3_par_B.args.json", - "3_par_B.err", - "3_par_B.out", - "3_par_A.metadiff.json", - "3_par_B.metadiff.json", + "3_par_a.args.json", + "3_par_a.err", + "3_par_a.out", + "3_par_b.args.json", + "3_par_b.err", + "3_par_b.out", + "3_par_a.metadiff.json", + "3_par_b.metadiff.json", WORKFLOW_LOG_FILENAME, ] for f in must_exist: diff --git a/tests/v2/00_ssh/test_FractalSSH.py b/tests/v2/00_ssh/test_FractalSSH.py new file mode 100644 index 0000000000..89b48f3061 --- /dev/null +++ b/tests/v2/00_ssh/test_FractalSSH.py @@ -0,0 +1,260 @@ +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest +from fabric import Connection +from paramiko.ssh_exception import NoValidConnectionsError + +from fractal_server.logger import set_logger +from fractal_server.ssh._fabric import FractalSSH +from fractal_server.ssh._fabric import FractalSSHTimeoutError + + +logger = set_logger(__file__) + + +def test_acquire_lock(): + """ + Test that the lock cannot be acquired twice. + """ + fake_fractal_ssh = FractalSSH(connection=Connection("localhost")) + fake_fractal_ssh._lock.acquire(timeout=0) + with pytest.raises(FractalSSHTimeoutError) as e: + with fake_fractal_ssh.acquire_timeout(timeout=0.1): + pass + print(e) + + +def test_run_command(fractal_ssh: FractalSSH): + """ + Basic working of `run_command` method. + """ + + # Successful remote execution + stdout = fractal_ssh.run_command( + cmd="whoami", + max_attempts=1, + base_interval=1.0, + lock_timeout=1.0, + ) + assert stdout.strip("\n") == "fractal" + + # When the remotely-executed command fails, a RuntimeError is raised. + with pytest.raises( + RuntimeError, match="Encountered a bad command exit code" + ): + fractal_ssh.run_command( + cmd="ls --invalid-option", + max_attempts=1, + base_interval=1.0, + lock_timeout=1.0, + ) + + +def test_run_command_concurrency(fractal_ssh: FractalSSH): + """ + Test locking feature for `run_command` method. + """ + + # Useful auxiliary function + def _run_sleep(label: str, lock_timeout: float): + logger.info(f"Start running with {label=} and {lock_timeout=}") + fractal_ssh.run_command(cmd="sleep 1", lock_timeout=lock_timeout) + + # Submit two commands to be run, with a large timeout for lock acquisition + with ThreadPoolExecutor(max_workers=2) as executor: + results_iterator = executor.map(_run_sleep, ["A", "B"], [2.0, 2.0]) + list(results_iterator) + + # Submit two commands to be run, with a small timeout for lock acquisition + with ThreadPoolExecutor(max_workers=2) as executor: + results_iterator = executor.map(_run_sleep, ["C", "D"], [0.1, 0.1]) + with pytest.raises( + FractalSSHTimeoutError, match="Failed to acquire lock" + ): + list(results_iterator) + + +def test_run_command_retries(fractal_ssh: FractalSSH): + """ + Test the multiple-attempts logic of `run_command`. + """ + + class MockFractalSSH(FractalSSH): + """ + Mock FractalSSH object, such that the first call to `run` always fails. + """ + + please_raise: bool + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.please_raise = True + + def run(self, *args, **kwargs): + if self.please_raise: + # Set `please_raise=False`, so that next call will go through + self.please_raise = False + # Construct a NoValidConnectionsError. Note that we prepare an + # `errors` attribute with the appropriate type, but with no + # meaningful content + errors = {("str", 1): ("str", 1, 1, 1)} + raise NoValidConnectionsError(errors=errors) + return super().run(*args, **kwargs) + + mocked_fractal_ssh = MockFractalSSH(connection=fractal_ssh._connection) + + # Call with max_attempts=1 fails + with pytest.raises(RuntimeError, match="Reached last attempt"): + mocked_fractal_ssh.run_command(cmd="whoami", max_attempts=1) + + # Call with max_attempts=2 goes through (note that we have to reset + # `please_raise`) + mocked_fractal_ssh.please_raise = True + stdout = mocked_fractal_ssh.run_command( + cmd="whoami", max_attempts=2, base_interval=0.1 + ) + assert stdout.strip() == "fractal" + + +def test_file_transfer(fractal_ssh: FractalSSH, tmp_path: Path): + """ + Test basic working of `send_file` and `get` methods. + """ + local_file_old = (tmp_path / "local_old").as_posix() + local_file_new = (tmp_path / "local_new").as_posix() + with open(local_file_old, "w") as f: + f.write("hi there\n") + + # Send file + fractal_ssh.send_file(local=local_file_old, remote="remote_file") + + # Get back file (note: we include the `lock_timeout` argument only + # for coverage of the corresponding conditional branch) + fractal_ssh.get( + remote="remote_file", local=local_file_new, lock_timeout=1.0 + ) + assert Path(local_file_new).is_file() + + +def test_send_file_concurrency(fractal_ssh: FractalSSH, tmp_path: Path): + local_file = (tmp_path / "local").as_posix() + with open(local_file, "w") as f: + f.write("x" * 10_000) + + def _send_file(remote: str, lock_timeout: float): + logger.info(f"Send file to {remote=}.") + fractal_ssh.send_file( + local=local_file, + remote=remote, + lock_timeout=lock_timeout, + ) + + # Try running two concurrent runs, with long lock timeout + with ThreadPoolExecutor(max_workers=2) as executor: + results_iterator = executor.map( + _send_file, ["remote1", "remote2"], [1.0, 1.0] + ) + list(results_iterator) + + # Try running two concurrent runs and fail, due to short lock timeout + with ThreadPoolExecutor(max_workers=2) as executor: + results_iterator = executor.map( + _send_file, ["remote3", "remote4"], [0.0, 0.0] + ) + with pytest.raises(FractalSSHTimeoutError) as e: + list(results_iterator) + assert "Failed to acquire lock" in str(e.value) + + +def test_folder_utils(tmp777_path, fractal_ssh: FractalSSH): + """ + Test basic working of `mkdir` and `remove_folder` methods. + """ + + # Define folder + folder = (tmp777_path / "nested/folder").as_posix() + + # Check that folder does not exist + with pytest.raises(RuntimeError) as e: + fractal_ssh.run_command(cmd=f"ls {folder}") + print(e.value) + + # Try to create folder, without parents options + with pytest.raises(RuntimeError) as e: + fractal_ssh.mkdir(folder=folder, parents=False) + print(e.value) + + # Create folder + fractal_ssh.mkdir(folder=folder, parents=True) + + # Check that folder exists + stdout = fractal_ssh.run_command(cmd=f"ls {folder}") + print(stdout) + print() + + # Remove folder + fractal_ssh.remove_folder(folder=folder, safe_root="/tmp") + + # Check that folder does not exist + with pytest.raises(RuntimeError) as e: + fractal_ssh.run_command(cmd=f"ls {folder}") + print(e.value) + + # Check that removing a missing folder fails + with pytest.raises(RuntimeError) as e: + fractal_ssh.remove_folder( + folder="/invalid/something", + safe_root="/invalid", + ) + print(e.value) + + +def test_remove_folder_input_validation(): + """ + Test input validation of `remove_folder` method. + """ + fake_fractal_ssh = FractalSSH(connection=Connection(host="localhost")) + + # Folders which are just invalid + invalid_folders = [ + None, + " /somewhere", + "/ somewhere", + "somewhere", + "$(pwd)", + "`pwd`", + ] + for folder in invalid_folders: + with pytest.raises(ValueError) as e: + fake_fractal_ssh.remove_folder(folder=folder, safe_root="/") + print(e.value) + + # Folders which are just invalid + invalid_folders = [ + None, + " /somewhere", + "/ somewhere", + "somewhere", + "$(pwd)", + "`pwd`", + ] + for safe_root in invalid_folders: + with pytest.raises(ValueError) as e: + fake_fractal_ssh.remove_folder( + folder="/tmp/something", + safe_root=safe_root, + ) + print(e.value) + + # Folders which are not relative to the accepted root + with pytest.raises(ValueError) as e: + fake_fractal_ssh.remove_folder(folder="/", safe_root="/tmp") + print(e.value) + + with pytest.raises(ValueError) as e: + fake_fractal_ssh.remove_folder( + folder="/actual_root/../something", + safe_root="/actual_root", + ) + print(e.value) diff --git a/tests/v2/00_ssh/test_executor.py b/tests/v2/00_ssh/test_executor.py index 5a03e86d8f..1f354aef59 100644 --- a/tests/v2/00_ssh/test_executor.py +++ b/tests/v2/00_ssh/test_executor.py @@ -1,83 +1,14 @@ -import io -import json import logging -import random from pathlib import Path -import pytest -from devtools import debug # noqa: F401 -from fabric.connection import Connection +from devtools import debug -from fractal_server.app.runner.exceptions import TaskExecutionError from fractal_server.app.runner.executors.slurm.ssh.executor import ( FractalSlurmSSHExecutor, ) # noqa from fractal_server.ssh._fabric import FractalSSH -def test_versions( - slurmlogin_ip, - ssh_alive, - slurmlogin_container, - monkeypatch, - ssh_keys: dict[str, str], -): - """ - Check the Python and fractal-server versions available on the cluster. - NOTE: This will later become a preliminary-check as part of the app - startup phase: check that Python has the same Major.Minor versions - and fractal-server has the same Major.Minor.Patch. - """ - monkeypatch.setattr("sys.stdin", io.StringIO("")) - - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"password": "fractal"}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - command = "/usr/bin/python3.9 --version" - print(f"COMMAND:\n{command}") - res = fractal_conn.run(command, hide=True) - print(f"STDOUT:\n{res.stdout}") - print(f"STDERR:\n{res.stderr}") - - python_command = "import fractal_server as fs; print(fs.__VERSION__);" - command = f"/usr/bin/python3.9 -c '{python_command}'" - - print(f"COMMAND:\n{command}") - res = fractal_conn.run(command, hide=True) - print(f"STDOUT:\n{res.stdout}") - print(f"STDERR:\n{res.stderr}") - - print("NOW AGAIN BUT USING KEY") - ssh_private_key = ssh_keys["private"] - debug(ssh_private_key) - debug(slurmlogin_ip) - - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"key_filename": ssh_private_key}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - command = "/usr/bin/python3.9 --version" - print(f"COMMAND:\n{command}") - res = fractal_conn.run(command, hide=True) - print(f"STDOUT:\n{res.stdout}") - print(f"STDERR:\n{res.stderr}") - - python_command = "import fractal_server as fs; print(fs.__VERSION__);" - command = f"/usr/bin/python3.9 -c '{python_command}'" - - print(f"COMMAND:\n{command}") - res = fractal_conn.run(command, hide=True) - print(f"STDOUT:\n{res.stdout}") - print(f"STDERR:\n{res.stderr}") - - # -o "StrictHostKeyChecking no" - - class MockFractalSSHSlurmExecutor(FractalSlurmSSHExecutor): """ When running from outside Fractal runner, task-specific subfolders @@ -107,180 +38,38 @@ def __init__(self, *args, **kwargs): def test_slurm_ssh_executor_submit( - slurmlogin_ip, - ssh_alive, - monkeypatch, + fractal_ssh, tmp_path: Path, tmp777_path: Path, - ssh_keys: dict[str, str], override_settings_factory, ): override_settings_factory(FRACTAL_SLURM_WORKER_PYTHON="/usr/bin/python3.9") - monkeypatch.setattr("sys.stdin", io.StringIO("")) - - ssh_private_key = ssh_keys["private"] - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"key_filename": ssh_private_key}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / "job_dir", - workflow_dir_remote=(tmp777_path / "remote_job_dir"), - slurm_poll_interval=1, - fractal_ssh=fractal_conn, - ) as executor: - fut = executor.submit(lambda: 1) - debug(fut) - debug(fut.result()) + with MockFractalSSHSlurmExecutor( + workflow_dir_local=tmp_path / "job_dir", + workflow_dir_remote=(tmp777_path / "remote_job_dir"), + slurm_poll_interval=1, + fractal_ssh=fractal_ssh, + ) as executor: + fut = executor.submit(lambda: 1) + debug(fut) + debug(fut.result()) def test_slurm_ssh_executor_map( - slurmlogin_ip, - ssh_alive, - monkeypatch, + fractal_ssh: FractalSSH, tmp_path: Path, tmp777_path: Path, - ssh_keys: dict[str, str], override_settings_factory, ): override_settings_factory(FRACTAL_SLURM_WORKER_PYTHON="/usr/bin/python3.9") - monkeypatch.setattr("sys.stdin", io.StringIO("")) - - ssh_private_key = ssh_keys["private"] - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"key_filename": ssh_private_key}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / "job_dir", - workflow_dir_remote=(tmp777_path / "remote_job_dir"), - slurm_poll_interval=1, - fractal_ssh=fractal_conn, - ) as executor: - res = executor.map(lambda x: x * 2, [1, 2, 3]) - results = list(res) - assert results == [2, 4, 6] - - -@pytest.mark.skip( - reason=( - "This is not up-to-date with the new FractalSlurmSSHExecutor " - "(switching from config kwargs to a single connection)." - ) -) -def test_slurm_ssh_executor_no_docker( - monkeypatch, - tmp_path, - testdata_path, - override_settings_factory, -): - """ - This test requires a configuration file pointing to a SLURM cluster - that can be reached via SSH. - """ - - # Define functions locally, to play well with cloudpickle - - def compute_square(x): - return x**2 - - def raise_error_for_even_argument(x): - if x % 2 == 0: - raise ValueError(f"The argument {x} is even. Fail.") - - ssh_config_file = testdata_path / "ssh_config.json" - if not ssh_config_file.exists(): - logging.warning(f"Missing {ssh_config_file} -- skip test.") - return - - random.seed(tmp_path.as_posix()) - random_id = random.randrange(0, 999999) - - monkeypatch.setattr("sys.stdin", io.StringIO("")) - - with ssh_config_file.open("r") as f: - config = json.load(f)["uzh2"] - debug(config) - - remote_python = config.pop("remote_python") - root_dir_remote = Path(config.pop("root_dir_remote")) - override_settings_factory( - FRACTAL_SLURM_WORKER_PYTHON=remote_python, - ) - from fractal_server.app.runner.executors.slurm._slurm_config import ( - get_default_slurm_config, - ) - - slurm_config = get_default_slurm_config() - slurm_config.partition = config.pop("partition") - slurm_config.mem_per_task_MB = config.pop("mem_per_task_MB") - - debug(slurm_config) - - # submit method - label = f"{random_id}_0_submit" - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / f"local_job_dir_{label}", - workflow_dir_remote=root_dir_remote / f"remote_job_dir_{label}", - slurm_poll_interval=1, - **config, - ) as executor: - arg = 2 - fut = executor.submit(compute_square, arg, slurm_config=slurm_config) - debug(fut) - assert fut.result() == compute_square(arg) - - # map method (few values) - label = f"{random_id}_1_map_few" - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / f"local_job_dir_{label}", - workflow_dir_remote=root_dir_remote / f"remote_job_dir_{label}", - slurm_poll_interval=1, - **config, - ) as executor: - inputs = list(range(3)) - slurm_res = executor.map(compute_square, inputs) - assert list(slurm_res) == list(map(compute_square, inputs)) - - # map method (few values) - label = f"{random_id}_2_map_many" - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / f"local_job_dir_{label}", - workflow_dir_remote=root_dir_remote / f"remote_job_dir_{label}", - slurm_poll_interval=1, - **config, - ) as executor: - inputs = list(range(200)) - slurm_res = executor.map(compute_square, inputs) - assert list(slurm_res) == list(map(compute_square, inputs)) - - # submit method (fail) - label = f"{random_id}_3_submit_fail" - with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / f"local_job_dir_{label}", - workflow_dir_remote=root_dir_remote / f"remote_job_dir_{label}", - slurm_poll_interval=1, - **config, - ) as executor: - future = executor.submit(raise_error_for_even_argument, 2) - with pytest.raises(TaskExecutionError): - future.result() - - # map method (fail) - label = f"{random_id}_4_map_fail" with MockFractalSSHSlurmExecutor( - workflow_dir_local=tmp_path / f"local_job_dir_{label}", - workflow_dir_remote=root_dir_remote / f"remote_job_dir_{label}", + workflow_dir_local=tmp_path / "job_dir", + workflow_dir_remote=(tmp777_path / "remote_job_dir"), slurm_poll_interval=1, - **config, + fractal_ssh=fractal_ssh, ) as executor: - inputs = [1, 3, 5, 6, 2, 7, 4] - slurm_res = executor.map(raise_error_for_even_argument, inputs) - with pytest.raises(TaskExecutionError): - list(slurm_res) + res = executor.map(lambda x: x * 2, [1, 2, 3]) + results = list(res) + assert results == [2, 4, 6] diff --git a/tests/v2/00_ssh/test_setup.py b/tests/v2/00_ssh/test_setup.py new file mode 100644 index 0000000000..675b725c2e --- /dev/null +++ b/tests/v2/00_ssh/test_setup.py @@ -0,0 +1,37 @@ +import pytest +from fabric.connection import Connection + +import fractal_server +from fractal_server.ssh._fabric import FractalSSH + + +def test_check_connection_failure(): + + with Connection( + host="localhost", + user="invalid", + forward_agent=False, + connect_kwargs={"password": "invalid"}, + ) as connection: + this_fractal_ssh = FractalSSH(connection=connection) + with pytest.raises(RuntimeError): + this_fractal_ssh.check_connection() + + +def test_versions(fractal_ssh: FractalSSH): + """ + Check the Python and fractal-server versions available on the cluster. + """ + + command = "/usr/bin/python3.9 --version" + print(f"COMMAND:\n{command}") + stdout = fractal_ssh.run_command(cmd=command) + print(f"STDOUT:\n{stdout}") + + python_command = "import fractal_server as fs; print(fs.__VERSION__);" + command = f"/usr/bin/python3.9 -c '{python_command}'" + + print(f"COMMAND:\n{command}") + stdout = fractal_ssh.run_command(cmd=command) + print(f"STDOUT:\n{stdout}") + assert stdout.strip() == str(fractal_server.__VERSION__) diff --git a/tests/v2/00_ssh/test_task_collection_ssh.py b/tests/v2/00_ssh/test_task_collection_ssh.py index e5f5e57427..514b138a7a 100644 --- a/tests/v2/00_ssh/test_task_collection_ssh.py +++ b/tests/v2/00_ssh/test_task_collection_ssh.py @@ -1,12 +1,8 @@ -import io from pathlib import Path -import pytest -from devtools import debug # noqa: F401 -from fabric.connection import Connection +from devtools import debug from fractal_server.app.models.v2.collection_state import CollectionStateV2 -from fractal_server.ssh._fabric import _mkdir_over_ssh from fractal_server.ssh._fabric import FractalSSH from fractal_server.tasks.v2._TaskCollectPip import _TaskCollectPip from fractal_server.tasks.v2.background_operations_ssh import ( @@ -14,27 +10,8 @@ ) -@pytest.fixture -def fractal_ssh( - slurmlogin_ip, - ssh_alive, - ssh_keys, - monkeypatch, -): - ssh_private_key = ssh_keys["private"] - monkeypatch.setattr("sys.stdin", io.StringIO("")) - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"key_filename": ssh_private_key}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - fractal_conn.check_connection() - yield fractal_conn - - async def test_task_collection_ssh( - fractal_ssh, + fractal_ssh: FractalSSH, db, override_settings_factory, tmp777_path: Path, @@ -43,8 +20,9 @@ async def test_task_collection_ssh( remote_basedir = (tmp777_path / "WORKING_BASE_DIR").as_posix() debug(remote_basedir) - _mkdir_over_ssh( - folder=remote_basedir, fractal_ssh=fractal_ssh, parents=True + fractal_ssh.mkdir( + folder=remote_basedir, + parents=True, ) override_settings_factory( @@ -82,7 +60,7 @@ async def test_task_collection_ssh( async def test_task_collection_ssh_failure( - fractal_ssh, + fractal_ssh: FractalSSH, db, override_settings_factory, tmp777_path: Path, @@ -91,9 +69,7 @@ async def test_task_collection_ssh_failure( remote_basedir = (tmp777_path / "WORKING_BASE_DIR").as_posix() debug(remote_basedir) - _mkdir_over_ssh( - folder=remote_basedir, fractal_ssh=fractal_ssh, parents=True - ) + fractal_ssh.mkdir(folder=remote_basedir, parents=True) override_settings_factory( FRACTAL_SLURM_WORKER_PYTHON="/usr/bin/python3.9", @@ -127,4 +103,4 @@ async def test_task_collection_ssh_failure( # host machine, because /tmp is shared with the container) venv_dir = Path(remote_basedir) / ".fractal/fractal-tasks-core99.99.99" debug(venv_dir) - assert venv_dir.is_dir() + assert not venv_dir.is_dir() diff --git a/tests/v2/00_ssh/test_unit_fabric_connection.py b/tests/v2/00_ssh/test_unit_fabric_connection.py deleted file mode 100644 index 8aab78f15a..0000000000 --- a/tests/v2/00_ssh/test_unit_fabric_connection.py +++ /dev/null @@ -1,51 +0,0 @@ -import io - -import pytest -from fabric.connection import Connection - -from fractal_server.ssh._fabric import FractalSSH - - -def test_unit_fabric_connection( - slurmlogin_ip, ssh_alive, slurmlogin_container, monkeypatch -): - """ - Test both the pytest-docker setup and the use of a `fabric` connection, by - running the `hostname` over SSH. - """ - print(f"{slurmlogin_ip=}") - - command = "hostname" - print(f"Now run {command=} at {slurmlogin_ip=}") - - # https://github.com/fabric/fabric/issues/1979 - # https://github.com/fabric/fabric/issues/2005#issuecomment-525664468 - monkeypatch.setattr("sys.stdin", io.StringIO("")) - - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"password": "fractal"}, - ) as connection: - - res = connection.run(command, hide=True) - print(f"STDOUT:\n{res.stdout}") - print(f"STDERR:\n{res.stderr}") - assert res.stdout.strip("\n") == "slurmhead" - - # Test also FractalSSH - fractal_conn = FractalSSH(connection=connection) - assert fractal_conn.is_connected - fractal_conn.check_connection() - res = fractal_conn.run(command, hide=True) - assert res.stdout.strip("\n") == "slurmhead" - - with Connection( - host=slurmlogin_ip, - user="x", - connect_kwargs={"password": "x"}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - # raise error if there is not a connection available - with pytest.raises(RuntimeError): - fractal_conn.check_connection() diff --git a/tests/v2/00_ssh/test_unit_remove_folder_over_ssh.py b/tests/v2/00_ssh/test_unit_remove_folder_over_ssh.py deleted file mode 100644 index 0672cefd03..0000000000 --- a/tests/v2/00_ssh/test_unit_remove_folder_over_ssh.py +++ /dev/null @@ -1,102 +0,0 @@ -import io - -import pytest -from fabric.connection import Connection - -from fractal_server.ssh._fabric import FractalSSH -from fractal_server.ssh._fabric import remove_folder_over_ssh -from fractal_server.ssh._fabric import run_command_over_ssh - - -@pytest.fixture -def fractal_ssh( - slurmlogin_ip, - ssh_alive, - ssh_keys, - monkeypatch, -): - ssh_private_key = ssh_keys["private"] - monkeypatch.setattr("sys.stdin", io.StringIO("")) - with Connection( - host=slurmlogin_ip, - user="fractal", - connect_kwargs={"key_filename": ssh_private_key}, - ) as connection: - fractal_conn = FractalSSH(connection=connection) - fractal_conn.check_connection() - yield fractal_conn - - -def test_unit_remove_folder_over_ssh_failures(): - # Folders which are just invalid - invalid_folders = [ - None, - " /somewhere", - "/ somewhere", - "somewhere", - "$(pwd)", - "`pwd`", - ] - for folder in invalid_folders: - with pytest.raises(ValueError) as e: - remove_folder_over_ssh( - folder=folder, safe_root="/", fractal_ssh=None - ) - print(e.value) - - # Folders which are not relative to the accepted root - with pytest.raises(ValueError) as e: - remove_folder_over_ssh(folder="/", safe_root="/tmp", fractal_ssh=None) - print(e.value) - - with pytest.raises(ValueError) as e: - remove_folder_over_ssh( - folder="/actual_root/../something", - safe_root="/actual_root", - fractal_ssh=None, - ) - print(e.value) - - -def test_unit_remove_folder_over_ssh(tmp777_path, fractal_ssh): - - assert fractal_ssh.is_connected - - # Define folder - folder = (tmp777_path / "folder").as_posix() - - # Check that folder does not exist - with pytest.raises(ValueError) as e: - run_command_over_ssh(cmd=f"ls {folder}", fractal_ssh=fractal_ssh) - print(e.value) - - # Create folder - stdout = run_command_over_ssh( - cmd=f"mkdir -p {folder}", fractal_ssh=fractal_ssh - ) - print(stdout) - print() - - # Check that folder exists - stdout = run_command_over_ssh(cmd=f"ls {folder}", fractal_ssh=fractal_ssh) - print(stdout) - print() - - # Remove folder - remove_folder_over_ssh( - folder=folder, safe_root="/tmp", fractal_ssh=fractal_ssh - ) - - # Check that folder does not exist - with pytest.raises(ValueError) as e: - run_command_over_ssh(cmd=f"ls {folder}", fractal_ssh=fractal_ssh) - print(e.value) - - # Check that removing a missing folder fails - with pytest.raises(ValueError) as e: - remove_folder_over_ssh( - folder="/invalid/something", - safe_root="/invalid", - fractal_ssh=fractal_ssh, - ) - print(e.value) diff --git a/tests/v2/03_api/test_api_task_collection_failures.py b/tests/v2/03_api/test_api_task_collection_failures.py index 2326b8ae9a..d33c0da350 100644 --- a/tests/v2/03_api/test_api_task_collection_failures.py +++ b/tests/v2/03_api/test_api_task_collection_failures.py @@ -285,3 +285,22 @@ async def test_remove_directory( ) assert res.status_code == 201 assert os.path.isdir(DIRECTORY) is True + + +async def test_invalid_python_version( + client, + MockCurrentUser, + override_settings_factory, +): + override_settings_factory( + FRACTAL_TASKS_PYTHON_3_9=None, + ) + + async with MockCurrentUser(user_kwargs=dict(is_verified=True)): + res = await client.post( + f"{PREFIX}/collect/pip/", + json=dict(package="invalid-task-package", python_version="3.9"), + ) + assert res.status_code == 422 + assert "Python version 3.9 is not available" in res.json()["detail"] + debug(res.json()["detail"]) diff --git a/tests/v2/03_api/test_api_task_collection_ssh.py b/tests/v2/03_api/test_api_task_collection_ssh.py new file mode 100644 index 0000000000..3bffc87c47 --- /dev/null +++ b/tests/v2/03_api/test_api_task_collection_ssh.py @@ -0,0 +1,98 @@ +from pathlib import Path + +import pytest + +from fractal_server.app.schemas.v2 import CollectionStatusV2 +from fractal_server.ssh._fabric import FractalSSH + + +PREFIX = "api/v2/task" + + +async def test_task_collection_ssh_from_pypi( + db, + app, + client, + MockCurrentUser, + override_settings_factory, + tmp_path: Path, + tmp777_path: Path, + fractal_ssh: FractalSSH, +): + + # Define and create remote working directory + WORKING_BASE_DIR = (tmp777_path / "working_dir").as_posix() + fractal_ssh.mkdir(folder=WORKING_BASE_DIR) + + # Assign FractalSSH object to app state + app.state.fractal_ssh = fractal_ssh + + # Override settins with Python/SSH configurations + override_settings_factory( + FRACTAL_TASKS_PYTHON_DEFAULT_VERSION="3.9", + FRACTAL_TASKS_PYTHON_3_9="/usr/bin/python3.9", + FRACTAL_RUNNER_BACKEND="slurm_ssh", + FRACTAL_SLURM_SSH_WORKING_BASE_DIR=WORKING_BASE_DIR, + ) + + async with MockCurrentUser(user_kwargs=dict(is_verified=True)): + + # CASE 1: successful collection + + # Trigger task collection + PACKAGE_VERSION = "1.0.2" + res = await client.post( + f"{PREFIX}/collect/pip/", + json=dict( + package="fractal-tasks-core", + package_version=PACKAGE_VERSION, + python_version="3.9", + ), + ) + assert res.status_code == 201 + assert res.json()["data"]["status"] == CollectionStatusV2.PENDING + state_id = res.json()["id"] + + # Get collection info + res = await client.get(f"{PREFIX}/collect/{state_id}/") + assert res.status_code == 200 + data = res.json()["data"] + assert data["status"] == CollectionStatusV2.OK + assert f"fractal-tasks-core=={PACKAGE_VERSION}" in data["freeze"] + remote_folder = ( + Path(WORKING_BASE_DIR) + / ".fractal" + / f"fractal-tasks-core{PACKAGE_VERSION}" + ).as_posix() + fractal_ssh.run_command(cmd=f"ls {remote_folder}") + + # CASE 2: Failure due to invalid version + + # Trigger task collection + PACKAGE_VERSION = "9.9.9" + res = await client.post( + f"{PREFIX}/collect/pip/", + json=dict( + package="fractal-tasks-core", + package_version=PACKAGE_VERSION, + python_version="3.9", + ), + ) + assert res.status_code == 201 + assert res.json()["data"]["status"] == CollectionStatusV2.PENDING + state_id = res.json()["id"] + + # Get collection info + res = await client.get(f"{PREFIX}/collect/{state_id}/") + assert res.status_code == 200 + data = res.json()["data"] + assert data["status"] == CollectionStatusV2.FAIL + assert "No matching distribution found" in data["log"] + assert f"fractal-tasks-core=={PACKAGE_VERSION}" in data["log"] + remote_folder = ( + Path(WORKING_BASE_DIR) + / ".fractal" + / f"fractal-tasks-core{PACKAGE_VERSION}" + ).as_posix() + with pytest.raises(RuntimeError, match="No such file or directory"): + fractal_ssh.run_command(cmd=f"ls {remote_folder}") diff --git a/tests/v2/08_full_workflow/common_functions.py b/tests/v2/08_full_workflow/common_functions.py index 52393b2109..fec87c32c3 100644 --- a/tests/v2/08_full_workflow/common_functions.py +++ b/tests/v2/08_full_workflow/common_functions.py @@ -607,7 +607,7 @@ async def workflow_with_non_python_task( raise ValueError(f"{f} must exist, but {glob_list=}") # Check that stderr and stdout are as expected - with open(f"{working_dir}/0_non-python/0.log", "r") as f: + with open(f"{working_dir}/0_non_python/0.log", "r") as f: log = f.read() assert "This goes to standard output" in log assert "This goes to standard error" in log