diff --git a/memgpt/server/rest_api/app.py b/memgpt/server/rest_api/app.py index ed48232b8b..f53ad27abe 100644 --- a/memgpt/server/rest_api/app.py +++ b/memgpt/server/rest_api/app.py @@ -6,7 +6,8 @@ import typer import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from starlette.middleware.cors import CORSMiddleware from memgpt.server.constants import REST_DEFAULT_PORT @@ -38,8 +39,6 @@ # TODO(ethan) # NOTE(charles): @ethan I had to add this to get the global as the bottom to work interface: StreamingServerInterface = StreamingServerInterface -# global server -# server: SyncServer = None server = SyncServer(default_interface_factory=lambda: interface()) # TODO(ethan): eventuall remove @@ -77,6 +76,21 @@ def create_application() -> "FastAPI": allow_headers=["*"], ) + @app.middleware("http") + async def set_current_user_middleware(request: Request, call_next): + user_id = request.headers.get("user_id") + if user_id: + try: + server.set_current_user(user_id) + except ValueError as e: + # Return an HTTP 401 Unauthorized response + # raise HTTPException(status_code=401, detail=str(e)) + return JSONResponse(status_code=401, content={"detail": str(e)}) + else: + server.set_current_user(None) + response = await call_next(request) + return response + for route in v1_routes: app.include_router(route, prefix=API_PREFIX) # this gives undocumented routes for "latest" and bare api calls. diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 4eb80d8893..834117d046 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1771,6 +1771,20 @@ def retry_agent_message(self, agent_id: str) -> List[Message]: memgpt_agent = self._get_or_load_agent(agent_id=agent_id) return memgpt_agent.retry_message() + def set_current_user(self, user_id: Optional[str]): + """Very hacky way to set the current user for the server, to be replaced once server becomes stateless + + NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now + """ + + # Make sure the user_id actually exists + if user_id is not None: + user_obj = self.get_user(user_id) + if not user_obj: + raise ValueError(f"User with id {user_id} not found") + + self._current_user = user_id + # TODO(ethan) wire back to real method in future ORM PR def get_current_user(self) -> User: """Returns the currently authed user. @@ -1778,6 +1792,15 @@ def get_current_user(self) -> User: Since server is the core gateway this needs to pass through server as the first touchpoint. """ + + # Check if _current_user is set and if it's non-null: + if hasattr(self, "_current_user") and self._current_user is not None: + current_user = self.get_user(self._current_user) + if not current_user: + warnings.warn(f"Provided user '{self._current_user}' not found, using default user") + else: + return current_user + # NOTE: same code as local client to get the default user config = MemGPTConfig.load() user_id = config.anon_clientid