Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task Queue Redesign #168

Merged
merged 10 commits into from
Aug 30, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Set server default for task.completed_at and task.assigned_at to none and add num_retries.

Revision ID: b9860676dd49
Revises: 5ad873484aa3
Create Date: 2024-08-22 07:54:55.921710

"""
from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b9860676dd49"
down_revision: str | None = "5ad873484aa3"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade database to this revision from previous."""
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.alter_column(
"completed_at",
server_default=None,
)
batch_op.alter_column(
"assigned_at",
server_default=None,
)
batch_op.add_column(
sa.Column(
"num_retries",
sa.Integer(),
nullable=False,
comment="Number of retries",
server_default=sa.text("0"),
)
)

# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database from this revision to previous."""
with op.batch_alter_table("tasks", schema=None) as batch_op:
batch_op.drop_column("num_retries")

# ### end Alembic commands ###
10 changes: 4 additions & 6 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from aana.exceptions.runtime import (
MultipleFileUploadNotAllowed,
)
from aana.storage.services.task import create_task
from aana.storage.repository.task import TaskRepository
from aana.storage.session import get_session


Expand Down Expand Up @@ -313,11 +313,9 @@ async def route_func_body( # noqa: C901
if not aana_settings.task_queue.enabled:
raise RuntimeError("Task queue is not enabled.") # noqa: TRY003

task_id = create_task(
endpoint=bound_path,
data=data_dict,
)
return AanaJSONResponse(content={"task_id": task_id})
task_repo = TaskRepository(self.session)
task = task_repo.save(endpoint=bound_path, data=data_dict)
return AanaJSONResponse(content={"task_id": str(task.id)})

if isasyncgenfunction(self.run):

Expand Down
53 changes: 36 additions & 17 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated, Any
from uuid import uuid4

import orjson
import ray
from fastapi import Depends
from fastapi.openapi.utils import get_openapi
Expand All @@ -13,13 +14,15 @@
from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema
from aana.api.app import app
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.task import TaskId, TaskInfo
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.storage.models.task import Status as TaskStatus
from aana.storage.repository.task import TaskRepository
from aana.storage.session import get_session

Expand Down Expand Up @@ -92,30 +95,46 @@ async def is_ready(self):
"""
return AanaJSONResponse(content={"ready": self.ready})

async def call_endpoint(self, path: str, **kwargs: dict[str, Any]) -> Any:
"""Call the endpoint from FastAPI with the given name.
async def execute_task(self, task_id: str) -> Any:
"""Execute a task.

Args:
path (str): The path of the endpoint.
**kwargs: The arguments to pass to the endpoint.
task_id (str): The task ID.

Returns:
Any: The response from the endpoint.
"""
for e in self.endpoints:
if e.path == path:
endpoint = e
break
else:
raise ValueError(f"Endpoint {path} not found") # noqa: TRY003

if not endpoint.initialized:
await endpoint.initialize()

if endpoint.is_streaming_response():
return [item async for item in endpoint.run(**kwargs)]
session = get_session()
HRashidi marked this conversation as resolved.
Show resolved Hide resolved
task_repo = TaskRepository(session)
try:
task = task_repo.read(task_id)
path = task.endpoint
kwargs = task.data

task_repo.update_status(task_id, TaskStatus.RUNNING, 0)

for e in self.endpoints:
if e.path == path:
endpoint = e
break
else:
raise ValueError(f"Endpoint {path} not found") # noqa: TRY003, TRY301

if not endpoint.initialized:
await endpoint.initialize()

if endpoint.is_streaming_response():
out = [item async for item in endpoint.run(**kwargs)]
else:
out = await endpoint.run(**kwargs)

task_repo.update_status(task_id, TaskStatus.COMPLETED, 100, out)
except Exception as e:
error_response = custom_exception_handler(None, e)
error = orjson.loads(error_response.body)
task_repo.update_status(task_id, TaskStatus.FAILED, 0, error)
else:
return await endpoint.run(**kwargs)
return out

@app.get(
"/tasks/get/{task_id}",
Expand Down
16 changes: 10 additions & 6 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from pathlib import Path

from pydantic import ConfigDict, field_validator
from pydantic_settings import BaseSettings
from pydantic import BaseModel, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from aana.configs.db import DbSettings
from aana.core.models.base import pydantic_protected_fields


class TestSettings(BaseSettings):
class TestSettings(BaseModel):
"""A pydantic model for test settings.

Attributes:
Expand All @@ -19,7 +19,7 @@ class TestSettings(BaseSettings):
save_expected_output: bool = False


class TaskQueueSettings(BaseSettings):
class TaskQueueSettings(BaseModel):
"""A pydantic model for task queue settings.

Attributes:
Expand All @@ -28,11 +28,13 @@ class TaskQueueSettings(BaseSettings):
execution_timeout (int): The maximum execution time for a task in seconds.
After this time, if the task is still running,
it will be considered as stuck and will be reassign to another worker.
max_retries (int): The maximum number of retries for a task.
"""

enabled: bool = True
num_workers: int = 4
execution_timeout: int = 600
max_retries: int = 3
HRashidi marked this conversation as resolved.
Show resolved Hide resolved


class Settings(BaseSettings):
Expand Down Expand Up @@ -69,8 +71,10 @@ def create_tmp_data_dir(cls, path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path

model_config = ConfigDict(
protected_namespaces=("settings", *pydantic_protected_fields)
model_config = SettingsConfigDict(
protected_namespaces=("settings", *pydantic_protected_fields),
env_nested_delimiter="__",
env_ignore_empty=True,
)


Expand Down
Loading
Loading