Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ykeremy/workflow prompt block #124

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ curlify = "^2.2.1"
typer = "^0.9.0"
types-toml = "^0.10.8.7"
apscheduler = "^3.10.4"
httpx = "^0.27.0"


[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 3 additions & 1 deletion skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,9 @@ async def update_workflow(
) -> Workflow:
try:
async with self.Session() as session:
if workflow := await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id).first()):
if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
if title:
workflow.title = title
if description:
Expand Down
18 changes: 18 additions & 0 deletions skyvern/forge/sdk/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,21 @@ def load_prompt(self, template: str, **kwargs: Any) -> str:
except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise

def load_prompt_from_string(self, template: str, **kwargs: Any) -> str:
"""
Load and populate the specified template from a string.

Args:
template (str): The template string to load.
**kwargs: The arguments to populate the template with.

Returns:
str: The populated template.
"""
try:
jinja_template = self.env.from_string(template)
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise
74 changes: 73 additions & 1 deletion skyvern/forge/sdk/workflow/models/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import json
from enum import StrEnum
from typing import Annotated, Any, Literal, Union

Expand All @@ -12,6 +13,8 @@
UnexpectedTaskStatus,
)
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import (
Expand All @@ -28,6 +31,7 @@ class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
CODE = "code"
TEXT_PROMPT = "text_prompt"


class Block(BaseModel, abc.ABC):
Expand Down Expand Up @@ -345,5 +349,73 @@ async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter
return None


BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
class TextPromptBlock(Block):
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT

llm_key: str
prompt: str
parameters: list[PARAMETER_TYPE] = []
json_schema: dict[str, Any] | None = None

def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters

async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key)
if not self.json_schema:
self.json_schema = {
"type": "object",
"properties": {
"llm_response": {
"type": "string",
"description": "Your response to the prompt",
}
},
}

prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values)
prompt += (
"\n\n"
+ "Please respond to the prompt above using the following JSON definition:\n\n"
+ "```json\n"
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response

async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value

response = await self.send_prompt(self.prompt, parameter_values)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=response,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=response,
)
return self.output_parameter

return None


BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
15 changes: 14 additions & 1 deletion skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML):
parameter_keys: list[str] | None = None


class TextPromptBlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore

llm_key: str
prompt: str
parameter_keys: list[str] | None = None
json_schema: dict[str, Any] | None = None


PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]

BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]


Expand Down
20 changes: 19 additions & 1 deletion skyvern/forge/sdk/workflow/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
from skyvern.forge.sdk.workflow.models.block import BlockType, BlockTypeVar, CodeBlock, ForLoopBlock, TaskBlock
from skyvern.forge.sdk.workflow.models.block import (
BlockType,
BlockTypeVar,
CodeBlock,
ForLoopBlock,
TaskBlock,
TextPromptBlock,
)
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
OutputParameter,
Expand Down Expand Up @@ -714,4 +721,15 @@ async def block_yaml_to_block(block_yaml: BLOCK_YAML_TYPES, parameters: dict[str
else [],
output_parameter=output_parameter,
)
elif block_yaml.block_type == BlockType.TEXT_PROMPT:
return TextPromptBlock(
label=block_yaml.label,
llm_key=block_yaml.llm_key,
prompt=block_yaml.prompt,
parameters=[parameters[parameter_key] for parameter_key in block_yaml.parameter_keys]
if block_yaml.parameter_keys
else [],
json_schema=block_yaml.json_schema,
output_parameter=output_parameter,
)
raise ValueError(f"Invalid block type {block_yaml.block_type}")
Loading