From 403f079d6ea4e60885eba7f4574a460d0d55dcba Mon Sep 17 00:00:00 2001 From: ykeremy Date: Sat, 20 Apr 2024 09:14:42 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20synced=20local=20'skyvern/'=20wi?= =?UTF-8?q?th=20remote=20'skyvern/'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- skyvern/forge/sdk/agent.py | 16 ----------- skyvern/forge/sdk/routes/agent_protocol.py | 33 ++-------------------- 2 files changed, 3 insertions(+), 46 deletions(-) diff --git a/skyvern/forge/sdk/agent.py b/skyvern/forge/sdk/agent.py index 5c7981c06..61cb763f1 100644 --- a/skyvern/forge/sdk/agent.py +++ b/skyvern/forge/sdk/agent.py @@ -47,8 +47,6 @@ def get_agent_app(self, router: APIRouter = base_router) -> FastAPI: app.include_router(router, prefix="/api/v1") - app.add_middleware(AgentMiddleware, agent=self) - app.add_middleware( RawContextMiddleware, plugins=( @@ -85,20 +83,6 @@ async def request_middleware(request: Request, call_next: Callable[[Request], Aw return app -class AgentMiddleware: - """ - Middleware that injects the agent instance into the request scope. - """ - - def __init__(self, app: FastAPI, agent: Agent): - self.app = app - self.agent = agent - - async def __call__(self, scope, receive, send): # type: ignore - scope["agent"] = self.agent - await self.app(scope, receive, send) - - class ExecutionDatePlugin(Plugin): key = "execution_date" diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 19d41f351..3c839ea4f 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -83,19 +83,17 @@ async def check_server_status() -> Response: @base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse) async def create_agent_task( background_tasks: BackgroundTasks, - request: Request, task: TaskRequest, current_org: Organization = Depends(org_auth_service.get_current_org), x_api_key: Annotated[str | None, Header()] = None, x_max_steps_override: Annotated[int | None, Header()] = None, ) -> CreateTaskResponse: analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url}) - agent = request["agent"] if current_org and current_org.organization_name == "CoverageCat": task.proxy_location = ProxyLocation.RESIDENTIAL - created_task = await agent.create_task(task, current_org.organization_id) + created_task = await app.agent.create_task(task, current_org.organization_id) if x_max_steps_override: LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) await AsyncExecutorFactory.get_executor().execute_task( @@ -121,13 +119,11 @@ async def create_agent_task( summary="Executes the next step", ) async def execute_agent_task_step( - request: Request, task_id: str, step_id: str | None = None, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: analytics.capture("skyvern-oss-agent-task-step-execute") - agent = request["agent"] task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id) if not task: raise HTTPException( @@ -171,7 +167,7 @@ async def execute_agent_task_step( status_code=status.HTTP_404_NOT_FOUND, detail=f"No step found with id {step_id}", ) - step, _, _ = await agent.execute_step(current_org, task, step) + step, _, _ = await app.agent.execute_step(current_org, task, step) return Response( content=step.model_dump_json() if step else "", status_code=200, @@ -181,12 +177,10 @@ async def execute_agent_task_step( @base_router.get("/tasks/{task_id}", response_model=TaskResponse) async def get_task( - request: Request, task_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> TaskResponse: analytics.capture("skyvern-oss-agent-task-get") - request["agent"] task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id) if not task_obj: raise HTTPException( @@ -270,13 +264,11 @@ async def get_task( response_model=TaskResponse, ) async def retry_webhook( - request: Request, task_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), x_api_key: Annotated[str | None, Header()] = None, ) -> TaskResponse: analytics.capture("skyvern-oss-agent-task-retry-webhook") - agent = request["agent"] task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id) if not task_obj: raise HTTPException( @@ -290,20 +282,18 @@ async def retry_webhook( return task_obj.to_task_response() # retry the webhook - await agent.execute_task_webhook(task=task_obj, last_step=latest_step, api_key=x_api_key) + await app.agent.execute_task_webhook(task=task_obj, last_step=latest_step, api_key=x_api_key) return task_obj.to_task_response() @base_router.get("/internal/tasks/{task_id}", response_model=list[Task]) async def get_task_internal( - request: Request, task_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: """ Get all tasks. - :param request: :param page: Starting page, defaults to 1 :param page_size: :return: List of tasks with pagination without steps populated. Steps can be populated by calling the @@ -321,80 +311,68 @@ async def get_task_internal( @base_router.get("/tasks", tags=["agent"], response_model=list[Task]) async def get_agent_tasks( - request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: """ Get all tasks. - :param request: :param page: Starting page, defaults to 1 :param page_size: Page size, defaults to 10 :return: List of tasks with pagination without steps populated. Steps can be populated by calling the get_agent_task endpoint. """ analytics.capture("skyvern-oss-agent-tasks-get") - request["agent"] tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id) return ORJSONResponse([task.to_task_response().model_dump() for task in tasks]) @base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task]) async def get_agent_tasks_internal( - request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: """ Get all tasks. - :param request: :param page: Starting page, defaults to 1 :param page_size: Page size, defaults to 10 :return: List of tasks with pagination without steps populated. Steps can be populated by calling the get_agent_task endpoint. """ analytics.capture("skyvern-oss-agent-tasks-get-internal") - request["agent"] tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id) return ORJSONResponse([task.model_dump() for task in tasks]) @base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step]) async def get_agent_task_steps( - request: Request, task_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: """ Get all steps for a task. - :param request: :param task_id: :return: List of steps for a task with pagination. """ analytics.capture("skyvern-oss-agent-task-steps-get") - request["agent"] steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id) return ORJSONResponse([step.model_dump() for step in steps]) @base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact]) async def get_agent_task_step_artifacts( - request: Request, task_id: str, step_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> Response: """ Get all artifacts for a list of steps. - :param request: :param task_id: :param step_id: :return: List of artifacts for a list of steps. """ analytics.capture("skyvern-oss-agent-task-step-artifacts-get") - request["agent"] artifacts = await app.DATABASE.get_artifacts_for_task_step( task_id, step_id, @@ -416,12 +394,10 @@ class ActionResultTmp(BaseModel): @base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp]) async def get_task_actions( - request: Request, task_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> list[ActionResultTmp]: analytics.capture("skyvern-oss-agent-task-actions-get") - request["agent"] steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id) results: list[ActionResultTmp] = [] for step_s in steps: @@ -435,7 +411,6 @@ async def get_task_actions( @base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse) async def execute_workflow( background_tasks: BackgroundTasks, - request: Request, workflow_id: str, workflow_request: WorkflowRequestBody, current_org: Organization = Depends(org_auth_service.get_current_org), @@ -470,13 +445,11 @@ async def execute_workflow( @base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse) async def get_workflow_run( - request: Request, workflow_id: str, workflow_run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> WorkflowRunStatusResponse: analytics.capture("skyvern-oss-agent-workflow-run-get") - request["agent"] return await app.WORKFLOW_SERVICE.build_workflow_run_status_response( workflow_id=workflow_id, workflow_run_id=workflow_run_id,