Skip to content

Commit

Permalink
Merge pull request #187 from mobiusml/multiple_task_queue_deployments
Browse files Browse the repository at this point in the history
Support for Multiple Instances of TaskQueueDeployment
  • Loading branch information
movchan74 authored Oct 22, 2024
2 parents b5be4d3 + 836c21c commit 5900e72
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 97 deletions.
68 changes: 20 additions & 48 deletions aana/deployments/task_queue_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ class TaskQueueConfig(BaseModel):

@serve.deployment
class TaskQueueDeployment(BaseDeployment):
"""Deployment to serve the task queue."""
"""Deployment to serve the task queue.
IMPORTANT: If you are using SQLite, make sure to run only one instance
of this deployment to avoid database race conditions.
"""

def __init__(self):
"""Initialize the task queue deployment."""
Expand Down Expand Up @@ -76,17 +80,6 @@ async def loop(self): # noqa: C901
"""
handle = None

active_tasks = self.task_repo.get_active_tasks()
for task in active_tasks:
if task.status == TaskStatus.RUNNING:
self.running_task_ids.append(str(task.id))
if task.status == TaskStatus.ASSIGNED:
self.task_repo.update_status(
task_id=task.id,
status=TaskStatus.NOT_FINISHED,
progress=0,
)

while True:
if not self._configured:
# Wait for the deployment to be configured.
Expand All @@ -100,47 +93,20 @@ async def loop(self): # noqa: C901

# Check for expired tasks
execution_timeout = aana_settings.task_queue.execution_timeout
expired_tasks = self.task_repo.get_expired_tasks(execution_timeout)
max_retries = aana_settings.task_queue.max_retries
expired_tasks = self.task_repo.update_expired_tasks(
execution_timeout=execution_timeout, max_retries=max_retries
)
for task in expired_tasks:
deployment_response = self.deployment_responses.get(task.id)
if deployment_response:
deployment_response.cancel()
if task.num_retries >= aana_settings.task_queue.max_retries:
self.task_repo.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result={
"error": "TimeoutError",
"message": (
f"Task execution timed out after {execution_timeout} seconds and "
f"exceeded the maximum number of retries ({aana_settings.task_queue.max_retries})"
),
},
)
else:
self.task_repo.update_status(
task_id=task.id,
status=TaskStatus.NOT_FINISHED,
progress=0,
)

# If the queue is full, wait and retry
if len(self.running_task_ids) >= aana_settings.task_queue.num_workers:
await asyncio.sleep(0.1)
continue

# Get new tasks from the database
num_tasks_to_assign = aana_settings.task_queue.num_workers - len(
self.running_task_ids
)
tasks = self.task_repo.get_unprocessed_tasks(limit=num_tasks_to_assign)

# If there are no tasks, wait and retry
if not tasks:
await asyncio.sleep(0.1)
continue

if not handle:
# Sometimes the app isn't available immediately after the deployment is created
# so we need to wait for it to become available
Expand All @@ -159,13 +125,19 @@ async def loop(self): # noqa: C901
# (if it fails, the deployment will be unhealthy, and restart will be attempted)
handle = serve.get_app_handle(self.app_name)

# Get new tasks from the database
num_tasks_to_assign = aana_settings.task_queue.num_workers - len(
self.running_task_ids
)
tasks = self.task_repo.fetch_unprocessed_tasks(limit=num_tasks_to_assign)

# If there are no tasks, wait and retry
if not tasks:
await asyncio.sleep(0.1)
continue

# Start processing the tasks
for task in tasks:
self.task_repo.update_status(
task_id=task.id,
status=TaskStatus.ASSIGNED,
progress=0,
)
deployment_response = handle.execute_task.remote(task_id=task.id)
self.deployment_responses[task.id] = deployment_response
self.running_task_ids.append(str(task.id))
7 changes: 6 additions & 1 deletion aana/storage/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@


class DbType(str, Enum):
"""Engine types for relational database."""
"""Engine types for relational database.
Attributes:
POSTGRESQL: PostgreSQL database.
SQLITE: SQLite database.
"""

POSTGRESQL = "postgresql"
SQLITE = "sqlite"
Expand Down
69 changes: 64 additions & 5 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def save(self, endpoint: str, data: Any, priority: int = 0):
self.session.commit()
return task

def get_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
"""Fetches all unprocessed tasks.
def fetch_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
"""Fetches unprocessed tasks and marks them as ASSIGNED.
The task is considered unprocessed if it is in CREATED or NOT_FINISHED state.
The function runs in a transaction and locks the rows to prevent race condition
if multiple task queue deployments are running concurrently.
IMPORTANT: The lock doesn't work with SQLite. If you are using SQLite, you should
only run one task queue deployment at a time. Otherwise, you may encounter
race conditions.
Args:
limit (int | None): The maximum number of tasks to fetch. If None, fetch all.
Expand All @@ -85,8 +92,18 @@ def get_unprocessed_tasks(self, limit: int | None = None) -> list[TaskEntity]:
)
.order_by(desc(TaskEntity.priority), TaskEntity.created_at)
.limit(limit)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
for task in tasks:
self.update_status(
task_id=task.id,
status=TaskStatus.ASSIGNED,
progress=0,
commit=False,
)
self.session.commit()
return tasks

def update_status(
Expand All @@ -95,6 +112,7 @@ def update_status(
status: TaskStatus,
progress: int | None = None,
result: Any = None,
commit: bool = True,
):
"""Update the status of a task.
Expand All @@ -103,6 +121,7 @@ def update_status(
status (TaskStatus): The new status.
progress (int | None): The progress. If None, the progress will not be updated.
result (Any): The result.
commit (bool): Whether to commit the transaction.
"""
task = self.read(task_id)
if status == TaskStatus.COMPLETED or status == TaskStatus.FAILED:
Expand All @@ -114,7 +133,8 @@ def update_status(
task.progress = progress
task.status = status
task.result = result
self.session.commit()
if commit:
self.session.commit()

def get_active_tasks(self) -> list[TaskEntity]:
"""Fetches all active tasks.
Expand Down Expand Up @@ -160,14 +180,28 @@ def filter_incomplete_tasks(self, task_ids: list[str]) -> list[str]:
incomplete_task_ids = [str(task.id) for task in tasks]
return incomplete_task_ids

def get_expired_tasks(self, execution_timeout: float) -> list[TaskEntity]:
"""Fetches all tasks that are expired.
def update_expired_tasks(
self, execution_timeout: float, max_retries: int
) -> list[TaskEntity]:
"""Fetches all tasks that are expired and updates their status.
The task is considered expired if it is in RUNNING or ASSIGNED state and the
updated_at time is older than the execution_timeout.
If the task has exceeded the maximum number of retries, it will be marked as FAILED.
If the task has not exceeded the maximum number of retries, it will be marked as NOT_FINISHED and
be retried again.
The function runs in a transaction and locks the rows to prevent race condition
if multiple task queue deployments are running concurrently.
IMPORTANT: The lock doesn't work with SQLite. If you are using SQLite, you should
only run one task queue deployment at a time. Otherwise, you may encounter
race conditions.
Args:
execution_timeout (float): The maximum execution time for a task in seconds
max_retries (int): The maximum number of retries for a task
Returns:
list[TaskEntity]: the expired tasks.
Expand All @@ -181,6 +215,31 @@ def get_expired_tasks(self, execution_timeout: float) -> list[TaskEntity]:
TaskEntity.updated_at <= cutoff_time,
),
)
.populate_existing()
.with_for_update(skip_locked=True)
.all()
)
for task in tasks:
if task.num_retries >= max_retries:
self.update_status(
task_id=task.id,
status=TaskStatus.FAILED,
progress=0,
result={
"error": "TimeoutError",
"message": (
f"Task execution timed out after {execution_timeout} seconds and "
f"exceeded the maximum number of retries ({max_retries})"
),
},
commit=False,
)
else:
self.update_status(
task_id=task.id,
status=TaskStatus.NOT_FINISHED,
progress=0,
commit=False,
)
self.session.commit()
return tasks
96 changes: 53 additions & 43 deletions aana/tests/db/datastore/test_task_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,52 @@ def test_get_unprocessed_tasks(db_session):
"""Test fetching unprocessed tasks."""
task_repo = TaskRepository(db_session)

