Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(backend): Avoid long synchronous call to block FastAPI event-loop #8429

Merged
merged 5 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from redis import Redis
from backend.executor.database import DatabaseManager

from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex

from .types import (
Expand All @@ -21,8 +21,9 @@
class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis"):
self.locks = RedisKeyedMutex(redis)

@thread_cached_property

@property
@thread_cached
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client
Expand Down
9 changes: 1 addition & 8 deletions autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import Callable, TypeVar, ParamSpec
import threading
from functools import wraps
from typing import Callable, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")


def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
thread_local = threading.local()

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
Expand All @@ -21,7 +18,3 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return cache[key]

return wrapper


def thread_cached_property(func: Callable[[T], R]) -> property:
majdyz marked this conversation as resolved.
Show resolved Hide resolved
return property(thread_cached(func))
5 changes: 3 additions & 2 deletions autogpt_platform/backend/backend/executor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached

from backend.data.block import BlockInput
from backend.data.schedule import (
Expand Down Expand Up @@ -37,7 +37,8 @@ def __init__(self, refresh_interval=10):
def get_port(cls) -> int:
return Config().execution_scheduler_port

@thread_cached_property
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)

Expand Down
12 changes: 6 additions & 6 deletions autogpt_platform/backend/backend/server/integrations/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LoginResponse(BaseModel):


@router.get("/{provider}/login")
async def login(
majdyz marked this conversation as resolved.
Show resolved Hide resolved
def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
Expand Down Expand Up @@ -60,7 +60,7 @@ class CredentialsMetaResponse(BaseModel):


@router.post("/{provider}/callback")
async def callback(
def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
Expand Down Expand Up @@ -115,7 +115,7 @@ async def callback(


@router.get("/{provider}/credentials")
async def list_credentials(
def list_credentials(
provider: Annotated[str, Path(title="The provider to list credentials for")],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
Expand All @@ -133,7 +133,7 @@ async def list_credentials(


@router.get("/{provider}/credentials/{cred_id}")
async def get_credential(
def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
Expand All @@ -149,7 +149,7 @@ async def get_credential(


@router.post("/{provider}/credentials", status_code=201)
async def create_api_key_credentials(
def create_api_key_credentials(
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
api_key: Annotated[str, Body(title="The API key to store")],
Expand Down Expand Up @@ -184,7 +184,7 @@ class CredentialsDeletionResponse(BaseModel):


@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
Expand Down
27 changes: 19 additions & 8 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
from collections import defaultdict
Expand All @@ -7,7 +8,7 @@

import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -307,11 +308,13 @@ async def wrapper(*args, **kwargs):

return wrapper

@thread_cached_property
@property
@thread_cached
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)

@thread_cached_property
@property
@thread_cached
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)

Expand Down Expand Up @@ -516,7 +519,7 @@ async def set_graph_active_version(
user_id=user_id,
)

async def execute_graph(
def execute_graph(
self,
graph_id: str,
node_input: dict[Any, Any],
Expand All @@ -539,7 +542,9 @@ async def stop_graph_run(
404, detail=f"Agent execution #{graph_exec_id} not found"
)

self.execution_manager_client.cancel_execution(graph_exec_id)
await asyncio.to_thread(
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
)
majdyz marked this conversation as resolved.
Show resolved Hide resolved

# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
Expand Down Expand Up @@ -614,10 +619,16 @@ async def create_schedule(
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution_scheduler = self.execution_scheduler_client

return {
"id": execution_scheduler.add_execution_schedule(
graph_id, graph.version, cron, input_data, user_id=user_id
"id": await asyncio.to_thread(
lambda: self.execution_scheduler_client.add_execution_schedule(
graph_id=graph_id,
graph_version=graph.version,
cron=cron,
input_data=input_data,
user_id=user_id,
)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def block_autogen_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
majdyz marked this conversation as resolved.
Show resolved Hide resolved
test_graph.id, input_data, test_user.id
)
print(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def reddit_marketing_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/usecases/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def sample_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(
response = server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/test/executor/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def execute_graph(
num_execs: int = 4,
) -> str:
# --- Test adding new executions --- #
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
response = agent_server.execute_graph(test_graph.id, input_data, test_user.id)
graph_exec_id = response["id"]

# Execution queue should be empty
Expand Down
Loading