From 81a1478e2fdc3e462492656832239677d4890b2a Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Mon, 8 Apr 2024 22:56:49 -0700 Subject: [PATCH] Introduce new `TaskStatus`es: `queued` and `timed_out` (#170) --- ...d7fecef9_add_new_indices_to_tasks_table.py | 32 +++++++++++++++++++ skyvern/forge/sdk/db/models.py | 8 +++-- skyvern/forge/sdk/schemas/tasks.py | 11 +++++-- 3 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 alembic/versions/2024_04_09_0058-8335d7fecef9_add_new_indices_to_tasks_table.py diff --git a/alembic/versions/2024_04_09_0058-8335d7fecef9_add_new_indices_to_tasks_table.py b/alembic/versions/2024_04_09_0058-8335d7fecef9_add_new_indices_to_tasks_table.py new file mode 100644 index 000000000..964a9c848 --- /dev/null +++ b/alembic/versions/2024_04_09_0058-8335d7fecef9_add_new_indices_to_tasks_table.py @@ -0,0 +1,32 @@ +"""Add new indices to tasks table + +Revision ID: 8335d7fecef9 +Revises: ea8e24d0bc8e +Create Date: 2024-04-09 00:58:53.060477+00:00 + +""" +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "8335d7fecef9" +down_revision: Union[str, None] = "ea8e24d0bc8e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f("ix_tasks_created_at"), "tasks", ["created_at"], unique=False) + op.create_index(op.f("ix_tasks_modified_at"), "tasks", ["modified_at"], unique=False) + op.create_index(op.f("ix_tasks_status"), "tasks", ["status"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_tasks_status"), table_name="tasks") + op.drop_index(op.f("ix_tasks_modified_at"), table_name="tasks") + op.drop_index(op.f("ix_tasks_created_at"), table_name="tasks") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index ae2e85144..d3bf48bf7 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -30,7 +30,7 @@ class TaskModel(Base): task_id = Column(String, primary_key=True, index=True, default=generate_task_id) organization_id = Column(String, ForeignKey("organizations.organization_id")) - status = Column(String) + status = Column(String, index=True) webhook_callback_url = Column(String) title = Column(String) url = Column(String) @@ -46,8 +46,10 @@ class TaskModel(Base): retry = Column(Integer, nullable=True) error_code_mapping = Column(JSON, nullable=True) errors = Column(JSON, default=[], nullable=False) - created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) - modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True) + modified_at = Column( + DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False, index=True + ) class StepModel(Base): diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index f468f1c19..8c9baf9c1 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -73,19 +73,23 @@ class TaskRequest(BaseModel): class TaskStatus(StrEnum): created = "created" + queued = "queued" running = "running" + timed_out = "timed_out" failed = "failed" terminated = "terminated" completed = "completed" def is_final(self) -> bool: - return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed} + return self in {TaskStatus.failed, TaskStatus.terminated, TaskStatus.completed, TaskStatus.timed_out} def can_update_to(self, new_status: TaskStatus) -> bool: allowed_transitions: dict[TaskStatus, set[TaskStatus]] = { - TaskStatus.created: {TaskStatus.running}, - TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated}, + TaskStatus.created: {TaskStatus.queued, TaskStatus.running, TaskStatus.timed_out}, + TaskStatus.queued: {TaskStatus.running, TaskStatus.timed_out}, + TaskStatus.running: {TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated, TaskStatus.timed_out}, TaskStatus.failed: set(), + TaskStatus.terminated: set(), TaskStatus.completed: set(), } return new_status in allowed_transitions[self] @@ -97,6 +101,7 @@ def requires_extracted_info(self) -> bool: def cant_have_extracted_info(self) -> bool: status_cant_have_extracted_information = { TaskStatus.created, + TaskStatus.queued, TaskStatus.running, TaskStatus.failed, TaskStatus.terminated,