diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index faa504de0..1c5f54359 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -503,6 +503,10 @@ async def agent_step( step_retry=step.retry_index, ) step = await self.update_step(step=step, status=StepStatus.running) + await app.AGENT_FUNCTION.prepare_step_execution( + organization=organization, task=task, step=step, browser_state=browser_state + ) + ( scraped_page, extract_action_prompt, diff --git a/skyvern/forge/agent_functions.py b/skyvern/forge/agent_functions.py index 9e704ffa7..6af808d07 100644 --- a/skyvern/forge/agent_functions.py +++ b/skyvern/forge/agent_functions.py @@ -5,6 +5,7 @@ from skyvern.forge.async_operations import AsyncOperation from skyvern.forge.sdk.models import Organization, Step, StepStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus +from skyvern.webeye.browser_factory import BrowserState class AgentFunction: @@ -14,7 +15,7 @@ async def validate_step_execution( step: Step, ) -> None: """ - Checks if the step can be executed. + Checks if the step can be executed. It is called before the step is executed. :return: A tuple of whether the step can be executed and a list of reasons why it can't be executed. """ reasons = [] @@ -36,6 +37,18 @@ async def validate_step_execution( if not can_execute: raise StepUnableToExecuteError(step_id=step.step_id, reason=f"Cannot execute step. Reasons: {reasons}") + async def prepare_step_execution( + self, + organization: Organization | None, + task: Task, + step: Step, + browser_state: BrowserState, + ) -> None: + """ + Get prepared for the step execution. It's called at the first beginning when step running. + """ + return + def generate_async_operations( self, organization: Organization, diff --git a/skyvern/webeye/browser_factory.py b/skyvern/webeye/browser_factory.py index 9d68f9ec9..be44486da 100644 --- a/skyvern/webeye/browser_factory.py +++ b/skyvern/webeye/browser_factory.py @@ -29,10 +29,13 @@ LOG = structlog.get_logger() +BrowserCleanupFunc = Callable[[], None] | None + + class BrowserContextCreator(Protocol): def __call__( self, playwright: Playwright, **kwargs: dict[str, Any] - ) -> Awaitable[tuple[BrowserContext, BrowserArtifacts]]: ... + ) -> Awaitable[tuple[BrowserContext, BrowserArtifacts, BrowserCleanupFunc]]: ... class BrowserContextFactory: @@ -93,7 +96,7 @@ def register_type(cls, browser_type: str, creator: BrowserContextCreator) -> Non @classmethod async def create_browser_context( cls, playwright: Playwright, **kwargs: Any - ) -> tuple[BrowserContext, BrowserArtifacts]: + ) -> tuple[BrowserContext, BrowserArtifacts, BrowserCleanupFunc]: browser_type = SettingsManager.get_settings().BROWSER_TYPE try: creator = cls._creators.get(browser_type) @@ -123,14 +126,18 @@ class BrowserArtifacts(BaseModel): traces_dir: str | None = None -async def _create_headless_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]: +async def _create_headless_chromium( + playwright: Playwright, **kwargs: dict +) -> tuple[BrowserContext, BrowserArtifacts, BrowserCleanupFunc]: browser_args = BrowserContextFactory.build_browser_args() browser_artifacts = BrowserContextFactory.build_browser_artifacts(har_path=browser_args["record_har_path"]) browser_context = await playwright.chromium.launch_persistent_context(**browser_args) - return browser_context, browser_artifacts + return browser_context, browser_artifacts, None -async def _create_headful_chromium(playwright: Playwright, **kwargs: dict) -> tuple[BrowserContext, BrowserArtifacts]: +async def _create_headful_chromium( + playwright: Playwright, **kwargs: dict +) -> tuple[BrowserContext, BrowserArtifacts, BrowserCleanupFunc]: browser_args = BrowserContextFactory.build_browser_args() browser_args.update( { @@ -139,7 +146,7 @@ async def _create_headful_chromium(playwright: Playwright, **kwargs: dict) -> tu ) browser_artifacts = BrowserContextFactory.build_browser_artifacts(har_path=browser_args["record_har_path"]) browser_context = await playwright.chromium.launch_persistent_context(**browser_args) - return browser_context, browser_artifacts + return browser_context, browser_artifacts, None BrowserContextFactory.register_type("chromium-headless", _create_headless_chromium) @@ -155,11 +162,13 @@ def __init__( browser_context: BrowserContext | None = None, page: Page | None = None, browser_artifacts: BrowserArtifacts = BrowserArtifacts(), + browser_cleanup: BrowserCleanupFunc = None, ): self.pw = pw self.browser_context = browser_context self.page = page self.browser_artifacts = browser_artifacts + self.browser_cleanup = browser_cleanup def __assert_page(self) -> Page: if self.page is not None: @@ -190,6 +199,7 @@ async def check_and_fix_state( ( browser_context, browser_artifacts, + browser_cleanup, ) = await BrowserContextFactory.create_browser_context( self.pw, url=url, @@ -198,6 +208,7 @@ async def check_and_fix_state( ) self.browser_context = browser_context self.browser_artifacts = browser_artifacts + self.browser_cleanup = browser_cleanup LOG.info("browser context is created") assert self.browser_context is not None @@ -300,6 +311,9 @@ async def close(self, close_browser_on_completion: bool = True) -> None: LOG.info("Closing browser context and its pages") await self.browser_context.close() LOG.info("Main browser context and all its pages are closed") + if self.browser_cleanup is not None: + self.browser_cleanup() + LOG.info("Main browser cleanup is excuted") if self.pw and close_browser_on_completion: LOG.info("Stopping playwright") await self.pw.stop() diff --git a/skyvern/webeye/browser_manager.py b/skyvern/webeye/browser_manager.py index 76a23d011..4a0a3c351 100644 --- a/skyvern/webeye/browser_manager.py +++ b/skyvern/webeye/browser_manager.py @@ -30,6 +30,7 @@ async def _create_browser_state( ( browser_context, browser_artifacts, + browser_cleanup, ) = await BrowserContextFactory.create_browser_context( pw, proxy_location=proxy_location, @@ -41,6 +42,7 @@ async def _create_browser_state( browser_context=browser_context, page=None, browser_artifacts=browser_artifacts, + browser_cleanup=browser_cleanup, ) async def get_or_create_for_task(self, task: Task) -> BrowserState: