Skip to content

Commit

Permalink
feat: support a new secure flag (#2030)
Browse files Browse the repository at this point in the history
Co-authored-by: Shubham Naik <shub@memgpt.ai>
  • Loading branch information
4shub and Shubham Naik authored Nov 13, 2024
1 parent 8d9adf2 commit 266df8b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions letta/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def server(
host: Annotated[Optional[str], typer.Option(help="Host to run the server on (default to localhost)")] = None,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
ade: Annotated[bool, typer.Option(help="Allows remote access")] = False,
secure: Annotated[bool, typer.Option(help="Adds simple security access")] = False,
):
"""Launch a Letta server process"""
if type == ServerChoice.rest_api:
Expand Down
27 changes: 27 additions & 0 deletions letta/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware

from letta.__init__ import __version__
Expand Down Expand Up @@ -94,6 +96,27 @@ def generate_openapi_schema(app: FastAPI):
Path(f"openapi_{name}.json").write_text(json.dumps(docs, indent=2))


# middleware that only allows requests to pass through if user provides a password thats randomly generated and stored in memory
def generate_password():
import secrets

return secrets.token_urlsafe(16)


random_password = generate_password()


class CheckPasswordMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
return await call_next(request)

return JSONResponse(
content={"detail": "Unauthorized"},
status_code=401,
)


def create_application() -> "FastAPI":
"""the application start routine"""
# global server
Expand All @@ -113,6 +136,10 @@ def create_application() -> "FastAPI":
settings.cors_origins.append("https://app.letta.com")
print(f"▶ View using ADE at: https://app.letta.com/local-project/agents")

if "--secure" in sys.argv:
print(f"▶ Using secure mode with password: {random_password}")
app.add_middleware(CheckPasswordMiddleware)

app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
Expand Down

0 comments on commit 266df8b

Please sign in to comment.