diff --git a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py index 8ea228653e5..27b43cea55d 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/container_tasks/events.py @@ -3,7 +3,7 @@ from typing import TypeAlias, Union from distributed.worker import get_worker -from pydantic import BaseModel, Extra, NonNegativeFloat +from pydantic import BaseModel, Extra, validator class BaseTaskEvent(BaseModel, ABC): @@ -20,7 +20,7 @@ class Config: class TaskProgressEvent(BaseTaskEvent): - progress: NonNegativeFloat + progress: float @staticmethod def topic_name() -> str: @@ -44,6 +44,13 @@ class Config(BaseTaskEvent.Config): ] } + @validator("progress", always=True) + @classmethod + def ensure_between_0_1(cls, v): + if 0 <= v <= 1: + return v + return min(max(0, v), 1) + LogMessageStr: TypeAlias = str LogLevelInt: TypeAlias = int diff --git a/packages/models-library/src/models_library/projects_nodes.py b/packages/models-library/src/models_library/projects_nodes.py index 7a68324aada..de76af683e9 100644 --- a/packages/models-library/src/models_library/projects_nodes.py +++ b/packages/models-library/src/models_library/projects_nodes.py @@ -85,6 +85,12 @@ class NodeState(BaseModel): description="the node's current state", alias="currentStatus", ) + progress: float | None = Field( + default=0, + ge=0.0, + le=1.0, + description="current progress of the task if available (None if not started or not a computational task)", + ) class Config: extra = Extra.forbid @@ -133,7 +139,11 @@ class Node(BaseModel): ..., description="The short name of the node", examples=["JupyterLab"] ) progress: float | None = Field( - default=None, ge=0, le=100, description="the node progress value" + default=None, + ge=0, + le=100, + description="the node progress value", + deprecated=True, ) thumbnail: HttpUrlWithCustomMinLength | None = Field( default=None, diff --git a/packages/models-library/src/models_library/projects_pipeline.py b/packages/models-library/src/models_library/projects_pipeline.py index 0b6d29e2440..ce142118e54 100644 --- a/packages/models-library/src/models_library/projects_pipeline.py +++ b/packages/models-library/src/models_library/projects_pipeline.py @@ -1,4 +1,3 @@ -from typing import Dict, List, Optional from uuid import UUID from pydantic import BaseModel, Field, PositiveInt @@ -9,11 +8,17 @@ class PipelineDetails(BaseModel): - adjacency_list: Dict[NodeID, List[NodeID]] = Field( + adjacency_list: dict[NodeID, list[NodeID]] = Field( ..., description="The adjacency list of the current pipeline in terms of {NodeID: [successor NodeID]}", ) - node_states: Dict[NodeID, NodeState] = Field( + progress: float | None = Field( + ..., + ge=0, + le=1.0, + description="the progress of the pipeline (None if there are no computational tasks)", + ) + node_states: dict[NodeID, NodeState] = Field( ..., description="The states of each of the computational nodes in the pipeline" ) @@ -24,17 +29,15 @@ class PipelineDetails(BaseModel): class ComputationTask(BaseModel): id: TaskID = Field(..., description="the id of the computation task") state: RunningState = Field(..., description="the state of the computational task") - result: Optional[str] = Field( - None, description="the result of the computational task" - ) + result: str | None = Field(None, description="the result of the computational task") pipeline_details: PipelineDetails = Field( ..., description="the details of the generated pipeline" ) - iteration: Optional[PositiveInt] = Field( + iteration: PositiveInt | None = Field( ..., description="the iteration id of the computation task (none if no task ran yet)", ) - cluster_id: Optional[ClusterID] = Field( + cluster_id: ClusterID | None = Field( ..., description="the cluster on which the computaional task runs/ran (none if no task ran yet)", ) @@ -56,14 +59,17 @@ class Config: "2fb4808a-e403-4a46-b52c-892560d27862": { "modified": True, "dependencies": [], + "progress": 0.0, }, "19a40c7b-0a40-458a-92df-c77a5df7c886": { "modified": False, "dependencies": [ "2fb4808a-e403-4a46-b52c-892560d27862" ], + "progress": 0.0, }, }, + "progress": 0.0, }, "iteration": None, "cluster_id": None, @@ -82,14 +88,17 @@ class Config: "2fb4808a-e403-4a46-b52c-892560d27862": { "modified": False, "dependencies": [], + "progress": 1.0, }, "19a40c7b-0a40-458a-92df-c77a5df7c886": { "modified": False, "dependencies": [ "2fb4808a-e403-4a46-b52c-892560d27862" ], + "progress": 1.0, }, }, + "progress": 1.0, }, "iteration": 2, "cluster_id": 0, diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/0c084cb1091c_add_progress_to_comp_tasks.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/0c084cb1091c_add_progress_to_comp_tasks.py new file mode 100644 index 00000000000..6c13987346a --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/0c084cb1091c_add_progress_to_comp_tasks.py @@ -0,0 +1,30 @@ +"""add progress to comp_tasks + +Revision ID: 0c084cb1091c +Revises: 432aa859098b +Create Date: 2023-05-05 08:00:18.951040+00:00 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0c084cb1091c" +down_revision = "432aa859098b" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "comp_tasks", + sa.Column("progress", sa.Numeric(precision=3, scale=2), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("comp_tasks", "progress") + # ### end Alembic commands ### diff --git a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py index c1fd22b7893..e90c178d7bb 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py @@ -56,7 +56,7 @@ class NodeClass(enum.Enum): sa.Enum(StateType), nullable=False, server_default=StateType.NOT_STARTED.value, - doc="Current state in the task lifecicle", + doc="Current state in the task lifecycle", ), sa.Column( "errors", @@ -65,6 +65,12 @@ class NodeClass(enum.Enum): doc="List[models_library.errors.ErrorDict] with error information" " for a failing state, otherwise set to None", ), + sa.Column( + "progress", + sa.Numeric(precision=3, scale=2), # numbers from 0.00 and 1.00 + nullable=True, + doc="current progress of the task if available", + ), # utc timestamps for submission/start/end sa.Column("submit", sa.DateTime, doc="UTC timestamp for task submission"), sa.Column("start", sa.DateTime, doc="UTC timestamp when task started"), diff --git a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py index 10614f4dd62..59cfbd90ab5 100644 --- a/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py +++ b/packages/pytest-simcore/src/pytest_simcore/db_entries_mocks.py @@ -39,13 +39,13 @@ def creator(**user_kwargs) -> dict[str, Any]: ) # this is needed to get the primary_gid correctly result = con.execute( - sa.select([users]).where(users.c.id == user_config["id"]) + sa.select(users).where(users.c.id == user_config["id"]) ) user = result.first() assert user print(f"--> created {user=}") created_user_ids.append(user["id"]) - return dict(user) + return dict(user._asdict()) yield creator @@ -81,7 +81,7 @@ def creator(user: dict[str, Any], **overrides) -> ProjectAtDB: .returning(sa.literal_column("*")) ) - inserted_project = ProjectAtDB.parse_obj(result.first()) + inserted_project = ProjectAtDB.from_orm(result.first()) print(f"--> created {inserted_project=}") created_project_ids.append(f"{inserted_project.uuid}") return inserted_project diff --git a/packages/pytest-simcore/src/pytest_simcore/docker_swarm.py b/packages/pytest-simcore/src/pytest_simcore/docker_swarm.py index a3645a95e9d..3e6806df687 100644 --- a/packages/pytest-simcore/src/pytest_simcore/docker_swarm.py +++ b/packages/pytest-simcore/src/pytest_simcore/docker_swarm.py @@ -44,7 +44,7 @@ def _is_docker_swarm_init(docker_client: docker.client.DockerClient) -> bool: @retry( - wait=wait_fixed(5), + wait=wait_fixed(1), stop=stop_after_delay(8 * MINUTE), before_sleep=before_sleep_log(log, logging.WARNING), reraise=True, @@ -183,25 +183,26 @@ def docker_swarm( retry=retry_if_exception_type(AssertionError), stop=stop_after_delay(30), ) -def _wait_for_new_task_to_be_started(service: Any, old_task_ids: set[str]) -> None: - service.reload() - new_task_ids = {t["ID"] for t in service.tasks()} - assert len(new_task_ids.difference(old_task_ids)) == 1 +def _wait_for_migration_service_to_be_removed( + docker_client: docker.client.DockerClient, +) -> None: + for service in docker_client.services.list(): + if "migration" in service.name: # type: ignore + raise TryAgain -def _force_restart_migration_service(docker_client: docker.client.DockerClient) -> None: +def _force_remove_migration_service(docker_client: docker.client.DockerClient) -> None: for migration_service in ( service for service in docker_client.services.list() - if "migration" in service.name + if "migration" in service.name # type: ignore ): print( - "WARNING: migration service detected before updating stack, it will be force-updated" + "WARNING: migration service detected before updating stack, it will be force-removed now and re-deployed to ensure DB update" ) - before_update_task_ids = {t["ID"] for t in migration_service.tasks()} - migration_service.force_update() - _wait_for_new_task_to_be_started(migration_service, before_update_task_ids) - print(f"forced updated {migration_service.name}.") + migration_service.remove() # type: ignore + _wait_for_migration_service_to_be_removed(docker_client) + print(f"forced updated {migration_service.name}.") # type: ignore def _deploy_stack(compose_file: Path, stack_name: str) -> None: @@ -273,7 +274,7 @@ def docker_stack( # NOTE: if the migration service was already running prior to this call it must # be force updated so that it does its job. else it remains and tests will fail - _force_restart_migration_service(docker_client) + _force_remove_migration_service(docker_client) # make up-version stacks_deployed: dict[str, dict] = {} diff --git a/packages/pytest-simcore/src/pytest_simcore/services_api_mocks_for_aiohttp_clients.py b/packages/pytest-simcore/src/pytest_simcore/services_api_mocks_for_aiohttp_clients.py index 676b1b48553..7d596347085 100644 --- a/packages/pytest-simcore/src/pytest_simcore/services_api_mocks_for_aiohttp_clients.py +++ b/packages/pytest-simcore/src/pytest_simcore/services_api_mocks_for_aiohttp_clients.py @@ -70,7 +70,6 @@ def create_computation_cb(url, **kwargs) -> CallbackResult: - assert "json" in kwargs, f"missing body in call to {url}" body = kwargs["json"] for param in ["user_id", "project_id"]: @@ -113,6 +112,7 @@ def create_computation_cb(url, **kwargs) -> CallbackResult: "pipeline_details": { "adjacency_list": pipeline, "node_states": node_states, + "progress": 0, }, }, ) @@ -131,6 +131,7 @@ def get_computation_cb(url, **kwargs) -> CallbackResult: "pipeline_details": { "adjacency_list": pipeline, "node_states": node_states, + "progress": 0, }, "iteration": 2, "cluster_id": 23, @@ -350,7 +351,6 @@ def get_upload_link_cb(url: URL, **kwargs) -> CallbackResult: scheme = {LinkType.PRESIGNED: "http", LinkType.S3: "s3"} if file_size := kwargs["params"].get("file_size") is not None: - upload_schema = FileUploadSchema( chunk_size=parse_obj_as(ByteSize, "5GiB"), urls=[parse_obj_as(AnyUrl, f"{scheme[link_type]}://{file_id}")], diff --git a/packages/service-library/src/servicelib/long_running_tasks/_models.py b/packages/service-library/src/servicelib/long_running_tasks/_models.py index d0a035b9e18..b90fbaf787b 100644 --- a/packages/service-library/src/servicelib/long_running_tasks/_models.py +++ b/packages/service-library/src/servicelib/long_running_tasks/_models.py @@ -2,13 +2,13 @@ import urllib.parse from asyncio import Task from datetime import datetime -from typing import Any, Awaitable, Callable, Coroutine, Optional +from typing import Any, Awaitable, Callable, Coroutine from pydantic import ( BaseModel, + ConstrainedFloat, Field, PositiveFloat, - confloat, validate_arguments, validator, ) @@ -20,7 +20,13 @@ TaskType = Callable[..., Coroutine[Any, Any, Any]] ProgressMessage = str -ProgressPercent = confloat(ge=0.0, le=1.0) + + +class ProgressPercent(ConstrainedFloat): + ge = 0.0 + le = 1.0 + + ProgressCallback = Callable[[ProgressMessage, ProgressPercent, TaskId], Awaitable[None]] @@ -41,8 +47,8 @@ class TaskProgress(BaseModel): def update( self, *, - message: Optional[ProgressMessage] = None, - percent: Optional[ProgressPercent] = None, + message: ProgressMessage | None = None, + percent: ProgressPercent | None = None, ) -> None: """`percent` must be between 0.0 and 1.0 otherwise ValueError is raised""" if message: @@ -77,7 +83,7 @@ class TrackedTask(BaseModel): ) started: datetime = Field(default_factory=datetime.utcnow) - last_status_check: Optional[datetime] = Field( + last_status_check: datetime | None = Field( default=None, description=( "used to detect when if the task is not actively " @@ -96,8 +102,8 @@ class TaskStatus(BaseModel): class TaskResult(BaseModel): - result: Optional[Any] - error: Optional[Any] + result: Any | None + error: Any | None class ClientConfiguration(BaseModel): diff --git a/services/api-server/src/simcore_service_api_server/api/routes/solvers_jobs.py b/services/api-server/src/simcore_service_api_server/api/routes/solvers_jobs.py index 84829fd5de8..d2120a8335a 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/solvers_jobs.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/solvers_jobs.py @@ -46,7 +46,7 @@ def _compose_job_resource_name(solver_key, solver_version, job_id) -> str: ) -## JOBS --------------- +# JOBS --------------- # # - Similar to docker container's API design (container = job and image = solver) # @@ -224,7 +224,7 @@ async def inspect_job( job_id: UUID, user_id: PositiveInt = Depends(get_current_user_id), director2_api: DirectorV2Api = Depends(get_api_client(DirectorV2Api)), -): +) -> JobStatus: job_name = _compose_job_resource_name(solver_key, version, job_id) _logger.debug("Inspecting Job '%s'", job_name) diff --git a/services/api-server/src/simcore_service_api_server/utils/solver_job_models_converters.py b/services/api-server/src/simcore_service_api_server/utils/solver_job_models_converters.py index c4e0c7befd6..3791cfe03a2 100644 --- a/services/api-server/src/simcore_service_api_server/utils/solver_job_models_converters.py +++ b/services/api-server/src/simcore_service_api_server/utils/solver_job_models_converters.py @@ -20,7 +20,14 @@ StudyUI, ) from ..models.schemas.files import File -from ..models.schemas.jobs import ArgumentType, Job, JobInputs, JobStatus, TaskStates +from ..models.schemas.jobs import ( + ArgumentType, + Job, + JobInputs, + JobStatus, + PercentageInt, + TaskStates, +) from ..models.schemas.solvers import Solver, SolverKeyId from ..plugins.director_v2 import ComputationTaskGet from .typing_extra import get_types @@ -52,13 +59,11 @@ def now_str() -> str: def create_node_inputs_from_job_inputs(inputs: JobInputs) -> dict[InputID, InputTypes]: - # map Job inputs with solver inputs # TODO: ArgumentType -> InputTypes dispatcher node_inputs: dict[InputID, InputTypes] = {} for name, value in inputs.values.items(): - assert isinstance(value, get_types(ArgumentType)) # nosec if isinstance(value, File): @@ -84,7 +89,6 @@ def create_job_inputs_from_node_inputs(inputs: dict[InputID, InputTypes]) -> Job """ input_values: dict[str, ArgumentType] = {} for name, value in inputs.items(): - assert isinstance(name, get_types(InputID)) # nosec assert isinstance(value, get_types(InputTypes)) # nosec @@ -244,11 +248,10 @@ def create_job_from_project( def create_jobstatus_from_task(task: ComputationTaskGet) -> JobStatus: - job_status = JobStatus( job_id=task.id, state=task.state, - progress=task.guess_progress(), + progress=PercentageInt((task.pipeline_details.progress or 0) * 100.0), submitted_at=datetime.utcnow(), ) diff --git a/services/api-server/tests/unit/api_solvers/test_api_routers_solvers_jobs.py b/services/api-server/tests/unit/api_solvers/test_api_routers_solvers_jobs.py index a83b17a0af7..9a4ca2f8a28 100644 --- a/services/api-server/tests/unit/api_solvers/test_api_routers_solvers_jobs.py +++ b/services/api-server/tests/unit/api_solvers/test_api_routers_solvers_jobs.py @@ -57,7 +57,6 @@ def presigned_download_link( bucket_name: str, mocked_s3_server_url: HttpUrl, ) -> Iterator[AnyUrl]: - s3_client = boto3.client( "s3", endpoint_url=mocked_s3_server_url, @@ -102,7 +101,6 @@ def mocked_directorv2_service_api( assert_all_called=False, assert_all_mocked=True, # IMPORTANT: KEEP always True! ) as respx_mock: - # check that what we emulate, actually still exists path = "/v2/computations/{project_id}/tasks/-/logfile" assert path in oas["paths"] @@ -277,6 +275,7 @@ async def test_run_solver_job( "currentStatus": "NOT_STARTED", }, }, + "progress": 0.0, }, "iteration": 1, "cluster_id": 0, diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py index 7e6e45654c6..1d769c91c16 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/dask_utils.py @@ -16,6 +16,7 @@ from distributed.worker import get_worker from distributed.worker_state_machine import TaskState from servicelib.logging_utils import LogLevelInt, LogMessageStr +from servicelib.logging_utils import log_catch logger = logging.getLogger(__name__) @@ -143,12 +144,14 @@ def publish_task_logs( log_level: LogLevelInt, ) -> None: logger.info("[%s - %s]: %s", message_prefix, log_type.name, message) - if log_type == LogType.PROGRESS: - publish_event( - progress_pub, - TaskProgressEvent.from_dask_worker(progress=float(message)), - ) - else: - publish_event( - logs_pub, TaskLogEvent.from_dask_worker(log=message, log_level=log_level) - ) + with log_catch(logger, reraise=False): + if log_type == LogType.PROGRESS: + publish_event( + progress_pub, + TaskProgressEvent.from_dask_worker(progress=float(message)), + ) + else: + publish_event( + logs_pub, + TaskLogEvent.from_dask_worker(log=message, log_level=log_level), + ) diff --git a/services/director-v2/setup.cfg b/services/director-v2/setup.cfg index de804c8e71c..cdcd8fc3b4b 100644 --- a/services/director-v2/setup.cfg +++ b/services/director-v2/setup.cfg @@ -11,3 +11,4 @@ commit_args = --no-verify asyncio_mode = auto markers = testit: "marks test to run during development" + acceptance_test: "marks tests as 'acceptance tests' i.e. does the system do what the user expects? Typically those are workflows." diff --git a/services/director-v2/src/simcore_service_director_v2/models/domains/comp_tasks.py b/services/director-v2/src/simcore_service_director_v2/models/domains/comp_tasks.py index 490b0076a27..625da219ae6 100644 --- a/services/director-v2/src/simcore_service_director_v2/models/domains/comp_tasks.py +++ b/services/director-v2/src/simcore_service_director_v2/models/domains/comp_tasks.py @@ -107,7 +107,9 @@ class CompTaskAtDB(BaseModel): job_id: str | None = Field(default=None, description="The worker job ID") node_schema: NodeSchema = Field(..., alias="schema") inputs: InputsDict | None = Field(..., description="the inputs payload") - outputs: OutputsDict | None = Field({}, description="the outputs payload") + outputs: OutputsDict | None = Field( + default_factory=dict, description="the outputs payload" + ) run_hash: str | None = Field( default=None, description="the hex digest of the resolved inputs +outputs hash at the time when the last outputs were generated", @@ -121,6 +123,12 @@ class CompTaskAtDB(BaseModel): internal_id: PositiveInt node_class: NodeClass errors: list[ErrorDict] | None = Field(default=None) + progress: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="current progress of the task if available", + ) @validator("state", pre=True) @classmethod @@ -186,6 +194,7 @@ class Config: "submit": "2021-03-01 13:07:34.19161", "node_class": "INTERACTIVE", "state": "NOT_STARTED", + "progress": 0.44, } for image_example in Image.Config.schema_extra["examples"] ] diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py index a7d842491dc..7833b7555d4 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/base_scheduler.py @@ -231,6 +231,7 @@ async def _set_states_following_failed_to_aborted( project_id, [NodeID(n) for n in tasks_to_set_aborted], RunningState.ABORTED, + optional_progress=1.0, ) return tasks @@ -293,30 +294,36 @@ async def _update_states_from_comp_backend( cluster_id: ClusterID, project_id: ProjectID, pipeline_dag: nx.DiGraph, - ): + ) -> None: all_tasks = await self._get_pipeline_tasks(project_id, pipeline_dag) - processing_tasks = [ + if processing_tasks := [ t for t in all_tasks.values() if t.state in PROCESSING_STATES - ] - changed_tasks = await self._get_changed_tasks_from_backend( - user_id, cluster_id, processing_tasks - ) - - await self._publish_service_started_metrics(user_id, project_id, changed_tasks) + ]: + changed_tasks = await self._get_changed_tasks_from_backend( + user_id, cluster_id, processing_tasks + ) - completed_tasks = [ - current for _, current in changed_tasks if current.state in COMPLETED_STATES - ] - incomplete_tasks = [ - current - for _, current in changed_tasks - if current.state not in COMPLETED_STATES - ] + await self._publish_service_started_metrics( + user_id, project_id, changed_tasks + ) - if completed_tasks: - await self._process_completed_tasks(user_id, cluster_id, completed_tasks) - if incomplete_tasks: - await self._process_incomplete_tasks(incomplete_tasks) + completed_tasks = [ + current + for _, current in changed_tasks + if current.state in COMPLETED_STATES + ] + incomplete_tasks = [ + current + for _, current in changed_tasks + if current.state not in COMPLETED_STATES + ] + + if completed_tasks: + await self._process_completed_tasks( + user_id, cluster_id, completed_tasks + ) + if incomplete_tasks: + await self._process_incomplete_tasks(incomplete_tasks) @abstractmethod async def _start_tasks( @@ -471,7 +478,10 @@ async def _schedule_tasks_to_start( self.db_engine, CompTasksRepository ) await comp_tasks_repo.set_project_tasks_state( - project_id, list(tasks_ready_to_start.keys()), RunningState.PENDING + project_id, + list(tasks_ready_to_start.keys()), + RunningState.PENDING, + optional_progress=0, ) # we pass the tasks to the dask-client in a gather such that each task can be stopped independently @@ -502,6 +512,7 @@ async def _schedule_tasks_to_start( [r.node_id], RunningState.FAILED, r.get_errors(), + optional_progress=1.0, ) elif isinstance( r, @@ -523,6 +534,7 @@ async def _schedule_tasks_to_start( project_id, list(tasks_ready_to_start.keys()), RunningState.PUBLISHED, + optional_progress=0, ), ) elif isinstance(r, Exception): @@ -536,7 +548,7 @@ async def _schedule_tasks_to_start( "".join(traceback.format_tb(r.__traceback__)), ) await comp_tasks_repo.set_project_tasks_state( - project_id, [t], RunningState.FAILED + project_id, [t], RunningState.FAILED, optional_progress=1.0 ) def _wake_up_scheduler_now(self) -> None: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py index 6ed37f73534..6d7d9b87632 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/dask_scheduler.py @@ -201,13 +201,22 @@ async def _process_task_result( await self.rabbitmq_client.publish(message.channel_name, message) await CompTasksRepository(self.db_engine).set_project_tasks_state( - task.project_id, [task.node_id], task_final_state, errors=errors + task.project_id, + [task.node_id], + task_final_state, + errors=errors, + optional_progress=1, ) async def _task_progress_change_handler(self, event: str) -> None: task_progress_event = TaskProgressEvent.parse_raw(event) logger.debug("received task progress update: %s", task_progress_event) *_, user_id, project_id, node_id = parse_dask_job_id(task_progress_event.job_id) + + await CompTasksRepository(self.db_engine).set_project_task_progress( + project_id, node_id, task_progress_event.progress + ) + message = ProgressRabbitMessageNode.construct( user_id=user_id, project_id=project_id, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/clusters.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/clusters.py index bbb560c1dda..44e64b325f1 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/clusters.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/clusters.py @@ -95,7 +95,7 @@ async def _compute_user_access_rights( conn: connection.SAConnection, user_id: UserID, cluster: Cluster ) -> ClusterAccessRights: result = await conn.execute( - sa.select([user_to_groups.c.gid, groups.c.type]) + sa.select(user_to_groups.c.gid, groups.c.type) .where(user_to_groups.c.uid == user_id) .order_by(groups.c.type) .join(groups) @@ -121,7 +121,7 @@ class ClustersRepository(BaseRepository): async def create_cluster(self, user_id, new_cluster: ClusterCreate) -> Cluster: async with self.db_engine.acquire() as conn: user_primary_gid = await conn.scalar( - sa.select([users.c.primary_gid]).where(users.c.id == user_id) + sa.select(users.c.primary_gid).where(users.c.id == user_id) ) new_cluster.owner = user_primary_gid new_cluster_id = await conn.scalar( @@ -135,11 +135,12 @@ async def create_cluster(self, user_id, new_cluster: ClusterCreate) -> Cluster: async def list_clusters(self, user_id: UserID) -> list[Cluster]: async with self.db_engine.acquire() as conn: result = await conn.execute( - sa.select([clusters.c.id], distinct=True) + sa.select(clusters.c.id) + .distinct() .where( cluster_to_groups.c.gid.in_( # get the groups of the user where he/she has read access - sa.select([groups.c.gid]) + sa.select(groups.c.gid) .where(user_to_groups.c.uid == user_id) .order_by(groups.c.gid) .select_from(groups.join(user_to_groups)) @@ -148,8 +149,12 @@ async def list_clusters(self, user_id: UserID) -> list[Cluster]: ) .join(cluster_to_groups) ) - cluster_ids = await result.fetchall() - return await _clusters_from_cluster_ids(conn, {c.id for c in cluster_ids}) + retrieved_clusters = [] + if cluster_ids := await result.fetchall(): + retrieved_clusters = await _clusters_from_cluster_ids( + conn, {c.id for c in cluster_ids} + ) + return retrieved_clusters async def get_cluster(self, user_id: UserID, cluster_id: ClusterID) -> Cluster: async with self.db_engine.acquire() as conn: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_pipelines.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_pipelines.py index 8b21e592555..7af66309154 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_pipelines.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_pipelines.py @@ -19,7 +19,7 @@ class CompPipelinesRepository(BaseRepository): async def get_pipeline(self, project_id: ProjectID) -> CompPipelineAtDB: async with self.db_engine.acquire() as conn: result = await conn.execute( - sa.select([comp_pipeline]).where( + sa.select(comp_pipeline).where( comp_pipeline.c.project_id == str(project_id) ) ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index a11c7c3d45b..80e19b53f03 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -38,7 +38,7 @@ async def get( """ async with self.db_engine.acquire() as conn: result = await conn.execute( - sa.select([comp_runs]) + sa.select(comp_runs) .where( (comp_runs.c.user_id == user_id) & (comp_runs.c.project_uuid == f"{project_id}") @@ -60,7 +60,7 @@ async def list( runs_in_db: deque[CompRunsAtDB] = deque() async with self.db_engine.acquire() as conn: async for row in conn.execute( - sa.select([comp_runs]).where( + sa.select(comp_runs).where( or_( *[ comp_runs.c.result == RUNNING_STATE_TO_DB[s] @@ -84,7 +84,7 @@ async def create( if iteration is None: # let's get the latest if it exists last_iteration = await conn.scalar( - sa.select([comp_runs.c.iteration]) + sa.select(comp_runs.c.iteration) .where( (comp_runs.c.user_id == user_id) & (comp_runs.c.project_uuid == f"{project_id}") diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks.py index d97b6d800ac..a7790791a0b 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks.py @@ -151,11 +151,13 @@ async def _generate_tasks_list_from_project( assert node.state is not None # nosec task_state = node.state.current_status + task_progress = node.state.progress if ( node_id in published_nodes and to_node_class(node.key) == NodeClass.COMPUTATIONAL ): task_state = RunningState.PUBLISHED + task_progress = 0 task_db = CompTaskAtDB( project_id=project.uuid, @@ -172,6 +174,7 @@ async def _generate_tasks_list_from_project( state=task_state, internal_id=internal_id, node_class=to_node_class(node.key), + progress=task_progress, ) list_comp_tasks.append(task_db) @@ -186,9 +189,7 @@ async def get_all_tasks( tasks: list[CompTaskAtDB] = [] async with self.db_engine.acquire() as conn: async for row in conn.execute( - sa.select([comp_tasks]).where( - comp_tasks.c.project_id == f"{project_id}" - ) + sa.select(comp_tasks).where(comp_tasks.c.project_id == f"{project_id}") ): task_db = CompTaskAtDB.from_orm(row) tasks.append(task_db) @@ -202,7 +203,7 @@ async def get_comp_tasks( tasks: list[CompTaskAtDB] = [] async with self.db_engine.acquire() as conn: async for row in conn.execute( - sa.select([comp_tasks]).where( + sa.select(comp_tasks).where( (comp_tasks.c.project_id == f"{project_id}") & (comp_tasks.c.node_class == NodeClass.COMPUTATIONAL) ) @@ -214,7 +215,7 @@ async def get_comp_tasks( async def check_task_exists(self, project_id: ProjectID, node_id: NodeID) -> bool: async with self.db_engine.acquire() as conn: nid: str | None = await conn.scalar( - sa.select([comp_tasks.c.node_id]).where( + sa.select(comp_tasks.c.node_id).where( (comp_tasks.c.project_id == f"{project_id}") & (comp_tasks.c.node_id == f"{node_id}") ) @@ -244,23 +245,22 @@ async def upsert_tasks_from_project( async with self.db_engine.acquire() as conn: # get current tasks result = await conn.execute( - sa.select([comp_tasks.c.node_id]).where( + sa.select(comp_tasks.c.node_id).where( comp_tasks.c.project_id == str(project.uuid) ) ) # remove the tasks that were removed from project workbench - node_ids_to_delete = [ - t.node_id - for t in await result.fetchall() - if t.node_id not in project.workbench - ] - for node_id in node_ids_to_delete: - await conn.execute( - sa.delete(comp_tasks).where( - (comp_tasks.c.project_id == str(project.uuid)) - & (comp_tasks.c.node_id == node_id) + if all_nodes := await result.fetchall(): + node_ids_to_delete = [ + t.node_id for t in all_nodes if t.node_id not in project.workbench + ] + for node_id in node_ids_to_delete: + await conn.execute( + sa.delete(comp_tasks).where( + (comp_tasks.c.project_id == str(project.uuid)) + & (comp_tasks.c.node_id == node_id) + ) ) - ) # insert or update the remaining tasks # NOTE: comp_tasks DB only trigger a notification to the webserver if an UPDATE on comp_tasks.outputs or comp_tasks.state is done @@ -270,7 +270,7 @@ async def upsert_tasks_from_project( insert_stmt = insert(comp_tasks).values(**comp_task_db.to_db_model()) exclusion_rule = ( - {"state"} + {"state", "progress"} if str(comp_task_db.node_id) not in published_nodes else set() ) @@ -302,7 +302,7 @@ async def mark_project_published_tasks_as_aborted( & (comp_tasks.c.node_class == NodeClass.COMPUTATIONAL) & (comp_tasks.c.state == StateType.PUBLISHED) ) - .values(state=StateType.ABORTED) + .values(state=StateType.ABORTED, progress=1.0) ) logger.debug("marked project %s published tasks as aborted", f"{project_id=}") @@ -331,15 +331,20 @@ async def set_project_tasks_state( tasks: list[NodeID], state: RunningState, errors: list[ErrorDict] | None = None, + *, + optional_progress: float | None = None, ) -> None: async with self.db_engine.acquire() as conn: + update_values = {"state": RUNNING_STATE_TO_DB[state], "errors": errors} + if optional_progress is not None: + update_values["progress"] = optional_progress await conn.execute( sa.update(comp_tasks) .where( (comp_tasks.c.project_id == f"{project_id}") & (comp_tasks.c.node_id.in_([str(t) for t in tasks])) ) - .values(state=RUNNING_STATE_TO_DB[state], errors=errors) + .values(**update_values) ) logger.debug( "set project %s tasks %s with state %s", @@ -348,6 +353,26 @@ async def set_project_tasks_state( f"{state=}", ) + async def set_project_task_progress( + self, project_id: ProjectID, node_id: NodeID, progress: float + ) -> None: + async with self.db_engine.acquire() as conn: + await conn.execute( + sa.update(comp_tasks) + .where( + (comp_tasks.c.project_id == f"{project_id}") + & (comp_tasks.c.node_id == f"{node_id}") + ) + .values(progress=progress) + ) + + logger.debug( + "set project %s task %s with progress %s", + f"{project_id=}", + f"{node_id=}", + f"{progress=}", + ) + async def delete_tasks_from_project(self, project: ProjectAtDB) -> None: async with self.db_engine.acquire() as conn: await conn.execute( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects.py index 176d3ef8a06..43734c7c3b2 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects.py @@ -17,7 +17,7 @@ async def get_project(self, project_id: ProjectID) -> ProjectAtDB: async with self.db_engine.acquire() as conn: row: RowProxy | None = await ( await conn.execute( - sa.select([projects]).where(projects.c.uuid == str(project_id)) + sa.select(projects).where(projects.c.uuid == str(project_id)) ) ).first() if not row: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects_networks.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects_networks.py index b9e296a68e4..4e234b43c02 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects_networks.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/projects_networks.py @@ -14,9 +14,9 @@ class ProjectsNetworksRepository(BaseRepository): async def get_projects_networks(self, project_id: ProjectID) -> ProjectsNetworks: async with self.db_engine.acquire() as conn: - row: RowProxy = await ( + row: RowProxy | None = await ( await conn.execute( - sa.select([projects_networks]).where( + sa.select(projects_networks).where( projects_networks.c.project_uuid == f"{project_id}" ) ) diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dags.py b/services/director-v2/src/simcore_service_director_v2/utils/dags.py index f84077232d3..edf35fbb096 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dags.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dags.py @@ -5,7 +5,7 @@ import networkx as nx from models_library.projects import NodesDict from models_library.projects_nodes import NodeID, NodeState -from models_library.projects_nodes_io import PortLink +from models_library.projects_nodes_io import NodeIDStr, PortLink from models_library.projects_pipeline import PipelineDetails from models_library.projects_state import RunningState from models_library.utils.nodes import compute_node_hash @@ -14,21 +14,18 @@ from ..modules.db.tables import NodeClass from .computations import to_node_class -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) -def _is_node_computational(node_key: str) -> bool: - try: - result: bool = to_node_class(node_key) == NodeClass.COMPUTATIONAL - return result - except ValueError: - return False +kNODE_MODIFIED_STATE = "modified_state" +kNODE_DEPENDENCIES_TO_COMPUTE = "dependencies_state" def create_complete_dag(workbench: NodesDict) -> nx.DiGraph: """creates a complete graph out of the project workbench""" dag_graph = nx.DiGraph() for node_id, node in workbench.items(): + assert node.state # nosec dag_graph.add_node( node_id, name=node.label, @@ -38,11 +35,13 @@ def create_complete_dag(workbench: NodesDict) -> nx.DiGraph: run_hash=node.run_hash, outputs=node.outputs, state=node.state.current_status, + node_class=to_node_class(node.key), ) - for input_node_id in node.input_nodes: - predecessor_node = workbench.get(str(input_node_id)) - if predecessor_node: - dag_graph.add_edge(str(input_node_id), node_id) + if node.input_nodes: + for input_node_id in node.input_nodes: + predecessor_node = workbench.get(NodeIDStr(input_node_id)) + if predecessor_node: + dag_graph.add_edge(str(input_node_id), node_id) return dag_graph @@ -51,7 +50,7 @@ def create_complete_dag_from_tasks(tasks: list[CompTaskAtDB]) -> nx.DiGraph: dag_graph = nx.DiGraph() for task in tasks: dag_graph.add_node( - str(task.node_id), + f"{task.node_id}", name=task.job_id, key=task.image.name, version=task.image.tag, @@ -59,17 +58,20 @@ def create_complete_dag_from_tasks(tasks: list[CompTaskAtDB]) -> nx.DiGraph: run_hash=task.run_hash, outputs=task.outputs, state=task.state, + node_class=task.node_class, + progress=task.progress, ) - for input_data in task.inputs.values(): - if isinstance(input_data, PortLink): - dag_graph.add_edge(str(input_data.node_uuid), str(task.node_id)) + if task.inputs: + for input_data in task.inputs.values(): + if isinstance(input_data, PortLink): + dag_graph.add_edge(str(input_data.node_uuid), f"{task.node_id}") return dag_graph -async def compute_node_modified_state( - nodes_data_view: nx.classes.reportviews.NodeDataView, node_id: NodeID +async def _compute_node_modified_state( + graph_data: nx.classes.reportviews.NodeDataView, node_id: NodeID ) -> bool: - node = nodes_data_view[str(node_id)] + node = graph_data[f"{node_id}"] # if the node state is in the modified state already if node["state"] in [ None, @@ -86,7 +88,7 @@ async def compute_node_modified_state( # maybe our inputs changed? let's compute the node hash and compare with the saved one async def get_node_io_payload_cb(node_id: NodeID) -> dict[str, Any]: - result: dict[str, Any] = nodes_data_view[str(node_id)] + result: dict[str, Any] = graph_data[f"{node_id}"] return result computed_hash = await compute_node_hash(node_id, get_node_io_payload_cb) @@ -95,38 +97,32 @@ async def get_node_io_payload_cb(node_id: NodeID) -> dict[str, Any]: return False -async def compute_node_dependencies_state(nodes_data_view, node_id) -> set[NodeID]: - node = nodes_data_view[str(node_id)] +async def _compute_node_dependencies_state(graph_data, node_id) -> set[NodeID]: + node = graph_data[f"{node_id}"] # check if the previous node is outdated or waits for dependencies... in which case this one has to wait non_computed_dependencies: set[NodeID] = set() for input_port in node.get("inputs", {}).values(): if isinstance(input_port, PortLink): - if node_needs_computation(nodes_data_view, input_port.node_uuid): + if _node_needs_computation(graph_data, input_port.node_uuid): non_computed_dependencies.add(input_port.node_uuid) # all good. ready return non_computed_dependencies -kNODE_MODIFIED_STATE = "modified_state" -kNODE_DEPENDENCIES_TO_COMPUTE = "dependencies_state" - - -async def compute_node_states( - nodes_data_view: nx.classes.reportviews.NodeDataView, node_id: NodeID -): - node = nodes_data_view[str(node_id)] - node[kNODE_MODIFIED_STATE] = await compute_node_modified_state( - nodes_data_view, node_id - ) - node[kNODE_DEPENDENCIES_TO_COMPUTE] = await compute_node_dependencies_state( - nodes_data_view, node_id +async def _compute_node_states( + graph_data: nx.classes.reportviews.NodeDataView, node_id: NodeID +) -> None: + node = graph_data[f"{node_id}"] + node[kNODE_MODIFIED_STATE] = await _compute_node_modified_state(graph_data, node_id) + node[kNODE_DEPENDENCIES_TO_COMPUTE] = await _compute_node_dependencies_state( + graph_data, node_id ) -def node_needs_computation( - nodes_data_view: nx.classes.reportviews.NodeDataView, node_id: NodeID +def _node_needs_computation( + graph_data: nx.classes.reportviews.NodeDataView, node_id: NodeID ) -> bool: - node = nodes_data_view[str(node_id)] + node = graph_data[f"{node_id}"] needs_computation: bool = node.get(kNODE_MODIFIED_STATE, False) or node.get( kNODE_DEPENDENCIES_TO_COMPUTE, None ) @@ -134,20 +130,20 @@ def node_needs_computation( async def _set_computational_nodes_states(complete_dag: nx.DiGraph) -> None: - nodes_data_view: nx.classes.reportviews.NodeDataView = complete_dag.nodes.data() - for node in nx.topological_sort(complete_dag): - if _is_node_computational(nodes_data_view[node].get("key", "")): - await compute_node_states(nodes_data_view, node) + graph_data: nx.classes.reportviews.NodeDataView = complete_dag.nodes.data() + for node_id in nx.algorithms.dag.topological_sort(complete_dag): + if graph_data[node_id]["node_class"] is NodeClass.COMPUTATIONAL: + await _compute_node_states(graph_data, node_id) async def create_minimal_computational_graph_based_on_selection( complete_dag: nx.DiGraph, selected_nodes: list[NodeID], force_restart: bool ) -> nx.DiGraph: - nodes_data_view: nx.classes.reportviews.NodeDataView = complete_dag.nodes.data() + graph_data: nx.classes.reportviews.NodeDataView = complete_dag.nodes.data() try: # first pass, traversing in topological order to correctly get the dependencies, set the nodes states await _set_computational_nodes_states(complete_dag) - except nx.NetworkXUnfeasible: + except nx.exception.NetworkXUnfeasible: # not acyclic, return an empty graph return nx.DiGraph() @@ -158,9 +154,9 @@ async def create_minimal_computational_graph_based_on_selection( minimal_nodes_selection.update( { n - for n, _ in nodes_data_view - if _is_node_computational(nodes_data_view[n]["key"]) - and (force_restart or node_needs_computation(nodes_data_view, n)) + for n, _ in graph_data + if graph_data[n]["node_class"] is NodeClass.COMPUTATIONAL + and (force_restart or _node_needs_computation(graph_data, n)) } ) else: @@ -170,12 +166,13 @@ async def create_minimal_computational_graph_based_on_selection( { n for n in nx.bfs_tree(complete_dag, f"{node}", reverse=True) - if _is_node_computational(nodes_data_view[n]["key"]) - and node_needs_computation(nodes_data_view, n) + if graph_data[n]["node_class"] is NodeClass.COMPUTATIONAL + and _node_needs_computation(graph_data, n) } ) - if force_restart and _is_node_computational( - nodes_data_view[f"{node}"]["key"] + if ( + force_restart + and graph_data[f"{node}"]["node_class"] is NodeClass.COMPUTATIONAL ): minimal_nodes_selection.add(f"{node}") @@ -186,25 +183,36 @@ async def compute_pipeline_details( complete_dag: nx.DiGraph, pipeline_dag: nx.DiGraph, comp_tasks: list[CompTaskAtDB] ) -> PipelineDetails: try: - # FIXME: this problem of cyclic graphs for control loops create all kinds of issues that must be fixed + # NOTE: this problem of cyclic graphs for control loops create all kinds of issues that must be fixed # first pass, traversing in topological order to correctly get the dependencies, set the nodes states await _set_computational_nodes_states(complete_dag) - except nx.NetworkXUnfeasible: + except nx.exception.NetworkXUnfeasible: # not acyclic pass + + # NOTE: the latest progress is available in comp_tasks only + node_id_to_comp_task: dict[NodeIDStr, CompTaskAtDB] = { + NodeIDStr(f"{task.node_id}"): task for task in comp_tasks + } + pipeline_progress = None + if len(pipeline_dag.nodes) > 0: + pipeline_progress = 0.0 + for node_id in pipeline_dag.nodes: + if node_progress := node_id_to_comp_task[node_id].progress: + pipeline_progress += node_progress / len(pipeline_dag.nodes) + return PipelineDetails( - adjacency_list=nx.to_dict_of_lists(pipeline_dag), + adjacency_list=nx.convert.to_dict_of_lists(pipeline_dag), + progress=pipeline_progress, node_states={ node_id: NodeState( modified=node_data.get(kNODE_MODIFIED_STATE, False), dependencies=node_data.get(kNODE_DEPENDENCIES_TO_COMPUTE, set()), - currentStatus=next( - (task.state for task in comp_tasks if str(task.node_id) == node_id), - RunningState.UNKNOWN, - ), + currentStatus=node_id_to_comp_task[node_id].state, + progress=node_id_to_comp_task[node_id].progress, ) for node_id, node_data in complete_dag.nodes.data() - if _is_node_computational(node_data.get("key", "")) + if node_data["node_class"] is NodeClass.COMPUTATIONAL }, ) @@ -212,8 +220,11 @@ async def compute_pipeline_details( def find_computational_node_cycles(dag: nx.DiGraph) -> list[list[str]]: """returns a list of nodes part of a cycle and computational, which is currently forbidden.""" computational_node_cycles = [] - list_potential_cycles = nx.simple_cycles(dag) + list_potential_cycles = nx.algorithms.cycles.simple_cycles(dag) for cycle in list_potential_cycles: - if any(_is_node_computational(dag.nodes[node_id]["key"]) for node_id in cycle): + if any( + dag.nodes[node_id]["node_class"] is NodeClass.COMPUTATIONAL + for node_id in cycle + ): computational_node_cycles.append(deepcopy(cycle)) return computational_node_cycles diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask.py b/services/director-v2/src/simcore_service_director_v2/utils/dask.py index bf2d31e8abd..65bf54b8ac1 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask.py @@ -32,6 +32,7 @@ from models_library.users import UserID from pydantic import AnyUrl, ByteSize, ValidationError from servicelib.json_serialization import json_dumps +from servicelib.logging_utils import log_catch, log_context from simcore_sdk import node_ports_v2 from simcore_sdk.node_ports_common.exceptions import ( S3InvalidPathError, @@ -375,10 +376,10 @@ async def clean_task_output_and_log_files_if_invalid( ) -async def dask_sub_consumer( +async def _dask_sub_consumer( dask_sub: distributed.Sub, handler: Callable[[str], Awaitable[None]], -): +) -> None: async for dask_event in dask_sub: logger.debug( "received dask event '%s' of topic %s", @@ -386,7 +387,6 @@ async def dask_sub_consumer( dask_sub.name, ) await handler(dask_event) - await asyncio.sleep(0.010) _REST_TIMEOUT_S: Final[int] = 1 @@ -397,19 +397,12 @@ async def dask_sub_consumer_task( handler: Callable[[str], Awaitable[None]], ) -> NoReturn: while True: - try: - logger.info("starting dask consumer task for topic '%s'", dask_sub.name) - await dask_sub_consumer(dask_sub, handler) - except asyncio.CancelledError: - logger.info("stopped dask consumer task for topic '%s'", dask_sub.name) - raise - except Exception: # pylint: disable=broad-except - logger.exception( - "unknown exception in dask consumer task for topic '%s', restarting task in %s sec...", - dask_sub.name, - _REST_TIMEOUT_S, - ) - await asyncio.sleep(_REST_TIMEOUT_S) + with log_catch(logger, reraise=False), log_context( + logger, level=logging.DEBUG, msg=f"dask sub task for topic {dask_sub.name}" + ): + await _dask_sub_consumer(dask_sub, handler) + # we sleep a bit before restarting + await asyncio.sleep(_REST_TIMEOUT_S) def from_node_reqs_to_dask_resources( diff --git a/services/director-v2/tests/helpers/__init__.py b/services/director-v2/tests/helpers/__init__.py new file mode 100644 index 00000000000..c81922d7d4a --- /dev/null +++ b/services/director-v2/tests/helpers/__init__.py @@ -0,0 +1,4 @@ +import pytest + +# NOTE: this ensures that pytest rewrites the assertion so that comparison look nice in the console +pytest.register_assert_rewrite("helpers.shared_comp_utils") diff --git a/services/director-v2/tests/integration/shared_comp_utils.py b/services/director-v2/tests/helpers/shared_comp_utils.py similarity index 91% rename from services/director-v2/tests/integration/shared_comp_utils.py rename to services/director-v2/tests/helpers/shared_comp_utils.py index 645f0a389c0..04df1f50979 100644 --- a/services/director-v2/tests/integration/shared_comp_utils.py +++ b/services/director-v2/tests/helpers/shared_comp_utils.py @@ -1,6 +1,5 @@ import json import time -from typing import Optional from uuid import UUID import httpx @@ -26,8 +25,8 @@ async def assert_computation_task_out_obj( project: ProjectAtDB, exp_task_state: RunningState, exp_pipeline_details: PipelineDetails, - iteration: Optional[PositiveInt], - cluster_id: Optional[ClusterID], + iteration: PositiveInt | None, + cluster_id: ClusterID | None, ): assert task_out.id == project.uuid assert task_out.state == exp_task_state @@ -44,7 +43,9 @@ async def assert_computation_task_out_obj( assert task_out.iteration == iteration assert task_out.cluster_id == cluster_id # check pipeline details contents - assert task_out.pipeline_details.dict() == exp_pipeline_details.dict() + received_task_out_pipeline = task_out.pipeline_details.dict() + expected_task_out_pipeline = exp_pipeline_details.dict() + assert received_task_out_pipeline == expected_task_out_pipeline async def assert_and_wait_for_pipeline_status( @@ -52,7 +53,7 @@ async def assert_and_wait_for_pipeline_status( url: AnyHttpUrl, user_id: UserID, project_uuid: UUID, - wait_for_states: Optional[list[RunningState]] = None, + wait_for_states: list[RunningState] | None = None, ) -> ComputationGet: if not wait_for_states: wait_for_states = [ diff --git a/services/director-v2/tests/integration/01/test_computation_api.py b/services/director-v2/tests/integration/01/test_computation_api.py index 698c161322b..dfc27eb6cdf 100644 --- a/services/director-v2/tests/integration/01/test_computation_api.py +++ b/services/director-v2/tests/integration/01/test_computation_api.py @@ -15,18 +15,19 @@ import httpx import pytest import sqlalchemy as sa +from helpers.shared_comp_utils import ( + assert_and_wait_for_pipeline_status, + assert_computation_task_out_obj, +) from models_library.clusters import DEFAULT_CLUSTER_ID from models_library.projects import ProjectAtDB from models_library.projects_nodes import NodeState from models_library.projects_nodes_io import NodeID from models_library.projects_pipeline import PipelineDetails from models_library.projects_state import RunningState +from models_library.users import UserID from pytest import MonkeyPatch from settings_library.rabbit import RabbitSettings -from shared_comp_utils import ( - assert_and_wait_for_pipeline_status, - assert_computation_task_out_obj, -) from simcore_service_director_v2.models.schemas.comp_tasks import ComputationGet from starlette import status from starlette.testclient import TestClient @@ -97,7 +98,7 @@ def fake_workbench_computational_pipeline_details( adjacency_list = json.loads(fake_workbench_computational_adjacency_file.read_text()) node_states = json.loads(fake_workbench_node_states_file.read_text()) return PipelineDetails.parse_obj( - {"adjacency_list": adjacency_list, "node_states": node_states} + {"adjacency_list": adjacency_list, "node_states": node_states, "progress": 0} ) @@ -110,6 +111,8 @@ def fake_workbench_computational_pipeline_details_completed( node_state.modified = False node_state.dependencies = set() node_state.current_status = RunningState.SUCCESS + node_state.progress = 1 + completed_pipeline_details.progress = 1 return completed_pipeline_details @@ -216,18 +219,22 @@ class PartialComputationParams: "modified": True, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 2: { "modified": True, "dependencies": [1], + "progress": 0, }, 3: { "modified": True, "dependencies": [], + "progress": 0, }, 4: { "modified": True, "dependencies": [2, 3], + "progress": 0, }, }, exp_node_states_after_run={ @@ -235,18 +242,22 @@ class PartialComputationParams: "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, 2: { "modified": True, "dependencies": [], + "progress": 0, }, 3: { "modified": True, "dependencies": [], + "progress": 0, }, 4: { "modified": True, "dependencies": [2, 3], + "progress": 0, }, }, exp_pipeline_adj_list_after_force_run={1: []}, @@ -255,21 +266,25 @@ class PartialComputationParams: "modified": False, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 2: { "modified": True, "dependencies": [], "currentStatus": RunningState.NOT_STARTED, + "progress": 0, }, 3: { "modified": True, "dependencies": [], "currentStatus": RunningState.NOT_STARTED, + "progress": 0, }, 4: { "modified": True, "dependencies": [2, 3], "currentStatus": RunningState.NOT_STARTED, + "progress": 0, }, }, ), @@ -284,21 +299,25 @@ class PartialComputationParams: "modified": True, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 2: { "modified": True, "dependencies": [1], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 3: { "modified": True, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 4: { "modified": True, "dependencies": [2, 3], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, }, exp_node_states_after_run={ @@ -306,21 +325,25 @@ class PartialComputationParams: "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, 2: { "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, 3: { "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, 4: { "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, }, exp_pipeline_adj_list_after_force_run={1: [2], 2: [4], 4: []}, @@ -329,21 +352,25 @@ class PartialComputationParams: "modified": False, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 2: { "modified": False, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, 3: { "modified": False, "dependencies": [], "currentStatus": RunningState.SUCCESS, + "progress": 1, }, 4: { "modified": False, "dependencies": [], "currentStatus": RunningState.PUBLISHED, + "progress": 0, }, }, ), @@ -352,6 +379,7 @@ class PartialComputationParams: ], ) async def test_run_partial_computation( + catalog_ready: Callable[[UserID, str], Awaitable[None]], minimal_configuration: None, async_client: httpx.AsyncClient, registered_user: Callable, @@ -363,6 +391,7 @@ async def test_run_partial_computation( create_pipeline: Callable[..., Awaitable[ComputationGet]], ): user = registered_user() + await catalog_ready(user["id"], osparc_product_name) sleepers_project: ProjectAtDB = project( user, workbench=fake_workbench_without_outputs ) @@ -385,11 +414,18 @@ def _convert_to_pipeline_details( NodeID(workbench_node_uuids[dep_n]) for dep_n in s["dependencies"] }, currentStatus=s.get("currentStatus", RunningState.NOT_STARTED), + progress=s.get("progress"), ) for n, s in exp_node_states.items() } + pipeline_progress = 0 + for node_id in converted_adj_list: + node = converted_node_states[node_id] + pipeline_progress += (node.progress or 0) / len(converted_adj_list) return PipelineDetails( - adjacency_list=converted_adj_list, node_states=converted_node_states + adjacency_list=converted_adj_list, + node_states=converted_node_states, + progress=pipeline_progress, ) # convert the ids to the node uuids from the project @@ -495,6 +531,7 @@ def _convert_to_pipeline_details( async def test_run_computation( + catalog_ready: Callable[[UserID, str], Awaitable[None]], minimal_configuration: None, async_client: httpx.AsyncClient, registered_user: Callable, @@ -507,6 +544,7 @@ async def test_run_computation( create_pipeline: Callable[..., Awaitable[ComputationGet]], ): user = registered_user() + await catalog_ready(user["id"], osparc_product_name) sleepers_project = project(user, workbench=fake_workbench_without_outputs) # send a valid project with sleepers task_out = await create_pipeline( @@ -576,6 +614,10 @@ async def test_run_computation( node_id ].current_status ) + node_data.progress = fake_workbench_computational_pipeline_details.node_states[ + node_id + ].progress + expected_pipeline_details_forced.progress = 0 task_out = await create_pipeline( async_client, project=sleepers_project, @@ -672,7 +714,7 @@ async def test_abort_computation( ), f"response code is {response.status_code}, error: {response.text}" task_out = ComputationGet.parse_obj(response.json()) assert task_out.url.path == f"/v2/computations/{sleepers_project.uuid}:stop" - assert task_out.stop_url == None + assert task_out.stop_url is None # check that the pipeline is aborted/stopped task_out = await assert_and_wait_for_pipeline_status( diff --git a/services/director-v2/tests/integration/02/conftest.py b/services/director-v2/tests/integration/02/conftest.py index 11be5a52c78..a2e8868158f 100644 --- a/services/director-v2/tests/integration/02/conftest.py +++ b/services/director-v2/tests/integration/02/conftest.py @@ -1,6 +1,7 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument +from typing import AsyncIterator from uuid import uuid4 import aiodocker @@ -20,7 +21,9 @@ def network_name() -> str: @pytest.fixture -async def ensure_swarm_and_networks(network_name: str, docker_swarm: None): +async def ensure_swarm_and_networks( + network_name: str, docker_swarm: None +) -> AsyncIterator[None]: """ Make sure to always have a docker swarm network. If one is not present crete one. There can not be more then one. diff --git a/services/director-v2/tests/integration/02/test_dynamic_sidecar_nodeports_integration.py b/services/director-v2/tests/integration/02/test_dynamic_sidecar_nodeports_integration.py index 88d9f9922e3..4f0e2ab44bf 100644 --- a/services/director-v2/tests/integration/02/test_dynamic_sidecar_nodeports_integration.py +++ b/services/director-v2/tests/integration/02/test_dynamic_sidecar_nodeports_integration.py @@ -1,6 +1,7 @@ # pylint: disable=protected-access # pylint: disable=redefined-outer-name # pylint: disable=unused-argument +# pylint:disable=too-many-arguments import asyncio import hashlib @@ -13,10 +14,11 @@ from typing import ( Any, AsyncIterable, + AsyncIterator, Awaitable, Callable, + Coroutine, Iterable, - Iterator, Optional, cast, ) @@ -31,8 +33,18 @@ from aiodocker.containers import DockerContainer from aiopg.sa import Engine from fastapi import FastAPI +from helpers.shared_comp_utils import ( + assert_and_wait_for_pipeline_status, + assert_computation_task_out_obj, +) from models_library.clusters import DEFAULT_CLUSTER_ID -from models_library.projects import Node, NodesDict, ProjectAtDB, ProjectID +from models_library.projects import ( + Node, + NodesDict, + ProjectAtDB, + ProjectID, + ProjectIDStr, +) from models_library.projects_networks import ( PROJECT_NETWORK_PREFIX, ContainerAliases, @@ -43,8 +55,8 @@ from models_library.projects_pipeline import PipelineDetails from models_library.projects_state import RunningState from models_library.users import UserID +from pydantic import AnyHttpUrl, parse_obj_as from pytest import MonkeyPatch -from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.utils_docker import get_localhost_ip from servicelib.fastapi.long_running_tasks.client import ( Client, @@ -56,10 +68,6 @@ from servicelib.progress_bar import ProgressBarData from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings -from shared_comp_utils import ( - assert_and_wait_for_pipeline_status, - assert_computation_task_out_obj, -) from simcore_postgres_database.models.comp_pipeline import comp_pipeline from simcore_postgres_database.models.comp_tasks import comp_tasks from simcore_postgres_database.models.projects_networks import projects_networks @@ -131,7 +139,8 @@ @pytest.fixture -def minimal_configuration( # pylint:disable=too-many-arguments +async def minimal_configuration( + catalog_ready: Callable[[UserID, str], Awaitable[None]], sleeper_service: dict, dy_static_file_server_dynamic_sidecar_service: dict, dy_static_file_server_dynamic_sidecar_compose_spec_service: dict, @@ -144,8 +153,10 @@ def minimal_configuration( # pylint:disable=too-many-arguments dask_scheduler_service: str, dask_sidecar_service: None, ensure_swarm_and_networks: None, + current_user: dict[str, Any], osparc_product_name: str, -) -> Iterator[None]: +) -> AsyncIterator[None]: + await catalog_ready(current_user["id"], osparc_product_name) with postgres_db.connect() as conn: # pylint: disable=no-value-for-parameter conn.execute(comp_tasks.delete()) @@ -216,13 +227,6 @@ def fake_dy_success(mocks_dir: Path) -> dict[str, Any]: return json.loads(fake_dy_status_success.read_text()) -@pytest.fixture -def fake_dy_published(mocks_dir: Path) -> dict[str, Any]: - fake_dy_status_published = mocks_dir / "fake_dy_status_published.json" - assert fake_dy_status_published.exists() - return json.loads(fake_dy_status_published.read_text()) - - @pytest.fixture def services_node_uuids( fake_dy_workbench: dict[str, Any], @@ -444,7 +448,7 @@ async def _get_mapped_nodeports_values( for node_uuid in workbench: PORTS: Nodeports = await node_ports_v2.ports( user_id=user_id, - project_id=project_id, + project_id=ProjectIDStr(project_id), node_uuid=NodeIDStr(node_uuid), db_manager=db_manager, ) @@ -696,7 +700,7 @@ async def _start_and_wait_for_dynamic_services_ready( for service_uuid in workbench_dynamic_services: dynamic_service_url = await patch_dynamic_service_url( # pylint: disable=protected-access - app=director_v2_client._transport.app, + app=director_v2_client._transport.app, # type: ignore node_uuid=service_uuid, ) dynamic_services_urls[service_uuid] = dynamic_service_url @@ -718,7 +722,7 @@ async def _wait_for_dy_services_to_fully_stop( ) -> None: # pylint: disable=protected-access to_observe = ( - director_v2_client._transport.app.state.dynamic_sidecar_scheduler._scheduler._to_observe + director_v2_client._transport.app.state.dynamic_sidecar_scheduler._scheduler._to_observe # type: ignore ) # TODO: ANE please use tenacity for i in range(TIMEOUT_DETECT_DYNAMIC_SERVICES_STOPPED): @@ -792,7 +796,7 @@ async def _debug_progress_callback( Client( app=initialized_app, async_client=director_v2_client, - base_url=director_v2_client.base_url, + base_url=parse_obj_as(AnyHttpUrl, f"{director_v2_client.base_url}"), ), task_id, task_timeout=60, @@ -832,14 +836,18 @@ async def _assert_retrieve_completed( container_id ) - logs = " ".join(await container.log(stdout=True, stderr=True)) + logs = " ".join( + await cast( + Coroutine[Any, Any, list[str]], + container.log(stdout=True, stderr=True), + ) + ) assert ( _CONTROL_TESTMARK_DY_SIDECAR_NODEPORT_UPLOADED_MESSAGE in logs ), "TIP: Message missing suggests that the data was never uploaded: look in services/dynamic-sidecar/src/simcore_service_dynamic_sidecar/modules/nodeports.py" async def test_nodeports_integration( - # pylint: disable=too-many-arguments minimal_configuration: None, cleanup_services_and_networks: None, projects_networks_db: None, @@ -854,9 +862,7 @@ async def test_nodeports_integration( workbench_dynamic_services: dict[str, Node], services_node_uuids: ServicesNodeUUIDs, fake_dy_success: dict[str, Any], - fake_dy_published: dict[str, Any], tmp_path: Path, - mocker: MockerFixture, osparc_product_name: str, create_pipeline: Callable[..., Awaitable[ComputationGet]], ) -> None: @@ -908,16 +914,6 @@ async def test_nodeports_integration( product_name=osparc_product_name, ) - # check the contents is correct: a pipeline that just started gets PUBLISHED - await assert_computation_task_out_obj( - task_out, - project=current_study, - exp_task_state=RunningState.PUBLISHED, - exp_pipeline_details=PipelineDetails.parse_obj(fake_dy_published), - iteration=1, - cluster_id=DEFAULT_CLUSTER_ID, - ) - # wait for the computation to finish (either by failing, success or abort) task_out = await assert_and_wait_for_pipeline_status( async_client, task_out.url, current_user["id"], current_study.uuid @@ -989,7 +985,7 @@ async def test_nodeports_integration( # STEP 4 - app_settings: AppSettings = async_client._transport.app.state.settings + app_settings: AppSettings = async_client._transport.app.state.settings # type: ignore r_clone_settings: RCloneSettings = ( app_settings.DYNAMIC_SERVICES.DYNAMIC_SIDECAR.DYNAMIC_SIDECAR_R_CLONE_SETTINGS ) diff --git a/services/director-v2/tests/integration/conftest.py b/services/director-v2/tests/integration/conftest.py index d40193d5c62..deb21d43623 100644 --- a/services/director-v2/tests/integration/conftest.py +++ b/services/director-v2/tests/integration/conftest.py @@ -12,6 +12,11 @@ from simcore_postgres_database.models.projects import projects from simcore_service_director_v2.models.schemas.comp_tasks import ComputationGet from starlette import status +from tenacity import retry +from tenacity.retry import retry_if_exception_type +from tenacity.stop import stop_after_delay +from tenacity.wait import wait_fixed +from yarl import URL @pytest.fixture @@ -24,6 +29,7 @@ def updator(project_uuid: str): projects.select().where(projects.c.uuid == project_uuid) ) prj_row = result.first() + assert prj_row prj_workbench = prj_row.workbench result = con.execute( @@ -109,3 +115,47 @@ def mock_projects_repository(mocker: MockerFixture) -> None: f"{module_base}.ProjectsRepository.is_node_present_in_workbench", return_value=mocked_obj, ) + + +@pytest.fixture +async def catalog_ready( + services_endpoint: dict[str, URL] +) -> Callable[[UserID, str], Awaitable[None]]: + async def _waiter(user_id: UserID, product_name: str) -> None: + catalog_endpoint = list( + filter( + lambda service_endpoint: "catalog" in service_endpoint[0], + services_endpoint.items(), + ) + ) + assert ( + len(catalog_endpoint) == 1 + ), f"no catalog service found! {services_endpoint=}" + catalog_endpoint = catalog_endpoint[0][1] + print(f"--> found catalog endpoint at {catalog_endpoint=}") + client = httpx.AsyncClient() + + @retry( + wait=wait_fixed(1), + stop=stop_after_delay(60), + retry=retry_if_exception_type(AssertionError), + ) + async def _ensure_catalog_services_answers() -> None: + print("--> checking catalog is up and ready...") + response = await client.get( + f"{catalog_endpoint}/v0/services", + params={"details": False, "user_id": user_id}, + headers={"x-simcore-products-name": product_name}, + ) + assert ( + response.status_code == status.HTTP_200_OK + ), f"catalog is not ready {response.status_code}:{response.text}, TIP: migration not completed or catalog broken?" + services = response.json() + assert services != [], "catalog is not ready: no services available" + print( + f"<-- catalog is up and ready, received {response.status_code}:{response.text}" + ) + + await _ensure_catalog_services_answers() + + return _waiter diff --git a/services/director-v2/tests/mocks/fake_dy_status_published.json b/services/director-v2/tests/mocks/fake_dy_status_published.json index 84bca67c1ae..a5d9d396a83 100644 --- a/services/director-v2/tests/mocks/fake_dy_status_published.json +++ b/services/director-v2/tests/mocks/fake_dy_status_published.json @@ -6,7 +6,9 @@ "e6becb37-4699-47f5-81ef-e58fbdf8a9e5": { "modified": true, "dependencies": [], - "currentStatus": "PUBLISHED" + "currentStatus": "PUBLISHED", + "progress": 0 } - } + }, + "progress": 0 } diff --git a/services/director-v2/tests/mocks/fake_dy_status_success.json b/services/director-v2/tests/mocks/fake_dy_status_success.json index c183993fd9a..0b47992af9a 100644 --- a/services/director-v2/tests/mocks/fake_dy_status_success.json +++ b/services/director-v2/tests/mocks/fake_dy_status_success.json @@ -6,7 +6,9 @@ "e6becb37-4699-47f5-81ef-e58fbdf8a9e5": { "modified": false, "dependencies": [], - "currentStatus": "SUCCESS" + "currentStatus": "SUCCESS", + "progress": 1.0 } - } + }, + "progress": 1.0 } diff --git a/services/director-v2/tests/mocks/fake_workbench_computational_node_states.json b/services/director-v2/tests/mocks/fake_workbench_computational_node_states.json index d8ed46e44f7..c36d2fad5f2 100644 --- a/services/director-v2/tests/mocks/fake_workbench_computational_node_states.json +++ b/services/director-v2/tests/mocks/fake_workbench_computational_node_states.json @@ -2,19 +2,22 @@ "3a710d8b-565c-5f46-870b-b45ebe195fc7": { "modified": true, "dependencies": [], - "currentStatus": "PUBLISHED" + "currentStatus": "PUBLISHED", + "progress": 0 }, "e1e2ea96-ce8f-5abc-8712-b8ed312a782c": { "modified": true, "dependencies": [], - "currentStatus": "PUBLISHED" + "currentStatus": "PUBLISHED", + "progress": 0 }, "415fefd1-d08b-53c1-adb0-16bed3a687ef": { "modified": true, "dependencies": [ "3a710d8b-565c-5f46-870b-b45ebe195fc7" ], - "currentStatus": "PUBLISHED" + "currentStatus": "PUBLISHED", + "progress": 0 }, "6ede1209-b459-5735-91fc-761aa584808d": { "modified": true, @@ -22,6 +25,7 @@ "e1e2ea96-ce8f-5abc-8712-b8ed312a782c", "415fefd1-d08b-53c1-adb0-16bed3a687ef" ], - "currentStatus": "PUBLISHED" + "currentStatus": "PUBLISHED", + "progress": 0 } } diff --git a/services/director-v2/tests/unit/_helpers.py b/services/director-v2/tests/unit/_helpers.py index 10f9c9f9e45..b8b333e9d07 100644 --- a/services/director-v2/tests/unit/_helpers.py +++ b/services/director-v2/tests/unit/_helpers.py @@ -1,15 +1,12 @@ import asyncio from dataclasses import dataclass -from typing import Any, Dict, Iterator, List +from typing import Any import aiopg -from models_library.projects import ProjectAtDB, ProjectID +import aiopg.sa +from models_library.projects import ProjectAtDB from models_library.projects_nodes_io import NodeID -from models_library.projects_state import RunningState -from models_library.users import UserID -from pydantic.tools import parse_obj_as from simcore_postgres_database.models.comp_pipeline import StateType -from simcore_postgres_database.models.comp_runs import comp_runs from simcore_postgres_database.models.comp_tasks import comp_tasks from simcore_service_director_v2.models.domains.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.domains.comp_runs import CompRunsAtDB @@ -23,7 +20,7 @@ class PublishedProject: project: ProjectAtDB pipeline: CompPipelineAtDB - tasks: List[CompTaskAtDB] + tasks: list[CompTaskAtDB] @dataclass @@ -31,64 +28,18 @@ class RunningProject(PublishedProject): runs: CompRunsAtDB -async def assert_comp_run_state( - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore - user_id: UserID, - project_uuid: ProjectID, - exp_state: RunningState, -): - # check the database is correctly updated, the run is published - async with aiopg_engine.acquire() as conn: # type: ignore - result = await conn.execute( - comp_runs.select().where( - (comp_runs.c.user_id == user_id) - & (comp_runs.c.project_uuid == f"{project_uuid}") - ) # there is only one entry - ) - run_entry = CompRunsAtDB.parse_obj(await result.first()) - assert ( - run_entry.result == exp_state - ), f"comp_runs: expected state '{exp_state}, found '{run_entry.result}'" - - -async def assert_comp_tasks_state( - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore - project_uuid: ProjectID, - task_ids: List[NodeID], - exp_state: RunningState, -): - # check the database is correctly updated, the run is published - async with aiopg_engine.acquire() as conn: # type: ignore - result = await conn.execute( - comp_tasks.select().where( - (comp_tasks.c.project_id == f"{project_uuid}") - & (comp_tasks.c.node_id.in_([f"{n}" for n in task_ids])) - ) # there is only one entry - ) - tasks = parse_obj_as(List[CompTaskAtDB], await result.fetchall()) - assert all( # pylint: disable=use-a-generator - [t.state == exp_state for t in tasks] - ), f"expected state: {exp_state}, found: {[t.state for t in tasks]}" - - -async def trigger_comp_scheduler(scheduler: BaseCompScheduler): +async def trigger_comp_scheduler(scheduler: BaseCompScheduler) -> None: # trigger the scheduler scheduler._wake_up_scheduler_now() # pylint: disable=protected-access # let the scheduler be actually triggered await asyncio.sleep(1) -async def manually_run_comp_scheduler(scheduler: BaseCompScheduler): - # trigger the scheduler - await scheduler.schedule_all_pipelines() - - async def set_comp_task_state( - aiopg_engine: Iterator[aiopg.sa.engine.Engine], node_id: str, state: StateType # type: ignore -): - async with aiopg_engine.acquire() as conn: # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, node_id: str, state: StateType +) -> None: + async with aiopg_engine.acquire() as conn: await conn.execute( - # pylint: disable=no-value-for-parameter comp_tasks.update() .where(comp_tasks.c.node_id == node_id) .values(state=state) @@ -96,11 +47,13 @@ async def set_comp_task_state( async def set_comp_task_outputs( - aiopg_engine: aiopg.sa.engine.Engine, node_id: NodeID, outputs_schema: Dict[str, Any], outputs: Dict[str, Any] # type: ignore -): - async with aiopg_engine.acquire() as conn: # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, + node_id: NodeID, + outputs_schema: dict[str, Any], + outputs: dict[str, Any], +) -> None: + async with aiopg_engine.acquire() as conn: await conn.execute( - # pylint: disable=no-value-for-parameter comp_tasks.update() .where(comp_tasks.c.node_id == f"{node_id}") .values(outputs=outputs, schema={"outputs": outputs_schema, "inputs": {}}) @@ -108,11 +61,13 @@ async def set_comp_task_outputs( async def set_comp_task_inputs( - aiopg_engine: aiopg.sa.engine.Engine, node_id: NodeID, inputs_schema: Dict[str, Any], inputs: Dict[str, Any] # type: ignore -): - async with aiopg_engine.acquire() as conn: # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, + node_id: NodeID, + inputs_schema: dict[str, Any], + inputs: dict[str, Any], +) -> None: + async with aiopg_engine.acquire() as conn: await conn.execute( - # pylint: disable=no-value-for-parameter comp_tasks.update() .where(comp_tasks.c.node_id == f"{node_id}") .values(inputs=inputs, schema={"outputs": {}, "inputs": inputs_schema}) diff --git a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_observer.py b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_observer.py index 0a990409b95..d7bbf13e4e9 100644 --- a/services/director-v2/tests/unit/test_modules_dynamic_sidecar_observer.py +++ b/services/director-v2/tests/unit/test_modules_dynamic_sidecar_observer.py @@ -2,6 +2,7 @@ # pylint:disable=redefined-outer-name # pylint:disable=unused-argument +from typing import AsyncIterator from unittest.mock import AsyncMock import pytest @@ -90,13 +91,15 @@ def mock_env( @pytest.fixture def mocked_app(mock_env: None) -> FastAPI: app = FastAPI() - app.state.settings = AppSettings() + app.state.settings = AppSettings.create_from_envs() app.state.rabbitmq_client = AsyncMock() return app @pytest.fixture -async def dynamic_sidecar_scheduler(mocked_app: FastAPI) -> DynamicSidecarsScheduler: +async def dynamic_sidecar_scheduler( + mocked_app: FastAPI, +) -> AsyncIterator[DynamicSidecarsScheduler]: await setup_scheduler(mocked_app) await setup(mocked_app) diff --git a/services/director-v2/tests/unit/test_utils_dags.py b/services/director-v2/tests/unit/test_utils_dags.py index 27f54ac2e6d..0a6a5a671d5 100644 --- a/services/director-v2/tests/unit/test_utils_dags.py +++ b/services/director-v2/tests/unit/test_utils_dags.py @@ -12,6 +12,7 @@ import pytest from models_library.projects import NodesDict from models_library.projects_nodes_io import NodeID +from simcore_postgres_database.models.comp_tasks import NodeClass from simcore_service_director_v2.utils.dags import ( create_complete_dag, create_minimal_computational_graph_based_on_selection, @@ -214,9 +215,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra pytest.param( {"node_1": ["node_2", "node_3"], "node_2": ["node_3"], "node_3": []}, { - "node_1": {"key": "simcore/services/comp/fake"}, - "node_2": {"key": "simcore/services/comp/fake"}, - "node_3": {"key": "simcore/services/comp/fake"}, + "node_1": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_2": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_3": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, }, [], id="cycle less dag expect no cycle", @@ -228,9 +238,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra "node_3": ["node_1"], }, { - "node_1": {"key": "simcore/services/comp/fake"}, - "node_2": {"key": "simcore/services/comp/fake"}, - "node_3": {"key": "simcore/services/comp/fake"}, + "node_1": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_2": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_3": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, }, [["node_1", "node_2", "node_3"]], id="dag with 1 cycle", @@ -242,9 +261,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra "node_3": ["node_1"], }, { - "node_1": {"key": "simcore/services/comp/fake"}, - "node_2": {"key": "simcore/services/comp/fake"}, - "node_3": {"key": "simcore/services/comp/fake"}, + "node_1": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_2": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_3": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, }, [["node_1", "node_2", "node_3"], ["node_1", "node_2"]], id="dag with 2 cycles", @@ -256,9 +284,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra "node_3": ["node_1"], }, { - "node_1": {"key": "simcore/services/comp/fake"}, - "node_2": {"key": "simcore/services/comp/fake"}, - "node_3": {"key": "simcore/services/dynamic/fake"}, + "node_1": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_2": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_3": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, }, [["node_1", "node_2", "node_3"]], id="dag with 1 cycle and 1 dynamic services should fail", @@ -270,9 +307,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra "node_3": ["node_1"], }, { - "node_1": {"key": "simcore/services/dynamic/fake"}, - "node_2": {"key": "simcore/services/comp/fake"}, - "node_3": {"key": "simcore/services/dynamic/fake"}, + "node_1": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, + "node_2": { + "key": "simcore/services/comp/fake", + "node_class": NodeClass.COMPUTATIONAL, + }, + "node_3": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, }, [["node_1", "node_2", "node_3"]], id="dag with 1 cycle and 2 dynamic services should fail", @@ -284,9 +330,18 @@ async def test_create_minimal_graph(fake_workbench: NodesDict, graph: MinimalGra "node_3": ["node_1"], }, { - "node_1": {"key": "simcore/services/dynamic/fake"}, - "node_2": {"key": "simcore/services/dynamic/fake"}, - "node_3": {"key": "simcore/services/dynamic/fake"}, + "node_1": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, + "node_2": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, + "node_3": { + "key": "simcore/services/dynamic/fake", + "node_class": NodeClass.INTERACTIVE, + }, }, [], id="dag with 1 cycle and 3 dynamic services should be ok", diff --git a/services/director-v2/tests/unit/with_dbs/conftest.py b/services/director-v2/tests/unit/with_dbs/conftest.py index 5ba454df101..fad9e6f198d 100644 --- a/services/director-v2/tests/unit/with_dbs/conftest.py +++ b/services/director-v2/tests/unit/with_dbs/conftest.py @@ -272,6 +272,8 @@ def running_project( project_id=f"{created_project.uuid}", dag_adjacency_list=fake_workbench_adjacency, ), - tasks=tasks(user=user, project=created_project, state=StateType.RUNNING), + tasks=tasks( + user=user, project=created_project, state=StateType.RUNNING, progress=0.0 + ), runs=runs(user=user, project=created_project, result=StateType.RUNNING), ) diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py index c7311051a1b..a6c65523297 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py @@ -436,7 +436,9 @@ async def test_get_computation_from_empty_project( expected_computation = ComputationGet( id=proj.uuid, state=RunningState.UNKNOWN, - pipeline_details=PipelineDetails(adjacency_list={}, node_states={}), + pipeline_details=PipelineDetails( + adjacency_list={}, node_states={}, progress=None + ), url=parse_obj_as( AnyHttpUrl, f"{async_client.base_url.join(get_computation_url)}" ), @@ -485,10 +487,12 @@ async def test_get_computation_from_not_started_computation_task( adjacency_list=parse_obj_as( dict[NodeID, list[NodeID]], fake_workbench_adjacency ), + progress=0, node_states={ t.node_id: NodeState( modified=True, currentStatus=RunningState.NOT_STARTED, + progress=None, dependencies={ NodeID(node) for node, next_nodes in fake_workbench_adjacency.items() @@ -528,7 +532,7 @@ async def test_get_computation_from_published_computation_task( project_id=proj.uuid, dag_adjacency_list=fake_workbench_adjacency, ) - comp_tasks = tasks(user=user, project=proj, state=StateType.PUBLISHED) + comp_tasks = tasks(user=user, project=proj, state=StateType.PUBLISHED, progress=0) comp_runs = runs(user=user, project=proj, result=StateType.PUBLISHED) get_computation_url = httpx.URL( f"/v2/computations/{proj.uuid}?user_id={user['id']}" @@ -556,10 +560,12 @@ async def test_get_computation_from_published_computation_task( for node, next_nodes in fake_workbench_adjacency.items() if f"{t.node_id}" in next_nodes }, + progress=0, ) for t in comp_tasks if t.node_class == NodeClass.COMPUTATIONAL }, + progress=0, ), url=parse_obj_as( AnyHttpUrl, f"{async_client.base_url.join(get_computation_url)}" diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index b8f074580d5..ef6883190fe 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -8,35 +8,32 @@ # pylint: disable=too-many-statements +from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Iterator, Union +from typing import Any, Callable, cast from unittest import mock import aiopg +import aiopg.sa import httpx import pytest -from _helpers import ( - PublishedProject, - RunningProject, - assert_comp_run_state, - assert_comp_tasks_state, - manually_run_comp_scheduler, - set_comp_task_state, -) +from _helpers import PublishedProject, RunningProject from dask.distributed import SpecCluster from dask_task_models_library.container_tasks.errors import TaskCancelledError +from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import TaskOutputData from fastapi.applications import FastAPI from models_library.clusters import DEFAULT_CLUSTER_ID -from models_library.projects import ProjectAtDB +from models_library.projects import ProjectAtDB, ProjectID +from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState +from pydantic import parse_obj_as from pytest import MonkeyPatch from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from settings_library.rabbit import RabbitSettings -from simcore_postgres_database.models.comp_pipeline import StateType from simcore_postgres_database.models.comp_runs import comp_runs -from simcore_postgres_database.models.comp_tasks import NodeClass +from simcore_postgres_database.models.comp_tasks import NodeClass, comp_tasks from simcore_service_director_v2.core.application import init_app from simcore_service_director_v2.core.errors import ( ComputationalBackendNotConnectedError, @@ -50,10 +47,15 @@ from simcore_service_director_v2.core.settings import AppSettings from simcore_service_director_v2.models.domains.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.domains.comp_runs import CompRunsAtDB +from simcore_service_director_v2.models.domains.comp_tasks import CompTaskAtDB from simcore_service_director_v2.modules.comp_scheduler import background_task from simcore_service_director_v2.modules.comp_scheduler.base_scheduler import ( BaseCompScheduler, ) +from simcore_service_director_v2.modules.comp_scheduler.dask_scheduler import ( + DaskScheduler, +) +from simcore_service_director_v2.utils.dask_client_utils import TaskHandlers from simcore_service_director_v2.utils.scheduler import COMPLETED_STATES from starlette.testclient import TestClient @@ -63,6 +65,72 @@ ] +def _assert_dask_client_correctly_initialized( + mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler +) -> None: + mocked_dask_client.create.assert_called_once_with( + app=mock.ANY, + settings=mock.ANY, + endpoint=mock.ANY, + authentication=mock.ANY, + tasks_file_link_type=mock.ANY, + ) + mocked_dask_client.register_handlers.assert_called_once_with( + TaskHandlers( + cast(DaskScheduler, scheduler)._task_progress_change_handler, + cast(DaskScheduler, scheduler)._task_log_change_handler, + ) + ) + + +async def _assert_comp_run_db( + aiopg_engine: aiopg.sa.engine.Engine, + pub_project: PublishedProject, + expected_state: RunningState, +) -> None: + # check the database is correctly updated, the run is published + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_runs.select().where( + (comp_runs.c.user_id == pub_project.project.prj_owner) + & (comp_runs.c.project_uuid == f"{pub_project.project.uuid}") + ) # there is only one entry + ) + run_entry = CompRunsAtDB.parse_obj(await result.first()) + assert ( + run_entry.result == expected_state + ), f"comp_runs: expected state '{expected_state}, found '{run_entry.result}'" + + +async def _assert_comp_tasks_db( + aiopg_engine: aiopg.sa.engine.Engine, + project_uuid: ProjectID, + task_ids: list[NodeID], + *, + expected_state: RunningState, + expected_progress: float | None, +) -> None: + # check the database is correctly updated, the run is published + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_tasks.select().where( + (comp_tasks.c.project_id == f"{project_uuid}") + & (comp_tasks.c.node_id.in_([f"{n}" for n in task_ids])) + ) # there is only one entry + ) + tasks = parse_obj_as(list[CompTaskAtDB], await result.fetchall()) + assert all( + t.state == expected_state for t in tasks + ), f"expected state: {expected_state}, found: {[t.state for t in tasks]}" + assert all( + t.progress == expected_progress for t in tasks + ), f"{expected_progress=}, found: {[t.progress for t in tasks]}" + + +async def run_comp_scheduler(scheduler: BaseCompScheduler) -> None: + await scheduler.schedule_all_pipelines() + + @pytest.fixture def minimal_dask_scheduler_config( mock_env: EnvVarsDict, @@ -87,7 +155,7 @@ def minimal_dask_scheduler_config( @pytest.fixture def scheduler( minimal_dask_scheduler_config: None, - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, # dask_spec_local_cluster: SpecCluster, minimal_app: FastAPI, ) -> BaseCompScheduler: @@ -106,10 +174,9 @@ def mocked_dask_client(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture -def mocked_node_ports(mocker: MockerFixture): - mocker.patch( +def mocked_parse_output_data_fct(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( "simcore_service_director_v2.modules.comp_scheduler.dask_scheduler.parse_output_data", - return_value=None, autospec=True, ) @@ -124,7 +191,7 @@ def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture -def mocked_scheduler_task(mocker: MockerFixture) -> None: +def with_disabled_scheduler_task(mocker: MockerFixture) -> None: """disables the scheduler task, note that it needs to be triggered manually then""" mocker.patch.object(background_task, "scheduler_task") @@ -136,12 +203,20 @@ async def minimal_app(async_client: httpx.AsyncClient) -> FastAPI: # a new thread on which it creates a new loop # causing issues downstream with coroutines not # being created on the same loop - return async_client._transport.app + return async_client._transport.app # type: ignore + + +@pytest.fixture +def mocked_clean_task_output_and_log_files_if_invalid(mocker: MockerFixture) -> None: + mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler.dask_scheduler.clean_task_output_and_log_files_if_invalid", + autospec=True, + ) async def test_scheduler_gracefully_starts_and_stops( minimal_dask_scheduler_config: None, - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, dask_spec_local_cluster: SpecCluster, minimal_app: FastAPI, ): @@ -158,7 +233,7 @@ async def test_scheduler_gracefully_starts_and_stops( ) def test_scheduler_raises_exception_for_missing_dependencies( minimal_dask_scheduler_config: None, - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, dask_spec_local_cluster: SpecCluster, monkeypatch: MonkeyPatch, missing_dependency: str, @@ -175,13 +250,12 @@ def test_scheduler_raises_exception_for_missing_dependencies( async def test_empty_pipeline_is_not_scheduled( - mocked_scheduler_task: None, + with_disabled_scheduler_task: None, scheduler: BaseCompScheduler, - minimal_app: FastAPI, registered_user: Callable[..., dict[str, Any]], project: Callable[..., ProjectAtDB], pipeline: Callable[..., CompPipelineAtDB], - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, ): user = registered_user() empty_project = project(user) @@ -194,7 +268,7 @@ async def test_empty_pipeline_is_not_scheduled( cluster_id=DEFAULT_CLUSTER_ID, ) # create the empty pipeline now - _empty_pipeline = pipeline(project_id=f"{empty_project.uuid}") + pipeline(project_id=f"{empty_project.uuid}") # creating a run with an empty pipeline is useless, check the scheduler is not kicking in await scheduler.run_new_pipeline( @@ -204,35 +278,34 @@ async def test_empty_pipeline_is_not_scheduled( ) assert len(scheduler.scheduled_pipelines) == 0 assert ( - scheduler.wake_up_event.is_set() == False + scheduler.wake_up_event.is_set() is False ), "the scheduler was woken up on an empty pipeline!" # check the database is empty - async with aiopg_engine.acquire() as conn: # type: ignore + async with aiopg_engine.acquire() as conn: result = await conn.scalar( comp_runs.select().where( (comp_runs.c.user_id == user["id"]) & (comp_runs.c.project_uuid == f"{empty_project.uuid}") ) # there is only one entry ) - assert result == None + assert result is None async def test_misconfigured_pipeline_is_not_scheduled( - mocked_scheduler_task: None, + with_disabled_scheduler_task: None, scheduler: BaseCompScheduler, - minimal_app: FastAPI, registered_user: Callable[..., dict[str, Any]], project: Callable[..., ProjectAtDB], pipeline: Callable[..., CompPipelineAtDB], fake_workbench_without_outputs: dict[str, Any], fake_workbench_adjacency: dict[str, Any], - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, ): """A pipeline which comp_tasks are missing should not be scheduled. It shall be aborted and shown as such in the comp_runs db""" user = registered_user() sleepers_project = project(user, workbench=fake_workbench_without_outputs) - sleepers_pipeline = pipeline( + pipeline( project_id=f"{sleepers_project.uuid}", dag_adjacency_list=fake_workbench_adjacency, ) @@ -244,15 +317,15 @@ async def test_misconfigured_pipeline_is_not_scheduled( ) assert len(scheduler.scheduled_pipelines) == 1 assert ( - scheduler.wake_up_event.is_set() == True + scheduler.wake_up_event.is_set() is True ), "the scheduler was NOT woken up on the scheduled pipeline!" for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): assert u_id == user["id"] assert p_id == sleepers_project.uuid assert it > 0 - assert params.mark_for_cancellation == False + assert params.mark_for_cancellation is False # check the database was properly updated - async with aiopg_engine.acquire() as conn: # type: ignore + async with aiopg_engine.acquire() as conn: result = await conn.execute( comp_runs.select().where( (comp_runs.c.user_id == user["id"]) @@ -262,11 +335,11 @@ async def test_misconfigured_pipeline_is_not_scheduled( run_entry = CompRunsAtDB.parse_obj(await result.first()) assert run_entry.result == RunningState.PUBLISHED # let the scheduler kick in - await manually_run_comp_scheduler(scheduler) + await run_comp_scheduler(scheduler) # check the scheduled pipelines is again empty since it's misconfigured assert len(scheduler.scheduled_pipelines) == 0 # check the database entry is correctly updated - async with aiopg_engine.acquire() as conn: # type: ignore + async with aiopg_engine.acquire() as conn: result = await conn.execute( comp_runs.select().where( (comp_runs.c.user_id == user["id"]) @@ -277,15 +350,11 @@ async def test_misconfigured_pipeline_is_not_scheduled( assert run_entry.result == RunningState.ABORTED -async def test_proper_pipeline_is_scheduled( - mocked_scheduler_task: None, - mocked_dask_client: mock.MagicMock, - scheduler: BaseCompScheduler, - minimal_app: FastAPI, - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore - published_project: PublishedProject, -): - # This calls adds starts the scheduling of a pipeline +async def _assert_start_pipeline( + aiopg_engine, published_project: PublishedProject, scheduler: BaseCompScheduler +) -> list[CompTaskAtDB]: + exp_published_tasks = deepcopy(published_project.tasks) + assert published_project.project.prj_owner await scheduler.run_new_pipeline( user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, @@ -293,49 +362,58 @@ async def test_proper_pipeline_is_scheduled( ) assert len(scheduler.scheduled_pipelines) == 1, "the pipeline is not scheduled!" assert ( - scheduler.wake_up_event.is_set() == True + scheduler.wake_up_event.is_set() is True ), "the scheduler was NOT woken up on the scheduled pipeline!" for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): assert u_id == published_project.project.prj_owner assert p_id == published_project.project.uuid assert it > 0 - assert params.mark_for_cancellation == False + assert params.mark_for_cancellation is False + # check the database is correctly updated, the run is published - await assert_comp_run_state( + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( aiopg_engine, - published_project.project.prj_owner, published_project.project.uuid, - exp_state=RunningState.PUBLISHED, + [p.node_id for p in exp_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, ) - published_tasks = [ - published_project.tasks[1], - published_project.tasks[3], + return exp_published_tasks + + +async def _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project: PublishedProject, + published_tasks: list[CompTaskAtDB], + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, +) -> list[CompTaskAtDB]: + expected_pending_tasks = [ + published_tasks[1], + published_tasks[3], ] - # trigger the scheduler - await manually_run_comp_scheduler(scheduler) - # the client should be created here - mocked_dask_client.create.assert_called_once_with( - app=mock.ANY, - settings=mock.ANY, - endpoint=mock.ANY, - authentication=mock.ANY, - tasks_file_link_type=mock.ANY, - ) - # the tasks are set to pending, so they are ready to be taken, and the dask client is triggered - await assert_comp_tasks_state( + for p in expected_pending_tasks: + published_tasks.remove(p) + await run_comp_scheduler(scheduler) + _assert_dask_client_correctly_initialized(mocked_dask_client, scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [p.node_id for p in published_tasks], - exp_state=RunningState.PENDING, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=0, ) - # the other tasks are published - await assert_comp_tasks_state( + # the other tasks are still waiting in published state + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [p.node_id for p in published_project.tasks if p not in published_tasks], - exp_state=RunningState.PUBLISHED, + [p.node_id for p in published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct ) - + # tasks were send to the backend mocked_dask_client.send_computation_tasks.assert_has_calls( calls=[ mock.call( @@ -345,161 +423,320 @@ async def test_proper_pipeline_is_scheduled( tasks={f"{p.node_id}": p.image}, callback=scheduler._wake_up_scheduler_now, ) - for p in published_tasks + for p in expected_pending_tasks ], any_order=True, ) mocked_dask_client.send_computation_tasks.reset_mock() - - # trigger the scheduler - await manually_run_comp_scheduler(scheduler) - # let the scheduler kick in, it should switch to the run state to PENDING state, to reflect the tasks states - await assert_comp_run_state( + mocked_dask_client.get_tasks_status.assert_not_called() + mocked_dask_client.get_task_result.assert_not_called() + # there is a second run of the scheduler to move comp_runs to pending, the rest does not change + await run_comp_scheduler(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) + await _assert_comp_tasks_db( aiopg_engine, - published_project.project.prj_owner, published_project.project.uuid, - exp_state=RunningState.PENDING, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=0, ) - # no change here - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, [p.node_id for p in published_tasks], - exp_state=RunningState.PENDING, + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct ) mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_has_calls( + calls=[mock.call([p.job_id for p in expected_pending_tasks])], any_order=True + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + return expected_pending_tasks - # change 1 task to RUNNING - running_task_id = published_tasks[0].node_id - await set_comp_task_state( - aiopg_engine, - node_id=f"{running_task_id}", - state=StateType.RUNNING, + +@pytest.mark.acceptance_test +async def test_proper_pipeline_is_scheduled( + with_disabled_scheduler_task: None, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + mocked_parse_output_data_fct: mock.Mock, + mocked_clean_task_output_and_log_files_if_invalid: None, +): + expected_published_tasks = await _assert_start_pipeline( + aiopg_engine, published_project, scheduler ) - # trigger the scheduler, comp_run is now STARTED, as is the task - await manually_run_comp_scheduler(scheduler) - await assert_comp_run_state( + # ------------------------------------------------------------------------------- + # 1. first run will move comp_tasks to PENDING so the worker can take them + expected_pending_tasks = await _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler, + ) + + # ------------------------------------------------------------------------------- + # 3. the "worker" starts processing a task + exp_started_task = expected_pending_tasks[0] + expected_pending_tasks.remove(exp_started_task) + + async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]: + return [ + RunningState.STARTED + if job_id == exp_started_task.job_id + else RunningState.PENDING + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running + await run_comp_scheduler(scheduler) + # comp_run, the comp_task switch to STARTED + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( aiopg_engine, - published_project.project.prj_owner, published_project.project.uuid, - RunningState.STARTED, + [exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=0, ) - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [running_task_id], - exp_state=RunningState.STARTED, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=0, ) - mocked_dask_client.send_computation_tasks.assert_not_called() - - # change the task to SUCCESS - await set_comp_task_state( + await _assert_comp_tasks_db( aiopg_engine, - node_id=f"{running_task_id}", - state=StateType.SUCCESS, + published_project.project.uuid, + [p.node_id for p in expected_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct ) - # trigger the scheduler, the run state is still STARTED, the task is completed - await manually_run_comp_scheduler(scheduler) - await assert_comp_run_state( + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in ([exp_started_task] + expected_pending_tasks)], + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + + # ------------------------------------------------------------------------------- + # 4. the "worker" completed the task successfully + async def _return_1st_task_success(job_ids: list[str]) -> list[RunningState]: + return [ + RunningState.SUCCESS + if job_id == exp_started_task.job_id + else RunningState.PENDING + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_success + + async def _return_random_task_result(job_id) -> TaskOutputData: + return TaskOutputData.parse_obj({"out_1": None, "out_2": 45}) + + mocked_dask_client.get_task_result.side_effect = _return_random_task_result + await run_comp_scheduler(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( aiopg_engine, - published_project.project.prj_owner, published_project.project.uuid, - RunningState.STARTED, - ) - await assert_comp_tasks_state( + [exp_started_task.node_id], + expected_state=RunningState.SUCCESS, + expected_progress=1, + ) + completed_tasks = [exp_started_task] + next_pending_task = published_project.tasks[2] + expected_pending_tasks.append(next_pending_task) + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [running_task_id], - exp_state=RunningState.SUCCESS, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=0, ) - next_published_task = published_project.tasks[2] - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [next_published_task.node_id], - exp_state=RunningState.PENDING, + [ + p.node_id + for p in published_project.tasks + if p not in expected_pending_tasks + completed_tasks + ], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct ) mocked_dask_client.send_computation_tasks.assert_called_once_with( user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, tasks={ - f"{next_published_task.node_id}": next_published_task.image, + f"{next_pending_task.node_id}": next_pending_task.image, }, callback=scheduler._wake_up_scheduler_now, ) mocked_dask_client.send_computation_tasks.reset_mock() - - # change 1 task to RUNNING - await set_comp_task_state( - aiopg_engine, - node_id=f"{next_published_task.node_id}", - state=StateType.RUNNING, + mocked_dask_client.get_tasks_status.assert_has_calls( + calls=[ + mock.call([p.job_id for p in completed_tasks + expected_pending_tasks[:1]]) + ], + any_order=True, ) - # trigger the scheduler, run state should keep to STARTED, task should be as well - await manually_run_comp_scheduler(scheduler) - await assert_comp_run_state( - aiopg_engine, - published_project.project.prj_owner, - published_project.project.uuid, - RunningState.STARTED, + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_called_once_with( + completed_tasks[0].job_id ) - await assert_comp_tasks_state( - aiopg_engine, - published_project.project.uuid, - [next_published_task.node_id], - exp_state=RunningState.STARTED, + mocked_dask_client.get_task_result.reset_mock() + mocked_parse_output_data_fct.assert_called_once_with( + mock.ANY, + completed_tasks[0].job_id, + await _return_random_task_result(completed_tasks[0].job_id), ) - mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_parse_output_data_fct.reset_mock() - # now change the task to FAILED - await set_comp_task_state( - aiopg_engine, - node_id=f"{next_published_task.node_id}", - state=StateType.FAILED, - ) - # trigger the scheduler, it should keep to STARTED state until it finishes - await manually_run_comp_scheduler(scheduler) - await assert_comp_run_state( + # ------------------------------------------------------------------------------- + # 6. the "worker" starts processing a task + exp_started_task = next_pending_task + + async def _return_2nd_task_running(job_ids: list[str]) -> list[RunningState]: + return [ + RunningState.STARTED + if job_id == exp_started_task.job_id + else RunningState.PENDING + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_running + # trigger the scheduler, run state should keep to STARTED, task should be as well + await run_comp_scheduler(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( aiopg_engine, - published_project.project.prj_owner, published_project.project.uuid, - RunningState.STARTED, + [exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=0, ) - await assert_comp_tasks_state( + mocked_dask_client.send_computation_tasks.assert_not_called() + expected_pending_tasks.reverse() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + + # ------------------------------------------------------------------------------- + # 7. the task fails + async def _return_2nd_task_failed(job_ids: list[str]) -> list[RunningState]: + return [ + RunningState.FAILED + if job_id == exp_started_task.job_id + else RunningState.PENDING + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed + mocked_dask_client.get_task_result.side_effect = None + await run_comp_scheduler(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [next_published_task.node_id], - exp_state=RunningState.FAILED, + [exp_started_task.node_id], + expected_state=RunningState.FAILED, + expected_progress=1, ) mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_called_once_with(exp_started_task.job_id) + mocked_dask_client.get_task_result.reset_mock() + mocked_parse_output_data_fct.assert_not_called() + expected_pending_tasks.remove(exp_started_task) + + # ------------------------------------------------------------------------------- + # 8. the last task shall succeed + exp_started_task = expected_pending_tasks[0] + + async def _return_3rd_task_success(job_ids: list[str]) -> list[RunningState]: + return [ + RunningState.SUCCESS + if job_id == exp_started_task.job_id + else RunningState.PENDING + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_3rd_task_success + mocked_dask_client.get_task_result.side_effect = _return_random_task_result - # now change the other task to SUCCESS - other_task = published_tasks[1] - await set_comp_task_state( - aiopg_engine, - node_id=f"{other_task.node_id}", - state=StateType.SUCCESS, - ) # trigger the scheduler, it should switch to FAILED, as we are done - await manually_run_comp_scheduler(scheduler) - await assert_comp_run_state( - aiopg_engine, - published_project.project.prj_owner, - published_project.project.uuid, - RunningState.FAILED, - ) - await assert_comp_tasks_state( + await run_comp_scheduler(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) + + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, - [other_task.node_id], - exp_state=RunningState.SUCCESS, + [exp_started_task.node_id], + expected_state=RunningState.SUCCESS, + expected_progress=1, ) mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_task_result.assert_called_once_with(exp_started_task.job_id) # the scheduled pipeline shall be removed assert scheduler.scheduled_pipelines == {} +async def test_task_progress_triggers( + with_disabled_scheduler_task: None, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + mocked_parse_output_data_fct: None, + mocked_clean_task_output_and_log_files_if_invalid: None, +): + expected_published_tasks = await _assert_start_pipeline( + aiopg_engine, published_project, scheduler + ) + # ------------------------------------------------------------------------------- + # 1. first run will move comp_tasks to PENDING so the worker can take them + expected_pending_tasks = await _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler, + ) + + # send some progress + started_task = expected_pending_tasks[0] + assert started_task.job_id + for progress in [-1, 0, 0.3, 0.5, 1, 1.5, 0.7, 0, 20]: + progress_event = TaskProgressEvent( + job_id=started_task.job_id, progress=progress + ) + await cast(DaskScheduler, scheduler)._task_progress_change_handler( + progress_event.json() + ) + # NOTE: not sure whether it should switch to STARTED.. it would make sense + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [started_task.node_id], + expected_state=RunningState.PENDING, + expected_progress=min(max(0, progress), 1), + ) + + @pytest.mark.parametrize( "backend_error", [ @@ -511,11 +748,10 @@ async def test_proper_pipeline_is_scheduled( ], ) async def test_handling_of_disconnected_dask_scheduler( - mocked_scheduler_task: None, + with_disabled_scheduler_task: None, dask_spec_local_cluster: SpecCluster, scheduler: BaseCompScheduler, - minimal_app: FastAPI, - aiopg_engine: Iterator[aiopg.sa.engine.Engine], # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, mocker: MockerFixture, published_project: PublishedProject, backend_error: SchedulerError, @@ -525,8 +761,10 @@ async def test_handling_of_disconnected_dask_scheduler( "simcore_service_director_v2.modules.comp_scheduler.dask_scheduler.DaskClient.send_computation_tasks", side_effect=backend_error, ) + assert mocked_dask_client_send_task # running the pipeline will now raise and the tasks are set back to PUBLISHED + assert published_project.project.prj_owner await scheduler.run_new_pipeline( user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, @@ -535,17 +773,14 @@ async def test_handling_of_disconnected_dask_scheduler( # since there is no cluster, there is no dask-scheduler, # the tasks shall all still be in PUBLISHED state now - await assert_comp_run_state( - aiopg_engine, - published_project.project.prj_owner, - published_project.project.uuid, - RunningState.PUBLISHED, - ) - await assert_comp_tasks_state( + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, [t.node_id for t in published_project.tasks], - exp_state=RunningState.PUBLISHED, + expected_state=RunningState.PUBLISHED, + expected_progress=None, ) # on the next iteration of the pipeline it will try to re-connect # now try to abort the tasks since we are wondering what is happening, this should auto-trigger the scheduler @@ -554,9 +789,9 @@ async def test_handling_of_disconnected_dask_scheduler( project_id=published_project.project.uuid, ) # we ensure the scheduler was run - await manually_run_comp_scheduler(scheduler) + await run_comp_scheduler(scheduler) # after this step the tasks are marked as ABORTED - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, published_project.project.uuid, [ @@ -564,25 +799,23 @@ async def test_handling_of_disconnected_dask_scheduler( for t in published_project.tasks if t.node_class == NodeClass.COMPUTATIONAL ], - exp_state=RunningState.ABORTED, + expected_state=RunningState.ABORTED, + expected_progress=1, ) # then we have another scheduler run - await manually_run_comp_scheduler(scheduler) + await run_comp_scheduler(scheduler) # now the run should be ABORTED - await assert_comp_run_state( - aiopg_engine, - published_project.project.prj_owner, - published_project.project.uuid, - RunningState.ABORTED, - ) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.ABORTED) -@dataclass +@dataclass(frozen=True, kw_only=True) class RebootState: task_status: RunningState - task_result: Union[Exception, TaskOutputData] + task_result: Exception | TaskOutputData expected_task_state_group1: RunningState + expected_task_progress_group1: float expected_task_state_group2: RunningState + expected_task_progress_group2: float expected_run_state: RunningState @@ -591,64 +824,75 @@ class RebootState: [ pytest.param( RebootState( - RunningState.UNKNOWN, - ComputationalBackendTaskNotFoundError(job_id="fake_job_id"), - RunningState.FAILED, - RunningState.ABORTED, - RunningState.FAILED, + task_status=RunningState.UNKNOWN, + task_result=ComputationalBackendTaskNotFoundError(job_id="fake_job_id"), + expected_task_state_group1=RunningState.FAILED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.FAILED, ), id="reboot with lost tasks", ), pytest.param( RebootState( - RunningState.ABORTED, - TaskCancelledError(job_id="fake_job_id"), - RunningState.ABORTED, - RunningState.ABORTED, - RunningState.ABORTED, + task_status=RunningState.ABORTED, + task_result=TaskCancelledError(job_id="fake_job_id"), + expected_task_state_group1=RunningState.ABORTED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.ABORTED, ), id="reboot with aborted tasks", ), pytest.param( RebootState( - RunningState.FAILED, - ValueError("some error during the call"), - RunningState.FAILED, - RunningState.ABORTED, - RunningState.FAILED, + task_status=RunningState.FAILED, + task_result=ValueError("some error during the call"), + expected_task_state_group1=RunningState.FAILED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.FAILED, ), id="reboot with failed tasks", ), pytest.param( RebootState( - RunningState.STARTED, - ComputationalBackendTaskResultsNotReadyError(job_id="fake_job_id"), - RunningState.STARTED, - RunningState.STARTED, - RunningState.STARTED, + task_status=RunningState.STARTED, + task_result=ComputationalBackendTaskResultsNotReadyError( + job_id="fake_job_id" + ), + expected_task_state_group1=RunningState.STARTED, + expected_task_progress_group1=0, + expected_task_state_group2=RunningState.STARTED, + expected_task_progress_group2=0, + expected_run_state=RunningState.STARTED, ), id="reboot with running tasks", ), pytest.param( RebootState( - RunningState.SUCCESS, - TaskOutputData.parse_obj({"whatever_output": 123}), - RunningState.SUCCESS, - RunningState.SUCCESS, - RunningState.SUCCESS, + task_status=RunningState.SUCCESS, + task_result=TaskOutputData.parse_obj({"whatever_output": 123}), + expected_task_state_group1=RunningState.SUCCESS, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.SUCCESS, + expected_task_progress_group2=1, + expected_run_state=RunningState.SUCCESS, ), id="reboot with completed tasks", ), ], ) async def test_handling_scheduling_after_reboot( - mocked_scheduler_task: None, + with_disabled_scheduler_task: None, mocked_dask_client: mock.MagicMock, - aiopg_engine: aiopg.sa.engine.Engine, # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, running_project: RunningProject, scheduler: BaseCompScheduler, - minimal_app: FastAPI, - mocked_node_ports: None, + mocked_parse_output_data_fct: mock.MagicMock, mocked_clean_task_output_fct: mock.MagicMock, reboot_state: RebootState, ): @@ -668,7 +912,7 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: mocked_dask_client.get_task_result.side_effect = mocked_get_task_result - await manually_run_comp_scheduler(scheduler) + await run_comp_scheduler(scheduler) # the status will be called once for all RUNNING tasks mocked_dask_client.get_tasks_status.assert_called_once() if reboot_state.expected_run_state in COMPLETED_STATES: @@ -700,7 +944,7 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: else: mocked_clean_task_output_fct.assert_not_called() - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, running_project.project.uuid, [ @@ -708,17 +952,17 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: running_project.tasks[2].node_id, running_project.tasks[3].node_id, ], - exp_state=reboot_state.expected_task_state_group1, + expected_state=reboot_state.expected_task_state_group1, + expected_progress=reboot_state.expected_task_progress_group1, ) - await assert_comp_tasks_state( + await _assert_comp_tasks_db( aiopg_engine, running_project.project.uuid, [running_project.tasks[4].node_id], - exp_state=reboot_state.expected_task_state_group2, + expected_state=reboot_state.expected_task_state_group2, + expected_progress=reboot_state.expected_task_progress_group2, ) - await assert_comp_run_state( - aiopg_engine, - running_project.project.prj_owner, - running_project.project.uuid, - exp_state=reboot_state.expected_run_state, + assert running_project.project.prj_owner + await _assert_comp_run_db( + aiopg_engine, running_project, reboot_state.expected_run_state ) diff --git a/services/director-v2/tests/unit/with_dbs/test_utils_dask.py b/services/director-v2/tests/unit/with_dbs/test_utils_dask.py index f3a1d491703..edc7bd02836 100644 --- a/services/director-v2/tests/unit/with_dbs/test_utils_dask.py +++ b/services/director-v2/tests/unit/with_dbs/test_utils_dask.py @@ -12,6 +12,7 @@ from unittest import mock import aiopg +import aiopg.sa import httpx import pytest from _helpers import PublishedProject, set_comp_task_inputs, set_comp_task_outputs @@ -224,7 +225,7 @@ def fake_task_output_data( async def test_parse_output_data( - aiopg_engine: aiopg.sa.engine.Engine, # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, published_project: PublishedProject, user_id: UserID, fake_io_schema: dict[str, dict[str, str]], @@ -279,7 +280,7 @@ def app_with_db( async def test_compute_input_data( app_with_db: None, - aiopg_engine: aiopg.sa.engine.Engine, # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, async_client: httpx.AsyncClient, user_id: UserID, published_project: PublishedProject, @@ -348,7 +349,7 @@ def tasks_file_link_scheme(tasks_file_link_type: FileLinkType) -> tuple: async def test_compute_output_data_schema( app_with_db: None, - aiopg_engine: aiopg.sa.engine.Engine, # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, async_client: httpx.AsyncClient, user_id: UserID, published_project: PublishedProject, @@ -390,7 +391,7 @@ async def test_compute_output_data_schema( @pytest.mark.parametrize("entry_exists_returns", [True, False]) async def test_clean_task_output_and_log_files_if_invalid( - aiopg_engine: aiopg.sa.engine.Engine, # type: ignore + aiopg_engine: aiopg.sa.engine.Engine, user_id: UserID, published_project: PublishedProject, mocked_node_ports_filemanager_fcts: dict[str, mock.MagicMock],