Skip to content

Commit

Permalink
workflow apis (#326)
Browse files Browse the repository at this point in the history
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
  • Loading branch information
ykeremy and wintonzheng authored May 16, 2024
1 parent 50026f3 commit 72d25cd
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""add proxy_location and webhook_callback_url to workflows table
Revision ID: 04bf06540db6
Revises: baec12642d77
Create Date: 2024-05-16 17:29:55.083124+00:00
"""
from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "04bf06540db6"
down_revision: Union[str, None] = "baec12642d77"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"workflows",
sa.Column(
"proxy_location",
sa.Enum(
"US_CA",
"US_NY",
"US_TX",
"US_FL",
"US_WA",
"RESIDENTIAL",
"RESIDENTIAL_ES",
"NONE",
name="proxylocation",
),
nullable=True,
),
)
op.add_column("workflows", sa.Column("webhook_callback_url", sa.String(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflows", "webhook_callback_url")
op.drop_column("workflows", "proxy_location")
# ### end Alembic commands ###
18 changes: 16 additions & 2 deletions skyvern/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,22 @@ def __init__(self, block_type: str) -> None:


class WorkflowNotFound(SkyvernHTTPException):
def __init__(self, workflow_id: str) -> None:
super().__init__(f"Workflow {workflow_id} not found", status_code=status.HTTP_404_NOT_FOUND)
def __init__(
self,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> None:
workflow_repr = ""
if workflow_id:
workflow_repr = f"workflow_id={workflow_id}"
if workflow_permanent_id:
if version:
workflow_repr = f"workflow_permanent_id={workflow_permanent_id}, version={version}"
else:
workflow_repr = f"workflow_permanent_id={workflow_permanent_id}"

super().__init__(f"Workflow not found. {workflow_repr}", status_code=status.HTTP_404_NOT_FOUND)


class WorkflowRunNotFound(SkyvernException):
Expand Down
116 changes: 111 additions & 5 deletions skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Sequence

import structlog
from sqlalchemy import and_, delete, select
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

Expand Down Expand Up @@ -723,18 +723,28 @@ async def get_latest_task_by_workflow_id(

async def create_workflow(
self,
organization_id: str,
title: str,
workflow_definition: dict[str, Any],
organization_id: str | None = None,
description: str | None = None,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
) -> Workflow:
async with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition,
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
)
if workflow_permanent_id:
workflow.workflow_permanent_id = workflow_permanent_id
if version:
workflow.version = version
session.add(workflow)
await session.commit()
await session.refresh(workflow)
Expand All @@ -743,7 +753,9 @@ async def create_workflow(
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
try:
async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
Expand All @@ -753,17 +765,88 @@ async def get_workflow(self, workflow_id: str, organization_id: str | None = Non
LOG.error("SQLAlchemyError", exc_info=True)
raise

async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
) -> Workflow | None:
try:
get_workflow_query = (
select(WorkflowModel)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if version:
get_workflow_query = get_workflow_query.filter_by(version=version)
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
async with self.Session() as session:
if workflow := (await session.scalars(get_workflow_query)).first():
return convert_to_workflow(workflow, self.debug_enabled)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise

async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
try:
async with self.Session() as session:
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(WorkflowModel.organization_id, WorkflowModel.workflow_permanent_id)
.subquery()
)
main_query = (
select(WorkflowModel)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.order_by(WorkflowModel.created_at.desc()) # Example ordering by creation date
.limit(page_size)
.offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
return [convert_to_workflow(workflow, self.debug_enabled) for workflow in workflows]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise

async def update_workflow(
self,
workflow_id: str,
organization_id: str | None = None,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
version: int | None = None,
) -> Workflow:
try:
async with self.Session() as session:
get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
get_workflow_query = (
select(WorkflowModel).filter_by(workflow_id=workflow_id).filter(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
Expand All @@ -773,6 +856,8 @@ async def update_workflow(
workflow.description = description
if workflow_definition:
workflow.workflow_definition = workflow_definition
if version:
workflow.version = version
await session.commit()
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
Expand All @@ -789,8 +874,29 @@ async def update_workflow(
LOG.error("UnexpectedError", exc_info=True)
raise

async def soft_delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
async with self.Session() as session:
# soft delete the workflow by setting the deleted_at field
update_deleted_at_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id)
update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow())
await session.execute(update_deleted_at_query)
await session.commit()

async def create_workflow_run(
self, workflow_id: str, proxy_location: ProxyLocation | None = None, webhook_callback_url: str | None = None
self,
workflow_id: str,
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session:
Expand Down
2 changes: 2 additions & 0 deletions skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class WorkflowModel(Base):
title = Column(String, nullable=False)
description = Column(String, nullable=True)
workflow_definition = Column(JSON, nullable=False)
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)

created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
Expand Down
4 changes: 4 additions & 0 deletions skyvern/forge/sdk/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
workflow_id=workflow_model.workflow_id,
organization_id=workflow_model.organization_id,
title=workflow_model.title,
workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url,
proxy_location=ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None,
version=workflow_model.version,
description=workflow_model.description,
workflow_definition=WorkflowDefinition.model_validate(workflow_model.workflow_definition),
created_at=workflow_model.created_at,
Expand Down
85 changes: 85 additions & 0 deletions skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,88 @@ async def create_workflow(
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
organization_id=current_org.organization_id, request=workflow_create_request
)


@base_router.put(
"/workflows/{workflow_permanent_id}",
openapi_extra={
"requestBody": {
"content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}},
"required": True,
},
},
response_model=Workflow,
)
@base_router.put(
"/workflows/{workflow_permanent_id}/",
openapi_extra={
"requestBody": {
"content": {"application/x-yaml": {"schema": WorkflowCreateYAMLRequest.model_json_schema()}},
"required": True,
},
},
response_model=Workflow,
include_in_schema=False,
)
async def update_workflow(
workflow_permanent_id: str,
request: Request,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Workflow:
analytics.capture("skyvern-oss-agent-workflow-update")
# validate the workflow
raw_yaml = await request.body()
try:
workflow_yaml = yaml.safe_load(raw_yaml)
except yaml.YAMLError:
raise HTTPException(status_code=422, detail="Invalid YAML")

workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml)
return await app.WORKFLOW_SERVICE.create_workflow_from_request(
organization_id=current_org.organization_id,
request=workflow_create_request,
workflow_permanent_id=workflow_permanent_id,
)


@base_router.delete("/workflows/{workflow_permanent_id}")
@base_router.delete("/workflows/{workflow_permanent_id}/", include_in_schema=False)
async def delete_workflow(
workflow_permanent_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> None:
analytics.capture("skyvern-oss-agent-workflow-delete")
await app.WORKFLOW_SERVICE.delete_workflow_by_permanent_id(workflow_permanent_id, current_org.organization_id)


@base_router.get("/workflows", response_model=list[Workflow])
@base_router.get("/workflows/", response_model=list[Workflow])
async def get_workflows(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
analytics.capture("skyvern-oss-agent-workflows-get")
return await app.WORKFLOW_SERVICE.get_workflows_by_organization_id(
organization_id=current_org.organization_id,
page=page,
page_size=page_size,
)


@base_router.get("/workflows/{workflow_permanent_id}", response_model=Workflow)
@base_router.get("/workflows/{workflow_permanent_id}/", response_model=Workflow)
async def get_workflow(
workflow_permanent_id: str,
version: int | None = None,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Workflow:
analytics.capture("skyvern-oss-agent-workflows-get")
return await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=current_org.organization_id,
version=version,
)
4 changes: 4 additions & 0 deletions skyvern/forge/sdk/workflow/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ class Workflow(BaseModel):
workflow_id: str
organization_id: str
title: str
workflow_permanent_id: str
version: int
description: str | None = None
workflow_definition: WorkflowDefinition
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None

created_at: datetime
modified_at: datetime
Expand Down
3 changes: 3 additions & 0 deletions skyvern/forge/sdk/workflow/models/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel, Field

from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.forge.sdk.workflow.models.parameter import ParameterType, WorkflowParameterType

Expand Down Expand Up @@ -187,4 +188,6 @@ class WorkflowDefinitionYAML(BaseModel):
class WorkflowCreateYAMLRequest(BaseModel):
title: str
description: str | None = None
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
workflow_definition: WorkflowDefinitionYAML
Loading

0 comments on commit 72d25cd

Please sign in to comment.