Skip to content

Commit

Permalink
Merge pull request #1688 from fractal-analytics-platform/1686-tbd-dep…
Browse files Browse the repository at this point in the history
…recate-pre-populating-default-arguments-based-on-json-schemas

Do not set default argument values from JSON Schemas
  • Loading branch information
tcompa committed Jul 25, 2024
2 parents 3478d3d + 82037f5 commit 69ba50d
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 136 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
**Note**: Numbers like (\#1234) point to closed Pull Requests on the fractal-server repository.


# 2.3.6 (Unreleased)
# 2.3.6

* API:
* When creating a WorkflowTask, do not pre-populate its top-level arguments based on JSON Schema default values (\#1688).
* Dependencies:
* Update `sqlmodel` to `^0.0.21` (\#1674).

Expand Down
52 changes: 0 additions & 52 deletions fractal_server/app/models/v2/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
import logging
from typing import Any
from typing import Optional

Expand Down Expand Up @@ -41,53 +39,3 @@ class TaskV2(SQLModel, table=True):

input_types: dict[str, bool] = Field(sa_column=Column(JSON), default={})
output_types: dict[str, bool] = Field(sa_column=Column(JSON), default={})

@property
def default_args_non_parallel_from_args_schema(self) -> dict[str, Any]:
"""
Extract default arguments from args_schema
"""
# Return {} if there is no args_schema
if self.args_schema_non_parallel is None:
return {}
# Try to construct default_args
try:
default_args = {}
properties = self.args_schema_non_parallel["properties"]
for prop_name, prop_schema in properties.items():
default_value = prop_schema.get("default", None)
if default_value is not None:
default_args[prop_name] = default_value
return default_args
except KeyError as e:
logging.warning(
"Cannot set default_args from args_schema_non_parallel="
f"{json.dumps(self.args_schema_non_parallel)}\n"
f"Original KeyError: {str(e)}"
)
return {}

@property
def default_args_parallel_from_args_schema(self) -> dict[str, Any]:
"""
Extract default arguments from args_schema
"""
# Return {} if there is no args_schema
if self.args_schema_parallel is None:
return {}
# Try to construct default_args
try:
default_args = {}
properties = self.args_schema_parallel["properties"]
for prop_name, prop_schema in properties.items():
default_value = prop_schema.get("default", None)
if default_value is not None:
default_args[prop_name] = default_value
return default_args
except KeyError as e:
logging.warning(
"Cannot set default_args from args_schema_parallel="
f"{json.dumps(self.args_schema_parallel)}\n"
f"Original KeyError: {str(e)}"
)
return {}
10 changes: 4 additions & 6 deletions fractal_server/app/routes/api/v2/_aux_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ async def _workflow_insert_task(

# Get task from db, and extract default arguments via a Task property
# method
# NOTE: this logic remains there for V1 tasks only. When we deprecate V1
# tasks, we can simplify this block
if is_legacy_task is True:
db_task = await db.get(Task, task_id)
if db_task is None:
Expand All @@ -439,12 +441,8 @@ async def _workflow_insert_task(
raise ValueError(f"TaskV2 {task_id} not found.")
task_type = db_task.type

final_args_non_parallel = (
db_task.default_args_non_parallel_from_args_schema.copy()
)
final_args_parallel = (
db_task.default_args_parallel_from_args_schema.copy()
)
final_args_non_parallel = {}
final_args_parallel = {}
final_meta_parallel = (db_task.meta_parallel or {}).copy()
final_meta_non_parallel = (db_task.meta_non_parallel or {}).copy()

Expand Down
29 changes: 6 additions & 23 deletions fractal_server/app/routes/api/v2/workflowtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,34 +186,17 @@ async def update_workflowtask(
default_args = (
db_wf_task.task_legacy.default_args_from_args_schema
)
actual_args = deepcopy(default_args)
if value is not None:
for k, v in value.items():
actual_args[k] = v
else:
default_args = (
db_wf_task.task.default_args_parallel_from_args_schema
)
# Override default_args with args value items
actual_args = deepcopy(default_args)
if value is not None:
for k, v in value.items():
actual_args[k] = v
actual_args = deepcopy(value)
if not actual_args:
actual_args = None
setattr(db_wf_task, key, actual_args)
elif key == "args_non_parallel":
# Get default arguments via a Task property method
if db_wf_task.is_legacy_task:
# This is only needed so that we don't have to modify the rest
# of this block, but legacy task cannot take any non-parallel
# args (see checks above).
default_args = {}
else:
default_args = deepcopy(
db_wf_task.task.default_args_non_parallel_from_args_schema
)
# Override default_args with args value items
actual_args = default_args.copy()
if value is not None:
for k, v in value.items():
actual_args[k] = v
actual_args = deepcopy(value)
if not actual_args:
actual_args = None
setattr(db_wf_task, key, actual_args)
Expand Down
49 changes: 0 additions & 49 deletions tests/v2/02_models/test_models_task.py

This file was deleted.

7 changes: 2 additions & 5 deletions tests/v2/03_api/test_api_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ class _Arguments(BaseModel):
patched_workflow_task = res.json()
debug(patched_workflow_task["args_non_parallel"])
assert patched_workflow_task["args_non_parallel"] == dict(
a=123, b="two", d=[1, 2, 3], e="something"
a=123, b="two", e="something"
)
assert res.status_code == 200

Expand All @@ -838,10 +838,7 @@ class _Arguments(BaseModel):
)
patched_workflow_task = res.json()
debug(patched_workflow_task["args_non_parallel"])
assert (
patched_workflow_task["args_non_parallel"]
== task.default_args_non_parallel_from_args_schema
)
assert patched_workflow_task["args_non_parallel"] is None
assert res.status_code == 200


Expand Down

0 comments on commit 69ba50d

Please sign in to comment.