Skip to content

Commit

Permalink
feat: Create individual user client session (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
duwenxin99 authored Dec 20, 2023
1 parent a61d8fd commit 359e5d3
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 180 deletions.
97 changes: 93 additions & 4 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,116 @@
# limitations under the License.

import os
from datetime import date, timedelta
from typing import Dict, Optional

import aiohttp
import dateutil.parser as dparser
import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import ChatPromptTemplate

from tools import convert_date, tools
from tools import initialize_tools

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")

# aiohttp context
connector = None


# Class for setting up a dedicated llm agent for each individual user
class UserAgent:
client: aiohttp.ClientSession
agent: AgentExecutor

def __init__(self, client, agent) -> None:
self.client = client
self.agent = agent


user_agents: Dict[str, UserAgent] = {}


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
if os.getenv("K_SERVICE"):
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)
else:
# Use gcloud credentials locally
import subprocess

return (
subprocess.run(
["gcloud", "auth", "print-identity-token"],
stdout=subprocess.PIPE,
check=True,
)
.stdout.strip()
.decode()
)


def convert_date(date_string: str) -> str:
"""Convert date into appropriate date string"""
if date_string == "tomorrow":
converted = date.today() + timedelta(1)
elif date_string == "yesterday":
converted = date.today() - timedelta(1)
elif date_string != "null" and date_string != "today" and date_string is not None:
converted = dparser.parse(date_string, fuzzy=True).date()
else:
converted = date.today()

return converted.strftime("%Y-%m-%d")


def get_header() -> Optional[dict]:
if "http://" in BASE_URL:
return None
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
return headers


async def get_connector():
global connector
if connector is None:
connector = aiohttp.TCPConnector(limit=100)
return connector


async def handle_error_response(response):
if response.status != 200:
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"


async def create_client_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(
connector=await get_connector(),
headers=get_header(),
raise_for_status=handle_error_response,
)


# Agent
def init_agent() -> AgentExecutor:
async def init_agent() -> UserAgent:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512)
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
)

client = await create_client_session()
tools = await initialize_tools(client)
agent = initialize_agent(
tools,
llm,
Expand All @@ -59,7 +147,8 @@ def init_agent() -> AgentExecutor:
[("system", template), ("human", human_message_template)]
)
agent.agent.llm_chain.prompt = prompt # type: ignore
return agent

return UserAgent(client, agent)


PREFIX = """SFO Airport Assistant helps travelers find their way at the airport.
Expand Down
42 changes: 24 additions & 18 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from contextlib import asynccontextmanager

import uvicorn
from fastapi import Body, FastAPI, HTTPException, Request
Expand All @@ -24,27 +26,31 @@
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

from agent import init_agent
from tools import session
from agent import init_agent, user_agents

app = FastAPI()

@asynccontextmanager
async def lifespan(app: FastAPI):
# FastAPI app startup event
print("Loading application...")
yield
# FastAPI app shutdown event
close_client_tasks = [
asyncio.create_task(c.client.close()) for c in user_agents.values()
]

asyncio.gather(*close_client_tasks)


# FastAPI setup
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
# TODO: set secret_key for production
app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY")
templates = Jinja2Templates(directory="templates")

agents: dict[str, AgentExecutor] = {}
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]


async def on_shutdown():
if session is not None:
await session.close()


app.add_event_handler("shutdown", on_shutdown)


@app.get("/", response_class=HTMLResponse)
def index(request: Request):
"""Render the default template."""
Expand All @@ -71,14 +77,14 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
# Add user message to chat history
request.session["messages"] += [{"role": "user", "content": prompt}]
# Agent setup
if request.session["uuid"] in agents:
agent = agents[request.session["uuid"]]
if request.session["uuid"] in user_agents:
user_agent = user_agents[request.session["uuid"]]
else:
agent = init_agent()
agents[request.session["uuid"]] = agent
user_agent = await init_agent()
user_agents[request.session["uuid"]] = user_agent
try:
# Send prompt to LLM
response = await agent.ainvoke({"input": prompt})
response = await user_agent.agent.ainvoke({"input": prompt})
request.session["messages"] += [
{"role": "assistant", "content": response["output"]}
]
Expand Down
Loading

0 comments on commit 359e5d3

Please sign in to comment.