Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ykeremy/context parameter source parameters #200

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 80 additions & 7 deletions skyvern/forge/sdk/workflow/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from skyvern.forge.sdk.workflow.exceptions import OutputParameterKeyCollisionError
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
BitwardenLoginCredentialParameter,
ContextParameter,
OutputParameter,
Parameter,
ParameterType,
Expand All @@ -30,6 +32,7 @@ def __init__(
self,
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
workflow_output_parameters: list[OutputParameter],
context_parameters: list[ContextParameter],
) -> None:
self.parameters = {}
self.values = {}
Expand All @@ -50,6 +53,12 @@ def __init__(
raise OutputParameterKeyCollisionError(output_parameter.key)
self.parameters[output_parameter.key] = output_parameter

for context_parameter in context_parameters:
# All context parameters will be registered with the context manager during initialization but the values
# will be calculated and set before and after each block execution
# values sometimes will be overwritten by the block execution itself
self.parameters[context_parameter.key] = context_parameter

def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]

Expand Down Expand Up @@ -175,9 +184,32 @@ async def register_parameter_value(
BitwardenService.logout()
LOG.error(f"Failed to get secret from Bitwarden. Error: {e}")
raise e
elif parameter.parameter_type == ParameterType.CONTEXT:
# ContextParameter values will be set within the blocks
return
elif isinstance(parameter, ContextParameter):
if isinstance(parameter.source, WorkflowParameter):
# TODO (kerem): set this while initializing the context manager
workflow_parameter_value = self.get_value(parameter.source.key)
if not isinstance(workflow_parameter_value, dict):
raise ValueError(f"ContextParameter source value is not a dict. Parameter key: {parameter.key}")
parameter.value = workflow_parameter_value.get(parameter.source.key)
self.parameters[parameter.key] = parameter
self.values[parameter.key] = parameter.value
elif isinstance(parameter.source, ContextParameter):
# TODO (kerem): update this anytime the source parameter value changes in values dict
context_parameter_value = self.get_value(parameter.source.key)
if not isinstance(context_parameter_value, dict):
raise ValueError(f"ContextParameter source value is not a dict. Parameter key: {parameter.key}")
parameter.value = context_parameter_value.get(parameter.source.key)
self.parameters[parameter.key] = parameter
self.values[parameter.key] = parameter.value
elif isinstance(parameter.source, OutputParameter):
# We won't set the value of the ContextParameter if the source is an OutputParameter it'll be set in
# `register_output_parameter_value_post_execution` method
pass
else:
raise NotImplementedError(
f"ContextParameter source has to be a WorkflowParameter, ContextParameter, or OutputParameter. "
f"{parameter.source.parameter_type} is not supported."
)
else:
raise ValueError(f"Unknown parameter type: {parameter.parameter_type}")

Expand All @@ -189,28 +221,66 @@ async def register_output_parameter_value_post_execution(
return

self.values[parameter.key] = value
await self.set_parameter_values_for_output_parameter_dependent_blocks(parameter, value)

async def set_parameter_values_for_output_parameter_dependent_blocks(
self, output_parameter: OutputParameter, value: dict[str, Any] | list | str | None
) -> None:
for key, parameter in self.parameters.items():
if (
isinstance(parameter, ContextParameter)
and isinstance(parameter.source, OutputParameter)
and parameter.source.key == output_parameter.key
):
if parameter.value:
LOG.warning(
f"Context parameter {parameter.key} already has a value, overwriting",
old_value=parameter.value,
new_value=value,
)
if not isinstance(value, dict):
raise ValueError(
f"ContextParameter can't depend on an OutputParameter with a non-dict value. "
f"ContextParameter key: {parameter.key}, "
f"OutputParameter key: {output_parameter.key}, "
f"OutputParameter value: {value}"
)
parameter.value = value.get(parameter.key)
self.parameters[parameter.key] = parameter
self.values[parameter.key] = parameter.value

async def register_block_parameters(
self,
aws_client: AsyncAWSClient,
parameters: list[PARAMETER_TYPE],
) -> None:
# Sort the parameters so that ContextParameter and BitwardenLoginCredentialParameter are processed last
# ContextParameter should be processed at the end since it requires the source parameter to be set
# BitwardenLoginCredentialParameter should be processed last since it requires the URL parameter to be set
parameters.sort(key=lambda x: x.parameter_type != ParameterType.BITWARDEN_LOGIN_CREDENTIAL)
# Python's tuple comparison works lexicographically, so we can sort the parameters by their type in a tuple
parameters.sort(
key=lambda x: (
isinstance(x, ContextParameter),
# This makes sure that ContextParameters witha ContextParameter source are processed after all other
# ContextParameters
isinstance(x.source, ContextParameter) if isinstance(x, ContextParameter) else False,
isinstance(x, BitwardenLoginCredentialParameter),
)
)

for parameter in parameters:
if parameter.key in self.parameters:
LOG.debug(f"Parameter {parameter.key} already registered, skipping")
continue

if parameter.parameter_type == ParameterType.WORKFLOW:
if isinstance(parameter, WorkflowParameter):
LOG.error(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
raise ValueError(
f"Workflow parameter {parameter.key} should have already been set through workflow run parameters"
)
elif parameter.parameter_type == ParameterType.OUTPUT:
elif isinstance(parameter, OutputParameter):
LOG.error(
f"Output parameter {parameter.key} should have already been set through workflow run context init"
)
Expand Down Expand Up @@ -244,8 +314,11 @@ def initialize_workflow_run_context(
workflow_run_id: str,
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
workflow_output_parameters: list[OutputParameter],
context_parameters: list[ContextParameter],
) -> WorkflowRunContext:
workflow_run_context = WorkflowRunContext(workflow_parameter_tuples, workflow_output_parameters)
workflow_run_context = WorkflowRunContext(
workflow_parameter_tuples, workflow_output_parameters, context_parameters
)
self.workflow_run_contexts[workflow_run_id] = workflow_run_context
return workflow_run_context

Expand Down
7 changes: 7 additions & 0 deletions skyvern/forge/sdk/workflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ def __init__(self, duplicate_keys: set[str]) -> None:
class InvalidEmailClientConfiguration(BaseWorkflowException):
def __init__(self, problems: list[str]) -> None:
super().__init__(f"Email client configuration is invalid. These parameters are missing or invalid: {problems}")


class ContextParameterSourceNotDefined(BaseWorkflowException):
def __init__(self, context_parameter_key: str, source_key: str) -> None:
super().__init__(
f"Source parameter key {source_key} for context parameter {context_parameter_key} does not exist."
)
47 changes: 36 additions & 11 deletions skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,17 @@ def get_all_parameters(
workflow_run_context = self.get_workflow_run_context(workflow_run_id)

if self.url and workflow_run_context.has_parameter(self.url):
parameters.append(workflow_run_context.get_parameter(self.url))
if workflow_run_context.has_value(self.url):
LOG.info(
"Task URL is parameterized, using parameter value",
task_url_parameter_value=workflow_run_context.get_value(self.url),
task_url_parameter_key=self.url,
)
self.url = workflow_run_context.get_value(self.url)
else:
# if the parameter is not resolved yet, we'll add it to the list of parameters to resolve
# parameterization of the url would happen when the task is executed
parameters.append(workflow_run_context.get_parameter(self.url))

return parameters

Expand Down Expand Up @@ -300,11 +310,18 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
# TODO (kerem): Should we add support for other types?
raise ValueError("loop_data should be a dict")

loop_block_parameters = self.get_all_parameters(workflow_run_id)
context_parameters = [
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
]
context_parameters = []
for loop_block in self.loop_blocks:
# todo: handle the case where the loop_block is a ForLoopBlock

all_parameters = loop_block.get_all_parameters(workflow_run_id)
for parameter in all_parameters:
if isinstance(parameter, ContextParameter):
context_parameters.append(parameter)

for context_parameter in context_parameters:
if context_parameter.source.key != self.loop_over.key:
continue
if context_parameter.key not in loop_data:
raise ContextParameterValueNotFound(
parameter_key=context_parameter.key,
Expand All @@ -318,15 +335,23 @@ def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any
def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter) or isinstance(self.loop_over, OutputParameter):
parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(parameter_value, list):
return parameter_value
else:
# TODO (kerem): Should we raise an error here?
return [parameter_value]
elif isinstance(self.loop_over, ContextParameter):
parameter_value = self.loop_over.value
if not parameter_value:
source_parameter_value = workflow_run_context.get_value(self.loop_over.source.key)
if isinstance(source_parameter_value, dict):
parameter_value = source_parameter_value.get(self.loop_over.key)
else:
raise ValueError("ContextParameter source value should be a dict")
else:
# TODO (kerem): Implement this for context parameters
raise NotImplementedError

if isinstance(parameter_value, list):
return parameter_value
else:
# TODO (kerem): Should we raise an error here?
return [parameter_value]

async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
Expand Down
2 changes: 1 addition & 1 deletion skyvern/forge/sdk/workflow/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class WorkflowParameter(Parameter):
class ContextParameter(Parameter):
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT

source: WorkflowParameter
source: "ParameterSubclasses"
# value will be populated by the context manager
value: str | int | float | bool | dict | list | None = None

Expand Down
2 changes: 2 additions & 0 deletions skyvern/forge/sdk/workflow/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
from skyvern.forge.sdk.workflow.models.block import BlockResult, BlockTypeVar
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE


class WorkflowRequestBody(BaseModel):
Expand All @@ -21,6 +22,7 @@ class RunWorkflowResponse(BaseModel):


class WorkflowDefinition(BaseModel):
parameters: list[PARAMETER_TYPE]
blocks: List[BlockTypeVar]

def validate(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ContextParameterYAML(ParameterYAML):
# This pattern already works in block.py but since the ParameterType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
parameter_type: Literal[ParameterType.CONTEXT] = ParameterType.CONTEXT # type: ignore
source_workflow_parameter_key: str
source_parameter_key: str


class OutputParameterYAML(ParameterYAML):
Expand Down
35 changes: 29 additions & 6 deletions skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
from skyvern.forge.sdk.workflow.exceptions import (
ContextParameterSourceNotDefined,
WorkflowDefinitionHasDuplicateParameterKeys,
)
from skyvern.forge.sdk.workflow.models.block import (
BlockResult,
BlockType,
Expand All @@ -34,6 +37,7 @@
UploadToS3Block,
)
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
AWSSecretParameter,
ContextParameter,
OutputParameter,
Expand Down Expand Up @@ -145,11 +149,17 @@ async def execute_workflow(
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)

# Get all context parameters from the workflow definition
context_parameters = [
parameter
for parameter in workflow.workflow_definition.parameters
if isinstance(parameter, ContextParameter)
]
# Get all <workflow parameter, workflow run parameter> tuples
wp_wps_tuples = await self.get_workflow_run_parameter_tuples(workflow_run_id=workflow_run.workflow_run_id)
workflow_output_parameters = await self.get_workflow_output_parameters(workflow_id=workflow.workflow_id)
app.WORKFLOW_CONTEXT_MANAGER.initialize_workflow_run_context(
workflow_run_id, wp_wps_tuples, workflow_output_parameters
workflow_run_id, wp_wps_tuples, workflow_output_parameters, context_parameters
)
# Execute workflow blocks
blocks = workflow.workflow_definition.blocks
Expand Down Expand Up @@ -649,10 +659,10 @@ async def create_workflow_from_request(self, organization_id: str, request: Work
organization_id=organization_id,
title=request.title,
description=request.description,
workflow_definition=WorkflowDefinition(blocks=[]),
workflow_definition=WorkflowDefinition(parameters=[], blocks=[]),
)
# Create parameters from the request
parameters = {}
parameters: dict[str, PARAMETER_TYPE] = {}
duplicate_parameter_keys = set()

# We're going to process context parameters after other parameters since they depend on the other parameters
Expand Down Expand Up @@ -701,10 +711,23 @@ async def create_workflow_from_request(self, organization_id: str, request: Work

# Now we can process the context parameters since all other parameters have been created
for context_parameter in context_parameter_yamls:
if context_parameter.source_parameter_key not in parameters:
raise ContextParameterSourceNotDefined(
context_parameter_key=context_parameter.key, source_key=context_parameter.source_parameter_key
)

if context_parameter.key in parameters:
LOG.error(f"Duplicate parameter key {context_parameter.key}")
duplicate_parameter_keys.add(context_parameter.key)
continue

# We're only adding the context parameter to the parameters dict, we're not creating it in the database
# It'll only be stored in the `workflow.workflow_definition`
# todo (kerem): should we have a database table for context parameters?
parameters[context_parameter.key] = ContextParameter(
key=context_parameter.key,
description=context_parameter.description,
source=parameters[context_parameter.source_workflow_parameter_key],
source=parameters[context_parameter.source_parameter_key],
# Context parameters don't have a default value, the value always depends on the source parameter
value=None,
)
Expand All @@ -720,7 +743,7 @@ async def create_workflow_from_request(self, organization_id: str, request: Work
block_label_mapping[block.label] = block

# Set the blocks for the workflow definition
workflow_definition = WorkflowDefinition(blocks=blocks)
workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks)
workflow = await self.update_workflow(
workflow_id=workflow.workflow_id,
workflow_definition=workflow_definition,
Expand Down
4 changes: 3 additions & 1 deletion skyvern/webeye/actions/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ async def handle_download_file_action(
# Start waiting for the download
async with page.expect_download() as download_info:
await asyncio.sleep(0.3)
await page.click(f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS)
await page.click(
f"xpath={xpath}", timeout=SettingsManager.get_settings().BROWSER_ACTION_TIMEOUT_MS, modifiers=["Alt"]
)

download = await download_info.value

Expand Down
Loading