diff --git a/skyvern/config.py b/skyvern/config.py index 8b01afc82..ac7c79912 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -48,6 +48,7 @@ class Settings(BaseSettings): # S3 bucket settings AWS_REGION: str = "us-east-1" AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads" + MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB SKYVERN_TELEMETRY: bool = True ANALYTICS_ID: str = "anonymous" diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index 9a9c80615..92ed4ac41 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Callable +from typing import IO, Any, Callable from urllib.parse import urlparse import aioboto3 @@ -55,6 +55,17 @@ async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) LOG.exception("S3 upload failed.", uri=uri) return None + @execute_with_async_client(client_type=AWSClientType.S3) + async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None: + try: + parsed_uri = S3Uri(uri) + await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key) + LOG.debug("Upload file stream success", uri=uri) + return uri + except Exception: + LOG.exception("S3 upload stream failed.", uri=uri) + return None + @execute_with_async_client(client_type=AWSClientType.S3) async def upload_file_from_path(self, uri: str, file_path: str, client: AioBaseClient = None) -> None: try: @@ -137,3 +148,6 @@ def key(self) -> str: @property def uri(self) -> str: return self._parsed.geturl() + + +aws_client = AsyncAWSClient() diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index de5602614..7fc8d4afc 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1,8 +1,21 @@ +import datetime +import uuid from typing import Annotated, Any import structlog import yaml -from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, status +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + Header, + HTTPException, + Query, + Request, + Response, + UploadFile, + status, +) from fastapi.responses import ORJSONResponse from pydantic import BaseModel @@ -10,6 +23,7 @@ from skyvern.exceptions import StepNotFound from skyvern.forge import app from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.aws import aws_client from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.core import skyvern_context @@ -736,3 +750,43 @@ async def update_organization( max_steps_per_run=org_update.max_steps_per_run, max_retries_per_step=org_update.max_retries_per_step, ) + + +async def validate_file_size(file: UploadFile) -> UploadFile: + # Check the file size + if file.size > app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE: + raise HTTPException( + status_code=413, + detail=f"File size exceeds the maximum allowed size ({app.SETTINGS_MANAGER.MAX_UPLOAD_FILE_SIZE} bytes)", + ) + return file + + +@base_router.post("/upload_file/", include_in_schema=False) +@base_router.post("/upload_file") +async def upload_file( + file: UploadFile = Depends(validate_file_size), + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> Response: + bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS + todays_date = datetime.datetime.now().strftime("%Y-%m-%d") + uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}" + s3_uri = ( + f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}" + ) + # Stream the file to S3 + uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file) + if not uploaded_s3_uri: + raise HTTPException(status_code=500, detail="Failed to upload file to S3.") + + # Generate a presigned URL for the uploaded file + presigned_urls = await aws_client.create_presigned_urls([uploaded_s3_uri]) + if not presigned_urls: + raise HTTPException(status_code=500, detail="Failed to generate presigned URL.") + + presigned_url = presigned_urls[0] + return ORJSONResponse( + content={"s3_uri": uploaded_s3_uri, "presigned_url": presigned_url}, + status_code=200, + media_type="application/json", + )