Skip to content

Commit

Permalink
For loop block updates (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
ykeremy authored Apr 10, 2024
1 parent 39d7d91 commit 8c12e2b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
38 changes: 26 additions & 12 deletions skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,25 @@ class ForLoopBlock(Block):

# TODO (kerem): Add support for ContextParameter
loop_over: PARAMETER_TYPE
loop_block: "BlockTypeVar"
loop_blocks: list["BlockTypeVar"]

def get_all_parameters(
self,
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
return self.loop_block.get_all_parameters(workflow_run_id) + [self.loop_over]
parameters = {self.loop_over}

for loop_block in self.loop_blocks:
for parameter in loop_block.get_all_parameters(workflow_run_id):
parameters.add(parameter)
return list(parameters)

def get_loop_block_context_parameters(self, workflow_run_id: str, loop_data: Any) -> list[ContextParameter]:
if not isinstance(loop_data, dict):
# TODO (kerem): Should we add support for other types?
raise ValueError("loop_data should be a dictionary")
raise ValueError("loop_data should be a dict")

loop_block_parameters = self.loop_block.get_all_parameters(workflow_run_id)
loop_block_parameters = self.get_all_parameters(workflow_run_id)
context_parameters = [
parameter for parameter in loop_block_parameters if isinstance(parameter, ContextParameter)
]
Expand Down Expand Up @@ -332,28 +337,37 @@ async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
num_loop_over_values=len(loop_over_values),
)
outputs_with_loop_values = []
block_outputs = []
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
try:
block_output = await self.loop_block.execute(workflow_run_id=workflow_run_id)
block_outputs.append(block_output)
block_outputs = [
await loop_block.execute(workflow_run_id=workflow_run_id) for loop_block in self.loop_blocks
]
except Exception as e:
LOG.error("ForLoopBlock: Failed to execute loop block", exc_info=True)
raise e
if block_output.output_parameter:
outputs_with_loop_values.append(
outputs_with_loop_values.append(
[
{
"loop_value": loop_over_value,
"output_parameter": block_output.output_parameter,
"output_value": workflow_run_context.get_value(block_output.output_parameter.key),
}
)
for block_output in block_outputs
if block_output.output_parameter
]
)

# If all block outputs are successful, the loop is successful
success = all([block_output.success for block_output in block_outputs])
# If all block outputs are successful, the loop is successful
success = all([block_output.success for block_output in block_outputs])
if not success:
LOG.info(
"ForLoopBlock: Encountered an failure processing block, terminating early",
block_outputs=block_outputs,
)
break

if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
Expand Down
3 changes: 3 additions & 0 deletions skyvern/forge/sdk/workflow/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Parameter(BaseModel, abc.ABC):
key: str
description: str | None = None

def __hash__(self) -> int:
return hash(self.key)

@classmethod
def get_subclasses(cls) -> tuple[type["Parameter"], ...]:
return tuple(cls.__subclasses__())
Expand Down
2 changes: 1 addition & 1 deletion skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ForLoopBlockYAML(BlockYAML):
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP # type: ignore

loop_over_parameter_key: str
loop_block: "BLOCK_YAML_SUBCLASSES"
loop_blocks: list["BLOCK_YAML_SUBCLASSES"]


class CodeBlockYAML(BlockYAML):
Expand Down
7 changes: 5 additions & 2 deletions skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,12 +760,15 @@ async def block_yaml_to_block(block_yaml: BLOCK_YAML_TYPES, parameters: dict[str
max_retries=block_yaml.max_retries,
)
elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_block = await WorkflowService.block_yaml_to_block(block_yaml.loop_block, parameters)
loop_blocks = [
await WorkflowService.block_yaml_to_block(loop_block, parameters)
for loop_block in block_yaml.loop_blocks
]
loop_over_parameter = parameters[block_yaml.loop_over_parameter_key]
return ForLoopBlock(
label=block_yaml.label,
loop_over=loop_over_parameter,
loop_block=loop_block,
loop_blocks=loop_blocks,
output_parameter=output_parameter,
)
elif block_yaml.block_type == BlockType.CODE:
Expand Down

0 comments on commit 8c12e2b

Please sign in to comment.