diff --git a/aana/alembic/versions/b9860676dd49_set_server_default_for_task_completed_.py b/aana/alembic/versions/b9860676dd49_set_server_default_for_task_completed_.py new file mode 100644 index 00000000..e40c309d --- /dev/null +++ b/aana/alembic/versions/b9860676dd49_set_server_default_for_task_completed_.py @@ -0,0 +1,49 @@ +"""Set server default for task.completed_at and task.assigned_at to none and add num_retries. + +Revision ID: b9860676dd49 +Revises: 5ad873484aa3 +Create Date: 2024-08-22 07:54:55.921710 + +""" +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b9860676dd49" +down_revision: str | None = "5ad873484aa3" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade database to this revision from previous.""" + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.alter_column( + "completed_at", + server_default=None, + ) + batch_op.alter_column( + "assigned_at", + server_default=None, + ) + batch_op.add_column( + sa.Column( + "num_retries", + sa.Integer(), + nullable=False, + comment="Number of retries", + server_default=sa.text("0"), + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database from this revision to previous.""" + with op.batch_alter_table("tasks", schema=None) as batch_op: + batch_op.drop_column("num_retries") + + # ### end Alembic commands ### diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index d234df3c..7455b193 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -19,7 +19,7 @@ from aana.exceptions.runtime import ( MultipleFileUploadNotAllowed, ) -from aana.storage.services.task import create_task +from aana.storage.repository.task import TaskRepository from aana.storage.session import get_session @@ -313,11 +313,9 @@ async def route_func_body( # noqa: C901 if not aana_settings.task_queue.enabled: raise RuntimeError("Task queue is not enabled.") # noqa: TRY003 - task_id = create_task( - endpoint=bound_path, - data=data_dict, - ) - return AanaJSONResponse(content={"task_id": task_id}) + task_repo = TaskRepository(self.session) + task = task_repo.save(endpoint=bound_path, data=data_dict) + return AanaJSONResponse(content={"task_id": str(task.id)}) if isasyncgenfunction(self.run): diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 75647fa0..dc8cb0cf 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -3,6 +3,7 @@ from typing import Annotated, Any from uuid import uuid4 +import orjson import ray from fastapi import Depends from fastapi.openapi.utils import get_openapi @@ -13,6 +14,7 @@ from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema from aana.api.app import app from aana.api.event_handlers.event_manager import EventManager +from aana.api.exception_handler import custom_exception_handler from aana.api.responses import AanaJSONResponse from aana.configs.settings import settings as aana_settings from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse @@ -20,6 +22,7 @@ from aana.core.models.sampling import SamplingParams from aana.core.models.task import TaskId, TaskInfo from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.storage.models.task import Status as TaskStatus from aana.storage.repository.task import TaskRepository from aana.storage.session import get_session @@ -92,30 +95,46 @@ async def is_ready(self): """ return AanaJSONResponse(content={"ready": self.ready}) - async def call_endpoint(self, path: str, **kwargs: dict[str, Any]) -> Any: - """Call the endpoint from FastAPI with the given name. + async def execute_task(self, task_id: str) -> Any: + """Execute a task. Args: - path (str): The path of the endpoint. - **kwargs: The arguments to pass to the endpoint. + task_id (str): The task ID. Returns: Any: The response from the endpoint. """ - for e in self.endpoints: - if e.path == path: - endpoint = e - break - else: - raise ValueError(f"Endpoint {path} not found") # noqa: TRY003 - - if not endpoint.initialized: - await endpoint.initialize() - - if endpoint.is_streaming_response(): - return [item async for item in endpoint.run(**kwargs)] + session = get_session() + task_repo = TaskRepository(session) + try: + task = task_repo.read(task_id) + path = task.endpoint + kwargs = task.data + + task_repo.update_status(task_id, TaskStatus.RUNNING, 0) + + for e in self.endpoints: + if e.path == path: + endpoint = e + break + else: + raise ValueError(f"Endpoint {path} not found") # noqa: TRY003, TRY301 + + if not endpoint.initialized: + await endpoint.initialize() + + if endpoint.is_streaming_response(): + out = [item async for item in endpoint.run(**kwargs)] + else: + out = await endpoint.run(**kwargs) + + task_repo.update_status(task_id, TaskStatus.COMPLETED, 100, out) + except Exception as e: + error_response = custom_exception_handler(None, e) + error = orjson.loads(error_response.body) + task_repo.update_status(task_id, TaskStatus.FAILED, 0, error) else: - return await endpoint.run(**kwargs) + return out @app.get( "/tasks/get/{task_id}", diff --git a/aana/configs/settings.py b/aana/configs/settings.py index fe61b598..ec161ac6 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -1,13 +1,13 @@ from pathlib import Path -from pydantic import ConfigDict, field_validator -from pydantic_settings import BaseSettings +from pydantic import BaseModel, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict from aana.configs.db import DbSettings from aana.core.models.base import pydantic_protected_fields -class TestSettings(BaseSettings): +class TestSettings(BaseModel): """A pydantic model for test settings. Attributes: @@ -19,7 +19,7 @@ class TestSettings(BaseSettings): save_expected_output: bool = False -class TaskQueueSettings(BaseSettings): +class TaskQueueSettings(BaseModel): """A pydantic model for task queue settings. Attributes: @@ -28,11 +28,13 @@ class TaskQueueSettings(BaseSettings): execution_timeout (int): The maximum execution time for a task in seconds. After this time, if the task is still running, it will be considered as stuck and will be reassign to another worker. + max_retries (int): The maximum number of retries for a task. """ enabled: bool = True num_workers: int = 4 execution_timeout: int = 600 + max_retries: int = 3 class Settings(BaseSettings): @@ -69,8 +71,10 @@ def create_tmp_data_dir(cls, path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path - model_config = ConfigDict( - protected_namespaces=("settings", *pydantic_protected_fields) + model_config = SettingsConfigDict( + protected_namespaces=("settings", *pydantic_protected_fields), + env_nested_delimiter="__", + env_ignore_empty=True, ) diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index 3e63ea9d..0930bddf 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -1,19 +1,15 @@ import asyncio -import concurrent.futures from typing import Any -import orjson import ray from pydantic import BaseModel, Field from ray import serve -from aana.api.exception_handler import custom_exception_handler from aana.configs.settings import settings as aana_settings from aana.deployments.base_deployment import BaseDeployment from aana.storage.models.task import Status as TaskStatus from aana.storage.repository.task import TaskRepository from aana.storage.session import get_session -from aana.utils.asyncio import run_async class TaskQueueConfig(BaseModel): @@ -29,7 +25,6 @@ class TaskQueueDeployment(BaseDeployment): def __init__(self): """Initialize the task queue deployment.""" super().__init__() - self.futures = {} loop = asyncio.get_running_loop() self.loop_task = loop.create_task(self.loop()) self.loop_task.add_done_callback( @@ -37,21 +32,32 @@ def __init__(self): ) self.session = get_session() self.task_repo = TaskRepository(self.session) + self.running_task_ids: list[str] = [] + self.deployment_responses = {} def check_health(self): """Check the health of the deployment.""" # if the loop is not running, the deployment is unhealthy if self.loop_task.done(): - raise RuntimeError("Task queue loop is not running") # noqa: TRY003 + raise RuntimeError( # noqa: TRY003 + "Task queue loop is not running" + ) from self.loop_task.exception() def __del__(self): """Clean up the deployment.""" # Cancel the loop task to prevent tasks from being reassigned self.loop_task.cancel() - # Set all non-completed tasks to NOT_FINISHED - for task_id, future in self.futures.items(): - if not future.done(): - self.task_repo.update_status(task_id, TaskStatus.NOT_FINISHED, 0) + # Cancel all deployment responses to stop the tasks + # and set all non-completed tasks to NOT_FINISHED + for task_id in self.running_task_ids: + deployment_response = self.deployment_responses.get(task_id) + if deployment_response: + deployment_response.cancel() + self.task_repo.update_status( + task_id=task_id, + status=TaskStatus.NOT_FINISHED, + progress=0, + ) async def apply_config(self, config: dict[str, Any]): """Apply the configuration. @@ -61,50 +67,25 @@ async def apply_config(self, config: dict[str, Any]): The configuration should conform to the TaskQueueConfig schema. """ config_obj = TaskQueueConfig(**config) - self.handle = None self.app_name = config_obj.app_name - self.thread_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=aana_settings.task_queue.num_workers, - ) async def loop(self): # noqa: C901 """The main loop for the task queue deployment. The loop will check the queue and assign tasks to workers. """ - - async def handle_task(task_id: str): - """Process a task.""" - # Fetch the task details - task = self.task_repo.read(task_id) - # Initially set the task status to RUNNING - self.task_repo.update_status(task_id, TaskStatus.RUNNING, 0) - try: - # Call the endpoint asynchronously - out = await self.handle.call_endpoint.remote(task.endpoint, **task.data) - # Update the task status to COMPLETED - self.task_repo.update_status(task_id, TaskStatus.COMPLETED, 100, out) - except Exception as e: - # Handle the exception and update the task status to FAILED - error_response = custom_exception_handler(None, e) - error = orjson.loads(error_response.body) - self.task_repo.update_status(task_id, TaskStatus.FAILED, 0, error) - - def run_handle_task(task_id): - """Wrapper to run the handle_task function.""" - run_async(handle_task(task_id)) - - def is_thread_pool_full(): - """Check if the thread pool has too many tasks. - - We use it to stop assigning tasks to the thread pool if it's full - to prevent the thread pool from being overwhelmed. - We don't want to schedule all tasks from the task queue (could be millions). - """ - return ( - self.thread_pool._work_queue.qsize() - > aana_settings.task_queue.num_workers * 2 - ) + handle = None + + active_tasks = self.task_repo.get_active_tasks() + for task in active_tasks: + if task.status == TaskStatus.RUNNING: + self.running_task_ids.append(str(task.id)) + if task.status == TaskStatus.ASSIGNED: + self.task_repo.update_status( + task_id=task.id, + status=TaskStatus.NOT_FINISHED, + progress=0, + ) while True: if not self._configured: @@ -112,30 +93,60 @@ def is_thread_pool_full(): await asyncio.sleep(1) continue - # Remove completed tasks from the futures dictionary - for task_id in list(self.futures.keys()): - if self.futures[task_id].done(): - del self.futures[task_id] + # Remove completed tasks from the list of running tasks + self.running_task_ids = self.task_repo.filter_incomplete_tasks( + self.running_task_ids + ) - if is_thread_pool_full(): - # wait a bit to give the thread pool time to process tasks + # Check for expired tasks + execution_timeout = aana_settings.task_queue.execution_timeout + expired_tasks = self.task_repo.get_expired_tasks(execution_timeout) + for task in expired_tasks: + deployment_response = self.deployment_responses.get(task.id) + if deployment_response: + deployment_response.cancel() + if task.num_retries >= aana_settings.task_queue.max_retries: + self.task_repo.update_status( + task_id=task.id, + status=TaskStatus.FAILED, + progress=0, + result={ + "error": "TimeoutError", + "message": ( + f"Task execution timed out after {execution_timeout} seconds and " + f"exceeded the maximum number of retries ({aana_settings.task_queue.max_retries})" + ), + }, + ) + else: + self.task_repo.update_status( + task_id=task.id, + status=TaskStatus.NOT_FINISHED, + progress=0, + ) + + # If the queue is full, wait and retry + if len(self.running_task_ids) >= aana_settings.task_queue.num_workers: await asyncio.sleep(0.1) continue - tasks = self.task_repo.get_unprocessed_tasks( - limit=aana_settings.task_queue.num_workers * 2 + # Get new tasks from the database + num_tasks_to_assign = aana_settings.task_queue.num_workers - len( + self.running_task_ids ) + tasks = self.task_repo.get_unprocessed_tasks(limit=num_tasks_to_assign) + # If there are no tasks, wait and retry if not tasks: await asyncio.sleep(0.1) continue - if not self.handle: + if not handle: # Sometimes the app isn't available immediately after the deployment is created # so we need to wait for it to become available for _ in range(10): try: - self.handle = serve.get_app_handle(self.app_name) + handle = serve.get_app_handle(self.app_name) break except ray.serve.exceptions.RayServeException as e: print( @@ -146,13 +157,15 @@ def is_thread_pool_full(): # If the app is not available after all retries, try again # but without catching the exception # (if it fails, the deployment will be unhealthy, and restart will be attempted) - self.handle = serve.get_app_handle(self.app_name) + handle = serve.get_app_handle(self.app_name) + # Start processing the tasks for task in tasks: - if is_thread_pool_full(): - # wait a bit to give the thread pool time to process tasks - await asyncio.sleep(0.1) - break - self.task_repo.update_status(task.id, TaskStatus.ASSIGNED, 0) - future = self.thread_pool.submit(run_handle_task, task.id) - self.futures[task.id] = future + self.task_repo.update_status( + task_id=task.id, + status=TaskStatus.ASSIGNED, + progress=0, + ) + deployment_response = handle.execute_task.remote(task_id=task.id) + self.deployment_responses[task.id] = deployment_response + self.running_task_ids.append(str(task.id)) diff --git a/aana/storage/models/base.py b/aana/storage/models/base.py index 37ef80fc..616066b1 100644 --- a/aana/storage/models/base.py +++ b/aana/storage/models/base.py @@ -14,7 +14,7 @@ timestamp = Annotated[ datetime.datetime, - mapped_column(DateTime(timezone=True), server_default=func.now()), + mapped_column(DateTime(timezone=True)), ] T = TypeVar("T", bound="InheritanceReuseMixin") @@ -79,9 +79,11 @@ class TimeStampEntity: """Mixin for database entities that will have create/update timestamps.""" created_at: Mapped[timestamp] = mapped_column( + server_default=func.now(), comment="Timestamp when row is inserted", ) updated_at: Mapped[timestamp] = mapped_column( onupdate=func.now(), + server_default=func.now(), comment="Timestamp when row is updated", ) diff --git a/aana/storage/models/task.py b/aana/storage/models/task.py index 4f76b056..c5b15212 100644 --- a/aana/storage/models/task.py +++ b/aana/storage/models/task.py @@ -45,6 +45,7 @@ class TaskEntity(BaseEntity, TimeStampEntity): comment="Timestamp when the task was assigned", ) completed_at: Mapped[timestamp | None] = mapped_column( + server_default=None, comment="Timestamp when the task was completed", ) progress: Mapped[float] = mapped_column( @@ -53,7 +54,16 @@ class TaskEntity(BaseEntity, TimeStampEntity): result: Mapped[dict | None] = mapped_column( JSON, comment="Result of the task in JSON format" ) + num_retries: Mapped[int] = mapped_column( + nullable=False, default=0, comment="Number of retries" + ) def __repr__(self) -> str: """String representation of the task.""" - return f"" + return ( + f"" + ) diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index bc482762..4dde5a29 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -2,10 +2,9 @@ from typing import Any from uuid import UUID -from sqlalchemy import and_, desc, or_ +from sqlalchemy import and_, desc from sqlalchemy.orm import Session -from aana.configs.settings import settings as aana_settings from aana.storage.models.task import Status as TaskStatus from aana.storage.models.task import TaskEntity from aana.storage.repository.base import BaseRepository @@ -71,9 +70,7 @@ def save(self, endpoint: str, data: Any, priority: int = 0): def get_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: """Fetches all unprocessed tasks. - The task is considered unprocessed if it is in CREATED or NOT_FINISHED state or - in RUNNING or ASSIGNED state and the update timestamp is older - than the execution timeout (to handle stuck tasks). + The task is considered unprocessed if it is in CREATED or NOT_FINISHED state. Args: limit (int | None): The maximum number of tasks to fetch. If None, fetch all. @@ -81,22 +78,10 @@ def get_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]: Returns: list[TaskEntity]: the unprocessed tasks. """ - execution_timeout = aana_settings.task_queue.execution_timeout - cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 tasks = ( self.session.query(TaskEntity) .filter( - or_( - TaskEntity.status.in_( - [TaskStatus.CREATED, TaskStatus.NOT_FINISHED] - ), - and_( - TaskEntity.status.in_( - [TaskStatus.RUNNING, TaskStatus.ASSIGNED] - ), - TaskEntity.updated_at <= cutoff_time, - ), - ) + TaskEntity.status.in_([TaskStatus.CREATED, TaskStatus.NOT_FINISHED]) ) .order_by(desc(TaskEntity.priority), TaskEntity.created_at) .limit(limit) @@ -120,10 +105,82 @@ def update_status( result (Any): The result. """ task = self.read(task_id) - task.status = status if status == TaskStatus.COMPLETED or status == TaskStatus.FAILED: task.completed_at = datetime.now() # noqa: DTZ005 + if status == TaskStatus.ASSIGNED: + task.assigned_at = datetime.now() # noqa: DTZ005 + task.num_retries += 1 if progress is not None: task.progress = progress + task.status = status task.result = result self.session.commit() + + def get_active_tasks(self) -> list[TaskEntity]: + """Fetches all active tasks. + + The task is considered active if it is in RUNNING or ASSIGNED state. + + Returns: + list[TaskEntity]: the active tasks. + """ + tasks = ( + self.session.query(TaskEntity) + .filter(TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED])) + .all() + ) + return tasks + + def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]: + """Remove the task IDs that are already completed (COMPLETED or FAILED). + + Args: + task_ids (list[str]): The task IDs to filter. + + Returns: + list[str]: The task IDs that are not completed. + """ + task_ids = [UUID(task_id) for task_id in task_ids] + tasks = ( + self.session.query(TaskEntity) + .filter( + and_( + TaskEntity.id.in_(task_ids), + TaskEntity.status.not_in( + [ + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.NOT_FINISHED, + ] + ), + ) + ) + .all() + ) + incomplete_task_ids = [str(task.id) for task in tasks] + return incomplete_task_ids + + def get_expired_tasks(self, execution_timeout: float) -> list[TaskEntity]: + """Fetches all tasks that are expired. + + The task is considered expired if it is in RUNNING or ASSIGNED state and the + updated_at time is older than the execution_timeout. + + Args: + execution_timeout (float): The maximum execution time for a task in seconds + + Returns: + list[TaskEntity]: the expired tasks. + """ + cutoff_time = datetime.now() - timedelta(seconds=execution_timeout) # noqa: DTZ005 + tasks = ( + self.session.query(TaskEntity) + .filter( + and_( + TaskEntity.status.in_([TaskStatus.RUNNING, TaskStatus.ASSIGNED]), + TaskEntity.updated_at <= cutoff_time, + ), + ) + .all() + ) + return tasks diff --git a/aana/storage/services/task.py b/aana/storage/services/task.py deleted file mode 100644 index b93195d0..00000000 --- a/aana/storage/services/task.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - -from sqlalchemy.orm import Session - -from aana.storage.engine import engine -from aana.storage.repository.task import TaskRepository - - -def create_task( - endpoint: str, - data: Any, - priority: int = 0, -) -> str: - """Create a task. - - Args: - endpoint: The endpoint to which the task is assigned. - data: Data for the task. - priority: Priority of the task (0 is the lowest). - - Returns: - str: The task ID. - """ - with Session(engine) as session: - task_repo = TaskRepository(session) - task = task_repo.save(endpoint=endpoint, data=data, priority=priority) - return str(task.id) diff --git a/aana/tests/db/datastore/test_task_repo.py b/aana/tests/db/datastore/test_task_repo.py new file mode 100644 index 00000000..41fbe64f --- /dev/null +++ b/aana/tests/db/datastore/test_task_repo.py @@ -0,0 +1,301 @@ +# ruff: noqa: S101 + +import asyncio +from datetime import datetime, timedelta + +import pytest + +from aana.storage.models.task import Status as TaskStatus +from aana.storage.models.task import TaskEntity +from aana.storage.repository.task import TaskRepository + + +def test_save_task_repo(db_session): + """Test saving a task.""" + task_repo = TaskRepository(db_session) + task = task_repo.save(endpoint="/test", data={"test": "test"}) + + task_entity = task_repo.read(task.id) + assert task_entity + assert task_entity.id == task.id + assert task_entity.endpoint == "/test" + assert task_entity.data == {"test": "test"} + + +@pytest.mark.asyncio +async def test_multiple_simultaneous_tasks(db_session): + """Test creating multiple tasks in parallel.""" + task_repo = TaskRepository(db_session) + + # Create multiple tasks in parallel with asyncio + async def add_task(i): + task = task_repo.save(endpoint="/test", data={"test": i}) + return task + + async_tasks = [] + for i in range(30): + async_task = asyncio.create_task(add_task(i)) + async_tasks.append(async_task) + + # Wait for all tasks to complete + await asyncio.gather(*async_tasks) + + +def test_get_unprocessed_tasks(db_session): + """Test fetching unprocessed tasks.""" + task_repo = TaskRepository(db_session) + + # Remove all existing tasks + db_session.query(TaskEntity).delete() + db_session.commit() + + # Create sample tasks with different statuses + now = datetime.now() # noqa: DTZ005 + + task1 = TaskEntity( + endpoint="/test1", + data={"test": "data1"}, + status=TaskStatus.CREATED, + priority=1, + created_at=now - timedelta(hours=10), + ) + task2 = TaskEntity( + endpoint="/test2", + data={"test": "data2"}, + status=TaskStatus.NOT_FINISHED, + priority=2, + created_at=now - timedelta(hours=1), + ) + task3 = TaskEntity( + endpoint="/test3", + data={"test": "data3"}, + status=TaskStatus.COMPLETED, + priority=3, + created_at=now - timedelta(hours=2), + ) + task4 = TaskEntity( + endpoint="/test4", + data={"test": "data4"}, + status=TaskStatus.CREATED, + priority=2, + created_at=now - timedelta(hours=3), + ) + + db_session.add_all([task1, task2, task3, task4]) + db_session.commit() + + # Fetch unprocessed tasks without any limit + unprocessed_tasks = task_repo.get_unprocessed_tasks() + + # Assert that only tasks with CREATED and NOT_FINISHED status are returned + assert len(unprocessed_tasks) == 3 + assert task1 in unprocessed_tasks + assert task2 in unprocessed_tasks + assert task4 in unprocessed_tasks + + # Ensure tasks are ordered by priority and then by created_at + assert unprocessed_tasks[0].id == task4.id # Highest priority + assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task + assert unprocessed_tasks[2].id == task1.id # Lowest priority + + # Fetch unprocessed tasks with a limit + limited_tasks = task_repo.get_unprocessed_tasks(limit=2) + + # Assert that only the specified number of tasks is returned + assert len(limited_tasks) == 2 + assert limited_tasks[0].id == task4.id # Highest priority + assert limited_tasks[1].id == task2.id # Same priority, but older + + +def test_update_status(db_session): + """Test updating the status of a task.""" + task_repo = TaskRepository(db_session) + + # Create a task with an initial status + task = TaskEntity( + endpoint="/test", data={"key": "value"}, status=TaskStatus.CREATED + ) + db_session.add(task) + db_session.commit() + + # Update the status to ASSIGNED and check fields + task_repo.update_status(task.id, TaskStatus.ASSIGNED, progress=50) + + updated_task = task_repo.read(task.id) + assert updated_task.status == TaskStatus.ASSIGNED + assert updated_task.assigned_at is not None + assert updated_task.num_retries == 1 + assert updated_task.progress == 50 + assert updated_task.result is None + + # Update the status to COMPLETED and check fields + task_repo.update_status( + task.id, TaskStatus.COMPLETED, progress=100, result={"result": "final_result"} + ) + + updated_task = task_repo.read(task.id) + assert updated_task.status == TaskStatus.COMPLETED + assert updated_task.completed_at is not None + assert updated_task.progress == 100 + assert updated_task.result == {"result": "final_result"} + + # Ensure timestamps are reasonable + assert updated_task.assigned_at < updated_task.completed_at + assert updated_task.created_at < updated_task.assigned_at + + # Update the status to FAILED and check fields + task_repo.update_status( + task.id, TaskStatus.FAILED, progress=0, result={"error": "error_message"} + ) + updated_task = task_repo.read(task.id) + + assert updated_task.status == TaskStatus.FAILED + assert updated_task.completed_at is not None + assert updated_task.progress == 0 + assert updated_task.result == {"error": "error_message"} + + # Ensure timestamps are reasonable + assert updated_task.assigned_at < updated_task.completed_at + assert updated_task.created_at < updated_task.assigned_at + + +def test_get_active_tasks(db_session): + task_repo = TaskRepository(db_session) + + # Remove all existing tasks + db_session.query(TaskEntity).delete() + db_session.commit() + + # Create sample tasks with different statuses + task1 = TaskEntity( + endpoint="/task1", data={"test": "data1"}, status=TaskStatus.CREATED + ) + task2 = TaskEntity( + endpoint="/task2", data={"test": "data2"}, status=TaskStatus.RUNNING + ) + task3 = TaskEntity( + endpoint="/task3", data={"test": "data3"}, status=TaskStatus.ASSIGNED + ) + task4 = TaskEntity( + endpoint="/task4", data={"test": "data4"}, status=TaskStatus.COMPLETED + ) + task5 = TaskEntity( + endpoint="/task5", data={"test": "data5"}, status=TaskStatus.FAILED + ) + task6 = TaskEntity( + endpoint="/task6", data={"test": "data6"}, status=TaskStatus.NOT_FINISHED + ) + + db_session.add_all([task1, task2, task3, task4, task5, task6]) + db_session.commit() + + # Fetch active tasks + active_tasks = task_repo.get_active_tasks() + + # Assert that only tasks with RUNNING and ASSIGNED status are returned + assert len(active_tasks) == 2 + assert task2 in active_tasks + assert task3 in active_tasks + assert all( + task.status in [TaskStatus.RUNNING, TaskStatus.ASSIGNED] + for task in active_tasks + ) + + +def test_remove_completed_tasks(db_session): + """Test removing completed tasks.""" + task_repo = TaskRepository(db_session) + + # Remove all existing tasks + db_session.query(TaskEntity).delete() + db_session.commit() + + # Create sample tasks with different statuses + task1 = TaskEntity( + endpoint="/task1", data={"test": "data1"}, status=TaskStatus.COMPLETED + ) + task2 = TaskEntity( + endpoint="/task2", data={"test": "data2"}, status=TaskStatus.RUNNING + ) + task3 = TaskEntity( + endpoint="/task3", data={"test": "data3"}, status=TaskStatus.ASSIGNED + ) + task4 = TaskEntity( + endpoint="/task4", data={"test": "data4"}, status=TaskStatus.COMPLETED + ) + task5 = TaskEntity( + endpoint="/task5", data={"test": "data5"}, status=TaskStatus.FAILED + ) + task6 = TaskEntity( + endpoint="/task6", data={"test": "data6"}, status=TaskStatus.NOT_FINISHED + ) + + all_tasks = [task1, task2, task3, task4, task5, task6] + unfinished_tasks = [task2, task3] + + db_session.add_all(all_tasks) + db_session.commit() + + # Remove completed tasks + task_ids = [str(task.id) for task in all_tasks] + non_completed_task_ids = task_repo.filter_incomplete_tasks(task_ids) + + # Assert that only the task IDs that are not completed are returned + assert set(non_completed_task_ids) == {str(task.id) for task in unfinished_tasks} + + +def test_get_expired_tasks(db_session): + """Test fetching expired tasks.""" + task_repo = TaskRepository(db_session) + + # Remove all existing tasks + db_session.query(TaskEntity).delete() + db_session.commit() + + # Set up current time and a cutoff time + current_time = datetime.now() # noqa: DTZ005 + execution_timeout = 3600 # 1 hour in seconds + + # Create tasks with different updated_at times and statuses + task1 = TaskEntity( + endpoint="/task1", + data={"test": "data1"}, + status=TaskStatus.RUNNING, + updated_at=current_time - timedelta(hours=2), + ) + task2 = TaskEntity( + endpoint="/task2", + data={"test": "data2"}, + status=TaskStatus.ASSIGNED, + updated_at=current_time - timedelta(seconds=2), + ) + task3 = TaskEntity( + endpoint="/task3", + data={"test": "data3"}, + status=TaskStatus.RUNNING, + updated_at=current_time, + ) + task4 = TaskEntity( + endpoint="/task4", + data={"test": "data4"}, + status=TaskStatus.COMPLETED, + updated_at=current_time - timedelta(hours=2), + ) + task5 = TaskEntity( + endpoint="/task5", + data={"test": "data5"}, + status=TaskStatus.FAILED, + updated_at=current_time - timedelta(seconds=4), + ) + + db_session.add_all([task1, task2, task3, task4, task5]) + db_session.commit() + + # Fetch expired tasks + expired_tasks = task_repo.get_expired_tasks(execution_timeout) + + # Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned + expected_task_ids = {str(task1.id)} + returned_task_ids = {str(task.id) for task in expired_tasks} + + assert returned_task_ids == expected_task_ids diff --git a/aana/tests/units/test_task_queue.py b/aana/tests/units/test_task_queue.py index cd7045e4..46667d4b 100644 --- a/aana/tests/units/test_task_queue.py +++ b/aana/tests/units/test_task_queue.py @@ -11,10 +11,6 @@ from aana.api.api_generation import Endpoint from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.deployments.base_deployment import BaseDeployment -from aana.deployments.task_queue_deployment import ( - TaskQueueConfig, - TaskQueueDeployment, -) @serve.deployment @@ -96,13 +92,6 @@ async def run( "name": "lowercase_deployment", "instance": Lowercase, }, - { - "name": "task_queue_deployment", - "instance": TaskQueueDeployment.options( - num_replicas=1, - user_config=TaskQueueConfig(app_name="app").model_dump(mode="json"), - ), - }, ] @@ -226,3 +215,42 @@ def test_task_queue(create_app): assert task_status == "completed" assert [chunk["text"] for chunk in result] == lowercase_text + + # Send 30 tasks to the task queue + task_ids = [] + for i in range(30): + data = {"text": [f"Task {i}"]} + response = requests.post( + f"http://localhost:{port}{route_prefix}/lowercase_stream?defer=True", + data={"body": json.dumps(data)}, + ) + assert response.status_code == 200 + task_ids.append(response.json().get("task_id")) + + # Check the task statusES with timeout of 10 seconds + start_time = time.time() + completed_tasks = [] + while time.time() - start_time < 10: + for task_id in task_ids: + if task_id in completed_tasks: + continue + response = requests.get( + f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" + ) + task_status = response.json().get("status") + result = response.json().get("result") + if task_status == "completed": + completed_tasks.append(task_id) + + if len(completed_tasks) == len(task_ids): + break + time.sleep(0.1) + + # Check that all tasks are completed + for task_id in task_ids: + response = requests.get( + f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" + ) + response = response.json() + task_status = response.get("status") + assert task_status == "completed", response