diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ad78faabb..f8a03f22d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,10 @@ **Note**: Numbers like (\#1234) point to closed Pull Requests on the fractal-server repository. -# 2.3.6 (Unreleased) +# 2.3.6 +* API: + * When creating a WorkflowTask, do not pre-populate its top-level arguments based on JSON Schema default values (\#1688). * Dependencies: * Update `sqlmodel` to `^0.0.21` (\#1674). diff --git a/fractal_server/app/models/v2/task.py b/fractal_server/app/models/v2/task.py index 23be8a5b6f..b1e1a225a6 100644 --- a/fractal_server/app/models/v2/task.py +++ b/fractal_server/app/models/v2/task.py @@ -1,5 +1,3 @@ -import json -import logging from typing import Any from typing import Optional @@ -41,53 +39,3 @@ class TaskV2(SQLModel, table=True): input_types: dict[str, bool] = Field(sa_column=Column(JSON), default={}) output_types: dict[str, bool] = Field(sa_column=Column(JSON), default={}) - - @property - def default_args_non_parallel_from_args_schema(self) -> dict[str, Any]: - """ - Extract default arguments from args_schema - """ - # Return {} if there is no args_schema - if self.args_schema_non_parallel is None: - return {} - # Try to construct default_args - try: - default_args = {} - properties = self.args_schema_non_parallel["properties"] - for prop_name, prop_schema in properties.items(): - default_value = prop_schema.get("default", None) - if default_value is not None: - default_args[prop_name] = default_value - return default_args - except KeyError as e: - logging.warning( - "Cannot set default_args from args_schema_non_parallel=" - f"{json.dumps(self.args_schema_non_parallel)}\n" - f"Original KeyError: {str(e)}" - ) - return {} - - @property - def default_args_parallel_from_args_schema(self) -> dict[str, Any]: - """ - Extract default arguments from args_schema - """ - # Return {} if there is no args_schema - if self.args_schema_parallel is None: - return {} - # Try to construct default_args - try: - default_args = {} - properties = self.args_schema_parallel["properties"] - for prop_name, prop_schema in properties.items(): - default_value = prop_schema.get("default", None) - if default_value is not None: - default_args[prop_name] = default_value - return default_args - except KeyError as e: - logging.warning( - "Cannot set default_args from args_schema_parallel=" - f"{json.dumps(self.args_schema_parallel)}\n" - f"Original KeyError: {str(e)}" - ) - return {} diff --git a/fractal_server/app/routes/api/v2/_aux_functions.py b/fractal_server/app/routes/api/v2/_aux_functions.py index 264ac2784e..944379ffd7 100644 --- a/fractal_server/app/routes/api/v2/_aux_functions.py +++ b/fractal_server/app/routes/api/v2/_aux_functions.py @@ -422,6 +422,8 @@ async def _workflow_insert_task( # Get task from db, and extract default arguments via a Task property # method + # NOTE: this logic remains there for V1 tasks only. When we deprecate V1 + # tasks, we can simplify this block if is_legacy_task is True: db_task = await db.get(Task, task_id) if db_task is None: @@ -439,12 +441,8 @@ async def _workflow_insert_task( raise ValueError(f"TaskV2 {task_id} not found.") task_type = db_task.type - final_args_non_parallel = ( - db_task.default_args_non_parallel_from_args_schema.copy() - ) - final_args_parallel = ( - db_task.default_args_parallel_from_args_schema.copy() - ) + final_args_non_parallel = {} + final_args_parallel = {} final_meta_parallel = (db_task.meta_parallel or {}).copy() final_meta_non_parallel = (db_task.meta_non_parallel or {}).copy() diff --git a/fractal_server/app/routes/api/v2/workflowtask.py b/fractal_server/app/routes/api/v2/workflowtask.py index 6d76040315..2fd50ad91d 100644 --- a/fractal_server/app/routes/api/v2/workflowtask.py +++ b/fractal_server/app/routes/api/v2/workflowtask.py @@ -186,34 +186,17 @@ async def update_workflowtask( default_args = ( db_wf_task.task_legacy.default_args_from_args_schema ) + actual_args = deepcopy(default_args) + if value is not None: + for k, v in value.items(): + actual_args[k] = v else: - default_args = ( - db_wf_task.task.default_args_parallel_from_args_schema - ) - # Override default_args with args value items - actual_args = deepcopy(default_args) - if value is not None: - for k, v in value.items(): - actual_args[k] = v + actual_args = deepcopy(value) if not actual_args: actual_args = None setattr(db_wf_task, key, actual_args) elif key == "args_non_parallel": - # Get default arguments via a Task property method - if db_wf_task.is_legacy_task: - # This is only needed so that we don't have to modify the rest - # of this block, but legacy task cannot take any non-parallel - # args (see checks above). - default_args = {} - else: - default_args = deepcopy( - db_wf_task.task.default_args_non_parallel_from_args_schema - ) - # Override default_args with args value items - actual_args = default_args.copy() - if value is not None: - for k, v in value.items(): - actual_args[k] = v + actual_args = deepcopy(value) if not actual_args: actual_args = None setattr(db_wf_task, key, actual_args) diff --git a/tests/v2/02_models/test_models_task.py b/tests/v2/02_models/test_models_task.py deleted file mode 100644 index b0df398b82..0000000000 --- a/tests/v2/02_models/test_models_task.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -from typing import Optional - -from pydantic import BaseModel - - -async def test_default_args_properties(task_factory_v2, caplog): - class Foo(BaseModel): - x: int = 42 - y: Optional[str] = None - - task = await task_factory_v2( - name="task1", - source="source1", - args_schema_non_parallel=Foo.schema(), - args_schema_parallel=Foo.schema(), - ) - assert task.default_args_non_parallel_from_args_schema == Foo().dict( - exclude_none=True - ) - assert task.default_args_parallel_from_args_schema == Foo().dict( - exclude_none=True - ) - - bugged_task = await task_factory_v2( - name="task2", - source="source2", - args_schema_non_parallel={"foo": "bar"}, - args_schema_parallel={"bar": "foo"}, - ) - - # Test KeyErrors - caplog.set_level(logging.WARNING) - caplog.clear() - assert caplog.text == "" - - assert bugged_task.default_args_non_parallel_from_args_schema == {} - assert ( - "Cannot set default_args from args_schema_non_parallel" in caplog.text - ) - assert ( - "Cannot set default_args from args_schema_parallel" not in caplog.text - ) - - assert bugged_task.default_args_parallel_from_args_schema == {} - assert ( - "Cannot set default_args from args_schema_non_parallel" in caplog.text - ) - assert "Cannot set default_args from args_schema_parallel" in caplog.text diff --git a/tests/v2/03_api/test_api_workflow.py b/tests/v2/03_api/test_api_workflow.py index 4473a20e8e..dd6157040f 100644 --- a/tests/v2/03_api/test_api_workflow.py +++ b/tests/v2/03_api/test_api_workflow.py @@ -826,7 +826,7 @@ class _Arguments(BaseModel): patched_workflow_task = res.json() debug(patched_workflow_task["args_non_parallel"]) assert patched_workflow_task["args_non_parallel"] == dict( - a=123, b="two", d=[1, 2, 3], e="something" + a=123, b="two", e="something" ) assert res.status_code == 200 @@ -838,10 +838,7 @@ class _Arguments(BaseModel): ) patched_workflow_task = res.json() debug(patched_workflow_task["args_non_parallel"]) - assert ( - patched_workflow_task["args_non_parallel"] - == task.default_args_non_parallel_from_args_schema - ) + assert patched_workflow_task["args_non_parallel"] is None assert res.status_code == 200