# Remove all existing tasks
db_session.query(TaskEntity).delete()
db_session.commit()

# Create sample tasks with different statuses
now = datetime.now() # noqa: DTZ005

task1 = TaskEntity(
endpoint="/test1",
data={"test": "data1"},
status=TaskStatus.CREATED,
priority=1,
created_at=now - timedelta(hours=10),
)
task2 = TaskEntity(
endpoint="/test2",
data={"test": "data2"},
status=TaskStatus.NOT_FINISHED,
priority=2,
created_at=now - timedelta(hours=1),
)
task3 = TaskEntity(
endpoint="/test3",
data={"test": "data3"},
status=TaskStatus.COMPLETED,
priority=3,
created_at=now - timedelta(hours=2),
)
task4 = TaskEntity(
endpoint="/test4",
data={"test": "data4"},
status=TaskStatus.CREATED,
priority=2,
created_at=now - timedelta(hours=3),
)

db_session.add_all([task1, task2, task3, task4])
db_session.commit()
def _create_sample_tasks():
# Remove all existing tasks
db_session.query(TaskEntity).delete()
db_session.commit()

# Create sample tasks with different statuses
now = datetime.now() # noqa: DTZ005

task1 = TaskEntity(
endpoint="/test1",
data={"test": "data1"},
status=TaskStatus.CREATED,
priority=1,
created_at=now - timedelta(hours=10),
)
task2 = TaskEntity(
endpoint="/test2",
data={"test": "data2"},
status=TaskStatus.NOT_FINISHED,
priority=2,
created_at=now - timedelta(hours=1),
)
task3 = TaskEntity(
endpoint="/test3",
data={"test": "data3"},
status=TaskStatus.COMPLETED,
priority=3,
created_at=now - timedelta(hours=2),
)
task4 = TaskEntity(
endpoint="/test4",
data={"test": "data4"},
status=TaskStatus.CREATED,
priority=2,
created_at=now - timedelta(hours=3),
)

