Skip to content

Commit

Permalink
workflow parameter validation
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng committed Oct 23, 2024
1 parent 7cba401 commit 5ab53f3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
11 changes: 11 additions & 0 deletions skyvern/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,14 @@ def __init__(self, host: str) -> None:
f"The host in your url is blocked: {host}",
status_code=status.HTTP_400_BAD_REQUEST,
)


class InvalidWorkflowParameter(SkyvernHTTPException):
def __init__(self, expected_parameter_type: str, value: str, workflow_permanent_id: str | None = None) -> None:
message = f"Invalid workflow parameter. Excpected parameter type: {expected_parameter_type}. Value: {value}."
if workflow_permanent_id:
message += f" Workflow permanent id: {workflow_permanent_id}"
super().__init__(
message,
status_code=status.HTTP_400_BAD_REQUEST,
)
38 changes: 24 additions & 14 deletions skyvern/forge/sdk/workflow/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Literal, Union
from typing import Annotated, Any, Literal, Union

from pydantic import BaseModel, ConfigDict, Field

from skyvern.exceptions import InvalidWorkflowParameter


class ParameterType(StrEnum):
WORKFLOW = "workflow"
Expand Down Expand Up @@ -114,21 +116,29 @@ class WorkflowParameterType(StrEnum):
JSON = "json"
FILE_URL = "file_url"

def convert_value(self, value: str | None) -> str | int | float | bool | dict | list | None:
def convert_value(self, value: Any) -> str | int | float | bool | dict | list | None:
if value is None:
return None
if self == WorkflowParameterType.STRING:
return value
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
return value.lower() in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
elif self == WorkflowParameterType.FILE_URL:
return value
try:
if self == WorkflowParameterType.STRING:
return str(value)
elif self == WorkflowParameterType.INTEGER:
return int(value)
elif self == WorkflowParameterType.FLOAT:
return float(value)
elif self == WorkflowParameterType.BOOLEAN:
if isinstance(value, bool):
return value
lower_case = str(value).lower()
if lower_case in ["true", "false", "1", "0"]:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))
return lower_case in ["true", "1"]
elif self == WorkflowParameterType.JSON:
return json.loads(value)
elif self == WorkflowParameterType.FILE_URL:
return value
except Exception:
raise InvalidWorkflowParameter(expected_parameter_type=self, value=str(value))


class WorkflowParameter(Parameter):
Expand Down
18 changes: 16 additions & 2 deletions skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import json
from datetime import datetime
from typing import Any

import requests
import structlog

from skyvern import analytics
from skyvern.exceptions import FailedToSendWebhook, MissingValueForParameter, WorkflowNotFound, WorkflowRunNotFound
from skyvern.exceptions import (
FailedToSendWebhook,
MissingValueForParameter,
WorkflowNotFound,
WorkflowParameterNotFound,
WorkflowRunNotFound,
)
from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
Expand Down Expand Up @@ -566,8 +573,15 @@ async def create_workflow_run_parameter(
self,
workflow_run_id: str,
workflow_parameter_id: str,
value: bool | int | float | str | dict | list,
value: Any,
) -> WorkflowRunParameter:
# get workflow parameter id first and validate the value according to the workflow_parameter.workflow_parameter_type
workflow_parameter = await app.DATABASE.get_workflow_parameter(workflow_parameter_id)
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id)
# InvalidWorkflowParameter will be raised if the validation fails
workflow_parameter.workflow_parameter_type.convert_value(value)

return await app.DATABASE.create_workflow_run_parameter(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
Expand Down

0 comments on commit 5ab53f3

Please sign in to comment.