Skip to content

Commit

Permalink
feat: Call agent rpc call in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Dec 26, 2024
1 parent 3a62875 commit 26dd587
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions src/ai/backend/manager/models/gql_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
import logging
import uuid
Expand Down Expand Up @@ -933,32 +934,34 @@ async def mutate(
else:
agent_ids = [agent.id async for agent in graph_ctx.registry.enumerate_instances()]

async def _rescan_alloc_map_task(reporter: ProgressReporter) -> None:
for agent_id in agent_ids:
await reporter.update(
message=f"Agent {agent_id} GPU alloc map scannning...",
)
async def _scan_single_agent(agent_id: str, reporter: ProgressReporter) -> None:
await reporter.update(message=f"Agent {agent_id} GPU alloc map scanning...")

reporter_msg = ""
try:
alloc_map: Mapping[str, Any] = await graph_ctx.registry.scan_gpu_alloc_map(
AgentId(agent_id)
)
key = f"gpu_alloc_map.{agent_id}"
await redis_helper.execute(
graph_ctx.registry.redis_stat,
lambda r: r.set(name=key, value=json.dumps(alloc_map)),
)
except Exception as e:
reporter_msg = f"Failed to scan GPU alloc map for agent {agent_id}: {str(e)}"
log.error(reporter_msg)
else:
reporter_msg = f"Agent {agent_id} GPU alloc map scanned."

await reporter.update(
increment=1,
message=reporter_msg,
reporter_msg = ""
try:
alloc_map: Mapping[str, Any] = await graph_ctx.registry.scan_gpu_alloc_map(
AgentId(agent_id)
)
key = f"gpu_alloc_map.{agent_id}"
await redis_helper.execute(
graph_ctx.registry.redis_stat,
lambda r: r.set(name=key, value=json.dumps(alloc_map)),
)
except Exception as e:
reporter_msg = f"Failed to scan GPU alloc map for agent {agent_id}: {str(e)}"
log.error(reporter_msg)
else:
reporter_msg = f"Agent {agent_id} GPU alloc map scanned."

await reporter.update(
increment=1,
message=reporter_msg,
)

async def _rescan_alloc_map_task(reporter: ProgressReporter) -> None:
async with asyncio.TaskGroup() as tg:
for agent_id in agent_ids:
tg.create_task(_scan_single_agent(agent_id, reporter))

await reporter.update(message="GPU alloc map scanning completed")

Expand Down

0 comments on commit 26dd587

Please sign in to comment.