Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for user_id in header #1755

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions memgpt/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.cors import CORSMiddleware

from memgpt.server.constants import REST_DEFAULT_PORT
Expand Down Expand Up @@ -38,8 +39,6 @@
# TODO(ethan)
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
# global server
# server: SyncServer = None
server = SyncServer(default_interface_factory=lambda: interface())

# TODO(ethan): eventuall remove
Expand Down Expand Up @@ -77,6 +76,21 @@ def create_application() -> "FastAPI":
allow_headers=["*"],
)

@app.middleware("http")
async def set_current_user_middleware(request: Request, call_next):
user_id = request.headers.get("user_id")
if user_id:
try:
server.set_current_user(user_id)
except ValueError as e:
# Return an HTTP 401 Unauthorized response
# raise HTTPException(status_code=401, detail=str(e))
return JSONResponse(status_code=401, content={"detail": str(e)})
else:
server.set_current_user(None)
response = await call_next(request)
return response

for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.
Expand Down
23 changes: 23 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,13 +1771,36 @@ def retry_agent_message(self, agent_id: str) -> List[Message]:
memgpt_agent = self._get_or_load_agent(agent_id=agent_id)
return memgpt_agent.retry_message()

def set_current_user(self, user_id: Optional[str]):
"""Very hacky way to set the current user for the server, to be replaced once server becomes stateless

NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now
"""

# Make sure the user_id actually exists
if user_id is not None:
user_obj = self.get_user(user_id)
if not user_obj:
raise ValueError(f"User with id {user_id} not found")

self._current_user = user_id

# TODO(ethan) wire back to real method in future ORM PR
def get_current_user(self) -> User:
"""Returns the currently authed user.

Since server is the core gateway this needs to pass through server as the
first touchpoint.
"""

# Check if _current_user is set and if it's non-null:
if hasattr(self, "_current_user") and self._current_user is not None:
current_user = self.get_user(self._current_user)
if not current_user:
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
else:
return current_user

# NOTE: same code as local client to get the default user
config = MemGPTConfig.load()
user_id = config.anon_clientid
Expand Down
Loading