diff --git a/alembic/versions/2024_04_08_2347-ea8e24d0bc8e_add_orgs_max_retries_per_step.py b/alembic/versions/2024_04_08_2347-ea8e24d0bc8e_add_orgs_max_retries_per_step.py new file mode 100644 index 000000000..5d984c6b8 --- /dev/null +++ b/alembic/versions/2024_04_08_2347-ea8e24d0bc8e_add_orgs_max_retries_per_step.py @@ -0,0 +1,30 @@ +"""Add orgs.max_retries_per_step + +Revision ID: ea8e24d0bc8e +Revises: 4630ab8c198e +Create Date: 2024-04-08 23:47:46.306300+00:00 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "ea8e24d0bc8e" +down_revision: Union[str, None] = "4630ab8c198e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("organizations", sa.Column("max_retries_per_step", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("organizations", "max_retries_per_step") + # ### end Alembic commands ### diff --git a/skyvern/config.py b/skyvern/config.py index 087573c08..3edcf0745 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): # Ratio should be between 0 and 1. # If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning. LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95 - MAX_RETRIES_PER_STEP: int = 2 + MAX_RETRIES_PER_STEP: int = 5 DEBUG_MODE: bool = False DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern" PROMPT_ACTION_HISTORY_WINDOW: int = 5 diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 7ed79e0bf..5018dbdeb 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -227,7 +227,7 @@ async def execute_step( # If the step failed, mark the step as failed and retry if step.status == StepStatus.failed: - maybe_next_step = await self.handle_failed_step(task, step) + maybe_next_step = await self.handle_failed_step(organization, task, step) # If there is no next step, it means that the task has failed if maybe_next_step: next_step = maybe_next_step @@ -965,8 +965,14 @@ async def update_task( **updates, ) - async def handle_failed_step(self, task: Task, step: Step) -> Step | None: - if step.retry_index >= SettingsManager.get_settings().MAX_RETRIES_PER_STEP: + async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None: + max_retries_per_step = ( + organization.max_retries_per_step + # we need to check by None because 0 is a valid value for max_retries_per_step + if organization.max_retries_per_step is not None + else SettingsManager.get_settings().MAX_RETRIES_PER_STEP + ) + if step.retry_index >= max_retries_per_step: LOG.warning( "Step failed after max retries, marking task as failed", task_id=task.task_id, @@ -978,7 +984,7 @@ async def handle_failed_step(self, task: Task, step: Step) -> Step | None: await self.update_task( task, TaskStatus.failed, - failure_reason=f"Max retries per step ({SettingsManager.get_settings().MAX_RETRIES_PER_STEP}) exceeded", + failure_reason=f"Max retries per step ({max_retries_per_step}) exceeded", ) return None else: diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 0917088c6..366af3330 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -418,12 +418,14 @@ async def create_organization( organization_name: str, webhook_callback_url: str | None = None, max_steps_per_run: int | None = None, + max_retries_per_step: int | None = None, ) -> Organization: async with self.Session() as session: org = OrganizationModel( organization_name=organization_name, webhook_callback_url=webhook_callback_url, max_steps_per_run=max_steps_per_run, + max_retries_per_step=max_retries_per_step, ) session.add(org) await session.commit() diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 30df02887..ae2e85144 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -74,7 +74,8 @@ class OrganizationModel(Base): organization_id = Column(String, primary_key=True, index=True, default=generate_org_id) organization_name = Column(String, nullable=False) webhook_callback_url = Column(UnicodeText) - max_steps_per_run = Column(Integer) + max_steps_per_run = Column(Integer, nullable=True) + max_retries_per_step = Column(Integer, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 5a1a406dd..5f6bfe218 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -104,6 +104,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization: organization_name=org_model.organization_name, webhook_callback_url=org_model.webhook_callback_url, max_steps_per_run=org_model.max_steps_per_run, + max_retries_per_step=org_model.max_retries_per_step, created_at=org_model.created_at, modified_at=org_model.modified_at, ) diff --git a/skyvern/forge/sdk/models.py b/skyvern/forge/sdk/models.py index 60aea5089..2d6bff41d 100644 --- a/skyvern/forge/sdk/models.py +++ b/skyvern/forge/sdk/models.py @@ -117,6 +117,7 @@ class Organization(BaseModel): organization_name: str webhook_callback_url: str | None = None max_steps_per_run: int | None = None + max_retries_per_step: int | None = None created_at: datetime modified_at: datetime