From 5ab53f35e8890f73dd9f53b806ce3dc519367035 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 22 Oct 2024 17:18:01 -0700 Subject: [PATCH] workflow parameter validation --- skyvern/exceptions.py | 11 ++++++ .../forge/sdk/workflow/models/parameter.py | 38 ++++++++++++------- skyvern/forge/sdk/workflow/service.py | 18 ++++++++- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index d0c580ef2..49bace7af 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -510,3 +510,14 @@ def __init__(self, host: str) -> None: f"The host in your url is blocked: {host}", status_code=status.HTTP_400_BAD_REQUEST, ) + + +class InvalidWorkflowParameter(SkyvernHTTPException): + def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None: + message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}." + if workflow_permanent_id: + message += f" Workflow permanent id: {workflow_permanent_id}" + super().__init__( + message, + status_code=status.HTTP_400_BAD_REQUEST, + ) diff --git a/skyvern/forge/sdk/workflow/models/parameter.py b/skyvern/forge/sdk/workflow/models/parameter.py index 1ebd943cd..ae611db9f 100644 --- a/skyvern/forge/sdk/workflow/models/parameter.py +++ b/skyvern/forge/sdk/workflow/models/parameter.py @@ -2,10 +2,12 @@ import json from datetime import datetime from enum import StrEnum -from typing import Annotated, Literal, Union +from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, ConfigDict, Field +from skyvern.exceptions import InvalidWorkflowParameter + class ParameterType(StrEnum): WORKFLOW = "workflow" @@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum): JSON = "json" FILE_URL = "file_url" - def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None: + def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None: if value is None: return None - if self == WorkflowParameterType.STRING: - return value - elif self == WorkflowParameterType.INTEGER: - return int(value) - elif self == WorkflowParameterType.FLOAT: - return float(value) - elif self == WorkflowParameterType.BOOLEAN: - return value.lower() in ["true", "1"] - elif self == WorkflowParameterType.JSON: - return json.loads(value) - elif self == WorkflowParameterType.FILE_URL: - return value + try: + if self == WorkflowParameterType.STRING: + return str(value) + elif self == WorkflowParameterType.INTEGER: + return int(value) + elif self == WorkflowParameterType.FLOAT: + return float(value) + elif self == WorkflowParameterType.BOOLEAN: + if isinstance(value, bool): + return value + lower_case = str(value).lower() + if lower_case in ["true", "false", "1", "0"]: + raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value)) + return lower_case in ["true", "1"] + elif self == WorkflowParameterType.JSON: + return json.loads(value) + elif self == WorkflowParameterType.FILE_URL: + return value + except Exception: + raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value)) class WorkflowParameter(Parameter): diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index f008af3ce..4c0de4157 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -1,11 +1,18 @@ import json from datetime import datetime +from typing import Any import requests import structlog from skyvern import analytics -from skyvern.exceptions import FailedToSendWebhook, MissingValueForParameter, WorkflowNotFound, WorkflowRunNotFound +from skyvern.exceptions import ( + FailedToSendWebhook, + MissingValueForParameter, + WorkflowNotFound, + WorkflowParameterNotFound, + WorkflowRunNotFound, +) from skyvern.forge import app from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context @@ -566,8 +573,15 @@ async def create_workflow_run_parameter( self, workflow_run_id: str, workflow_parameter_id: str, - value: bool | int | float | str | dict | list, + value: Any, ) -> WorkflowRunParameter: + # get workflow parameter id first and validate the value according to the workflow_parameter.workflow_parameter_type + workflow_parameter = await app.DATABASE.get_workflow_parameter(workflow_parameter_id) + if not workflow_parameter: + raise WorkflowParameterNotFound(workflow_parameter_id) + # InvalidWorkflowParameter will be raised if the validation fails + workflow_parameter.workflow_parameter_type.convert_value(value) + return await app.DATABASE.create_workflow_run_parameter( workflow_run_id=workflow_run_id, workflow_parameter_id=workflow_parameter_id,