From f008c04678ab93d1f9634eb5cb6b499a9970e63b Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:05:56 -0500 Subject: [PATCH] feat: Create reset button to clear session cookies (#152) --- .../int.tests.cloudbuild.yaml | 21 +++++++++-- langchain_tools_demo/main.py | 37 +++++++++++++------ langchain_tools_demo/static/index.css | 22 ++++++++--- langchain_tools_demo/static/index.js | 13 +++++++ langchain_tools_demo/templates/index.html | 3 ++ 5 files changed, 77 insertions(+), 19 deletions(-) diff --git a/langchain_tools_demo/int.tests.cloudbuild.yaml b/langchain_tools_demo/int.tests.cloudbuild.yaml index 9a1a5f77..06e77f6f 100644 --- a/langchain_tools_demo/int.tests.cloudbuild.yaml +++ b/langchain_tools_demo/int.tests.cloudbuild.yaml @@ -34,10 +34,25 @@ steps: export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL) # Test `/` route - curl -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL + curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL - # Test `/chat`` route - curl -si --fail --show-error \ + # Test `/chat` route should fail + msg=$(curl -si --show-error \ + -X POST \ + -H "Authorization: Bearer $$ID_TOKEN" \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"How can you help me?"}' \ + $$URL/chat) + + if grep -q "400" <<< "$msg"; then + echo "Chat Handler Test: PASSED" + else + echo "Chat Handler Test: FAILED" + echo $msg && exit 1 + fi + + # Test `/chat` route + curl -b cookies.txt -si --fail --show-error \ -X POST \ -H "Authorization: Bearer $$ID_TOKEN" \ -H 'Content-Type: application/json' \ diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py index a78b5e31..211ba803 100644 --- a/langchain_tools_demo/main.py +++ b/langchain_tools_demo/main.py @@ -22,7 +22,6 @@ from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from langchain.agents.agent import AgentExecutor from markdown import markdown from starlette.middleware.sessions import SessionMiddleware @@ -52,12 +51,17 @@ async def lifespan(app: FastAPI): @app.get("/", response_class=HTMLResponse) -def index(request: Request): +async def index(request: Request): """Render the default template.""" - request.session.clear() # Clear chat history, if needed if "uuid" not in request.session: request.session["uuid"] = str(uuid.uuid4()) request.session["messages"] = BASE_HISTORY + # Agent setup + if request.session["uuid"] in user_agents: + user_agent = user_agents[request.session["uuid"]] + else: + user_agent = await init_agent() + user_agents[request.session["uuid"]] = user_agent return templates.TemplateResponse( "index.html", {"request": request, "messages": request.session["messages"]} ) @@ -71,17 +75,14 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): raise HTTPException(status_code=400, detail="Error: No user query") if "uuid" not in request.session: - request.session["uuid"] = str(uuid.uuid4()) - request.session["messages"] = BASE_HISTORY + raise HTTPException( + status_code=400, detail="Error: Invoke index handler before start chatting" + ) # Add user message to chat history request.session["messages"] += [{"role": "user", "content": prompt}] - # Agent setup - if request.session["uuid"] in user_agents: - user_agent = user_agents[request.session["uuid"]] - else: - user_agent = await init_agent() - user_agents[request.session["uuid"]] = user_agent + + user_agent = user_agents[request.session["uuid"]] try: # Send prompt to LLM response = await user_agent.agent.ainvoke({"input": prompt}) @@ -95,6 +96,20 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") +@app.post("/reset") +async def reset(request: Request): + """Reset agent""" + global user_agents + uuid = request.session["uuid"] + + if uuid not in user_agents.keys(): + raise HTTPException(status_code=500, detail=f"Current agent not found") + + await user_agents[uuid].client.close() + del user_agents[uuid] + request.session.clear() + + if __name__ == "__main__": PORT = int(os.getenv("PORT", default=8081)) uvicorn.run(app, host="0.0.0.0", port=PORT) diff --git a/langchain_tools_demo/static/index.css b/langchain_tools_demo/static/index.css index b51feeaa..4e945144 100644 --- a/langchain_tools_demo/static/index.css +++ b/langchain_tools_demo/static/index.css @@ -35,11 +35,18 @@ body { .chat-header { position: relative; + min-width: 650px; font-size: 16px; font-weight: 500; text-align: center; } +.chat-header span.reset-button { + position: absolute; + margin-right: 0px; + font-size: 38px; +} + .chat-wrapper { display: flex; flex-direction: column; @@ -77,11 +84,16 @@ div.chat-content>span { border: none; } -.chat-input-container span.btn-group { - position: relative; - margin: auto; - margin-right: 10px; - display: flex; +#resetButton { + font-size: 35px; + cursor: pointer; + position: absolute; + top: 47px; + right: 40px; +} + +#resetButton:hover { + background-color: #c9d4e9; } .chat-bubble { diff --git a/langchain_tools_demo/static/index.js b/langchain_tools_demo/static/index.js index 564660ca..3c7d8bb4 100644 --- a/langchain_tools_demo/static/index.js +++ b/langchain_tools_demo/static/index.js @@ -26,6 +26,11 @@ $(document).on("keypress",async (e) => { } }); +// Reset current user via click +$('#resetButton').click(async (e) => { + await reset(); +}); + async function submitMessage() { let msg = $('.chat-bar input').val(); // Add message to UI @@ -62,6 +67,14 @@ async function askQuestion(prompt) { } } +async function reset() { + await fetch('/reset', { + method: 'POST', + }).then(()=>{ + window.location.reload() + }) +} + // Helper function to print to chatroom function log(name, msg) { let message = `${msg}`; diff --git a/langchain_tools_demo/templates/index.html b/langchain_tools_demo/templates/index.html index cf94adde..7e77699b 100644 --- a/langchain_tools_demo/templates/index.html +++ b/langchain_tools_demo/templates/index.html @@ -33,10 +33,13 @@
+