db_session.add_all([task1, task2, task3, task4])
db_session.commit()
return task1, task2, task3, task4

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()

# Fetch unprocessed tasks without any limit
unprocessed_tasks = task_repo.get_unprocessed_tasks()
unprocessed_tasks = task_repo.fetch_unprocessed_tasks()

# Assert that only tasks with CREATED and NOT_FINISHED status are returned
assert len(unprocessed_tasks) == 3
Expand All @@ -98,8 +103,11 @@ def test_get_unprocessed_tasks(db_session):
assert unprocessed_tasks[1].id == task2.id # Same priority, but a newer task
assert unprocessed_tasks[2].id == task1.id # Lowest priority

# Create sample tasks
task1, task2, task3, task4 = _create_sample_tasks()

# Fetch unprocessed tasks with a limit
limited_tasks = task_repo.get_unprocessed_tasks(limit=2)
limited_tasks = task_repo.fetch_unprocessed_tasks(limit=2)

# Assert that only the specified number of tasks is returned
assert len(limited_tasks) == 2
Expand Down Expand Up @@ -245,8 +253,8 @@ def test_remove_completed_tasks(db_session):
assert set(non_completed_task_ids) == {str(task.id) for task in unfinished_tasks}


def test_get_expired_tasks(db_session):
"""Test fetching expired tasks."""
def test_update_expired_tasks(db_session):
"""Test updating expired tasks."""
task_repo = TaskRepository(db_session)

# Remove all existing tasks
Expand Down Expand Up @@ -293,7 +301,9 @@ def test_get_expired_tasks(db_session):
db_session.commit()

# Fetch expired tasks
expired_tasks = task_repo.get_expired_tasks(execution_timeout)
expired_tasks = task_repo.update_expired_tasks(
execution_timeout=execution_timeout, max_retries=3
)

# Assert that only tasks with RUNNING or ASSIGNED status and an updated_at older than the cutoff are returned
expected_task_ids = {str(task1.id)}
Expand Down

0 comments on commit 5900e72

Please sign in to comment.