Skip to content

Commit

Permalink
Merge pull request #25 from FacerAin/feat/agent
Browse files Browse the repository at this point in the history
[#14] Add ExecutorAgent
  • Loading branch information
FacerAin authored Dec 6, 2023
2 parents 52e4bff + b3b67db commit b756e0f
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 25 deletions.
2 changes: 1 addition & 1 deletion app/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .agent import ChatAgent
from .agent import ExecutorAgent
63 changes: 58 additions & 5 deletions app/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,75 @@
import datetime
import os

from langchain.agents import AgentExecutor, AgentType, LLMSingleActionAgent, Tool, initialize_agent
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI

from app.agent.context import SAMPLE_CONTEXT
from app.agent.prompts import system_prompt_template
from app.agent.parser import CustomAgentOutputParser
from app.agent.prompts import AgentPromptTemplate, agent_prompt_template, retriever_prompt_template
from app.agent.retriever import PineconeRetriever
from app.core.config import settings


class ChatAgent:
class ExecutorAgent:
def __init__(self):
self.retriever = PineconeRetriever(index_name="khugpt")

self.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0, openai_api_key=settings.OPENAI_API_KEY)
self.tools = [
Tool(
name="retreiver",
func=self.retriever.get_relevant_doc_string,
description="""useful for when you need to answer questions about campus internal information.
Please input the information in the form of a question that is easy to search for.
Since the database is stored in Korean, please ask in Korean.""",
),
Tool(
name="cafeterial_menu",
func=simple_meal_info,
description="If a user is looking for campus cafeterial menu information, use this information.",
),
]
self.agent_prompt = AgentPromptTemplate(
template=agent_prompt_template,
tools=self.tools,
input_variables=["input", "intermediate_steps"],
)
self.agent = initialize_agent(
tools=self.tools, llm=self.llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True, max_iterations=3
)
self.output_parser = CustomAgentOutputParser()
llm_chain = LLMChain(llm=self.llm, prompt=self.agent_prompt)
tool_names = [tool.name for tool in self.tools]
self.agent = LLMSingleActionAgent(
llm_chain=llm_chain, output_parser=self.output_parser, stop=["\nObservation:"], allowed_tools=tool_names
)

self.executor = AgentExecutor.from_agent_and_tools(
agent=self.agent, tools=self.tools, verbose=True, max_iterations=2
)

def run(self, query):
response = self.executor.run(query)
print(response)
return response


def simple_meal_info(query):
return """
If a user is looking for campus cafeterial menu information, use the link below. You should directly check the meal information from the link below.
경희대학교 학생 식당: https://www.khu.ac.kr/kor/forum/list.do?type=RESTAURANT&category=INTL&page=1
경희대학교 제 2기숙사 식당: https://dorm2.khu.ac.kr/40/4050.kmc
"""


class RetrieverAgent:
def __init__(self, index_name: str = "khugpt") -> None:
self.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0, openai_api_key=settings.OPENAI_API_KEY)
self.retreiver = PineconeRetriever(index_name=index_name)

def run(self, query: str):
context = self.retreiver.get_relevant_doc_string(query)
system_prompt = system_prompt_template.format(
system_prompt = retriever_prompt_template.format(
question=query, context=context, current_date=datetime.datetime.now()
)
answer = self.llm.predict(system_prompt)
Expand Down
27 changes: 27 additions & 0 deletions app/agent/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Union

import re

from langchain.agents import AgentOutputParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException


class CustomAgentOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# Check if agent should finish
if "Final Answer:" in llm_output:
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
log=llm_output,
)
# Parse out the action and action input
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")
action = match.group(1).strip()
action_input = match.group(2)
# Return the action and action input
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
65 changes: 53 additions & 12 deletions app/agent/prompts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from langchain.prompts import PromptTemplate
from typing import List

system_prompt_template = PromptTemplate(
from langchain.agents import Tool
from langchain.prompts import PromptTemplate, StringPromptTemplate

agent_prompt_template = """You are a helpful assistant for Kyung Hee University students.
Answer the following questions as best you can. If a page_url is provided in the document, please also provide a link to the related page.
You have access to the following tools:
{tools}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin! Remember to speak in a korean when giving your final answer.
Question: {input}
{agent_scratchpad}"""

retriever_prompt_template = PromptTemplate(
input_variables=["current_date", "context", "question"],
template="""
You are a helpful assistant for Kyung Hee University students.
Use the following pieces of context to answer the question at the end.
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Users tend to want up-to-date information, so please refer to the current date to answer. Unless otherwise instructed, use this year’s information whenever possible.
Each context is separated by [SEP].
Expand All @@ -15,18 +40,34 @@
When the context cannot be found, instead of saying that the context was not found, advise the user on suitable alternative actions.
Attach the page link of the corresponding context at the bottom of the answer.
If a user is looking for campus cafeterial menu information, use the link below. You should directly check the meal information from the link below.
경희대학교 학생 식당: https://www.khu.ac.kr/kor/forum/list.do?type=RESTAURANT&category=INTL&page=1
경희대학교 제 2기숙사 식당: https://dorm2.khu.ac.kr/40/4050.kmc
You must answer in Korean.
Current date: {current_date}
Contexts: {context}
Question: {question}
Query: {question}
Helpful answer:
""",
)


class AgentPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]

def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
return self.template.format(**kwargs)
4 changes: 2 additions & 2 deletions app/agent/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def _combine_documents(self, responses: List[Dict]) -> List[str]:
doc_string = DOCUMENT_SEPERATOR.join(docs)
return doc_string

def similarity_search(self, query: str, top_k: int = 10, **kwargs: Any):
def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any):
embeddings = self._embedding_model.embed_query(query)
responses = self._index.query([embeddings], top_k=top_k, include_metadata=True)
return responses

def get_relevant_doc_string(self, query: str, top_k: int = 10):
def get_relevant_doc_string(self, query: str, top_k: int = 5):
responses = self.similarity_search(query=query, top_k=top_k)
doc_string = self._combine_documents(responses=responses["matches"])
return doc_string
7 changes: 3 additions & 4 deletions app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from fastapi import APIRouter

from app.agent import ChatAgent
from app.agent import ExecutorAgent
from app.agent.retriever import PineconeRetriever
from app.schemas.chat import ResponseAnswer, ReuqestQuery

router = APIRouter()

chat_agent = ChatAgent()
agent = ExecutorAgent()


@router.post("/completion", response_model=ResponseAnswer)
def make_chat(req: ReuqestQuery):
answer = chat_agent.run(req.query)
answer = agent.run(req.query)
return {"answer": answer}


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
langchain==0.0.310
langchain==0.0.346
fastapi==0.103.0
uvicorn==0.23.2
pinecone-client==2.2.4
Expand Down

0 comments on commit b756e0f

Please sign in to comment.