Skip to content

Commit

Permalink
feat: create an admin return all agents route (#1620)
Browse files Browse the repository at this point in the history
  • Loading branch information
4shub authored Aug 13, 2024
1 parent 925d251 commit 738978a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
7 changes: 7 additions & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,13 @@ def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def list_all_agents(self) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).all()

return [r.to_record() for r in results]

@enforce_types
def list_sources(self, user_id: uuid.UUID) -> List[Source]:
with self.session_maker() as session:
Expand Down
21 changes: 21 additions & 0 deletions memgpt/server/rest_api/admin/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from fastapi import APIRouter

from memgpt.server.rest_api.agents.index import ListAgentsResponse
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer

router = APIRouter()


def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def get_all_agents():
"""
Get a list of all agents in the database
"""
interface.clear()
agents_data = server.list_agents_legacy()

return ListAgentsResponse(**agents_data)

return router
7 changes: 6 additions & 1 deletion memgpt/server/rest_api/auth/index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter
Expand All @@ -13,6 +14,7 @@

class AuthResponse(BaseModel):
uuid: UUID = Field(..., description="UUID of the user")
is_admin: Optional[bool] = Field(None, description="Whether the user is an admin")


class AuthRequest(BaseModel):
Expand All @@ -29,10 +31,13 @@ def authenticate_user(request: AuthRequest) -> AuthResponse:
Currently, this is a placeholder that simply returns a UUID placeholder
"""
interface.clear()

is_admin = False
if request.password != password:
response = server.api_key_to_user(api_key=request.password)
else:
is_admin = True
response = server.authenticate_user()
return AuthResponse(uuid=response)
return AuthResponse(uuid=response, is_admin=is_admin)

return router
5 changes: 5 additions & 0 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
from memgpt.server.rest_api.admin.tools import setup_tools_index_router
from memgpt.server.rest_api.admin.users import setup_admin_router
from memgpt.server.rest_api.agents.command import setup_agents_command_router
Expand Down Expand Up @@ -69,6 +70,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security


ADMIN_PREFIX = "/admin"
ADMIN_API_PREFIX = "/api/admin"
API_PREFIX = "/api"
OPENAI_API_PREFIX = "/v1"

Expand All @@ -89,6 +91,9 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
app.include_router(setup_admin_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])
app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PREFIX, dependencies=[Depends(verify_password)])

# /api/admin/agents endpoints
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])

# /api/agents endpoints
app.include_router(setup_agents_command_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_config_router(server, interface, password), prefix=API_PREFIX)
Expand Down
17 changes: 11 additions & 6 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,13 +883,18 @@ def list_agents(
# TODO make return type pydantic
def list_agents_legacy(
self,
user_id: uuid.UUID,
user_id: Optional[uuid.UUID] = None,
) -> dict:
"""List all available agents to a user"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")

agents_states = self.ms.list_agents(user_id=user_id)
if user_id is None:
agents_states = self.ms.list_all_agents()
else:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")

agents_states = self.ms.list_agents(user_id=user_id)

agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states]

# TODO add a get_message_obj_from_message_id(...) function
Expand All @@ -900,7 +905,7 @@ def list_agents_legacy(
for agent_state, return_dict in zip(agents_states, agents_states_dicts):

# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
memgpt_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id)

# TODO remove this eventually when return type get pydanticfied
# this is to add persona_name and human_name so that the columns in UI can populate
Expand All @@ -918,7 +923,7 @@ def list_agents_legacy(
# get tool info from agent state
tools = []
for tool_name in agent_state.tools:
tool = self.ms.get_tool(tool_name, user_id)
tool = self.ms.get_tool(tool_name, agent_state.user_id)
tools.append(tool)
return_dict["tools"] = tools

Expand Down

0 comments on commit 738978a

Please sign in to comment.