Skip to content

Commit

Permalink
Merge pull request #140 from langchain-ai/harrison/add-other-bots
Browse files Browse the repository at this point in the history
cr
  • Loading branch information
nfcampos authored Jan 27, 2024
2 parents e765031 + 18dc274 commit 8ffdb51
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 11 deletions.
60 changes: 58 additions & 2 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
TOOLS,
AvailableTools,
get_retrieval_tool,
get_retriever,
)
from app.chatbot import get_chatbot_executor
from app.retrieval import get_retrieval_executor


class AgentType(str, Enum):
Expand Down Expand Up @@ -149,15 +151,13 @@ def get_chatbot(
class ConfigurableChatBot(RunnableBinding):
llm: LLMType
system_message: str = DEFAULT_SYSTEM_MESSAGE
assistant_id: Optional[str] = None
user_id: Optional[str] = None

def __init__(
self,
*,
llm: LLMType = LLMType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
Expand All @@ -179,6 +179,61 @@ def __init__(
.configurable_fields(
llm=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
)


class ConfigurableRetrieval(RunnableBinding):
llm_type: LLMType
system_message: str = DEFAULT_SYSTEM_MESSAGE
assistant_id: Optional[str] = None
user_id: Optional[str] = None

def __init__(
self,
*,
llm_type: LLMType = LLMType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
) -> None:
others.pop("bound", None)
retriever = get_retriever(assistant_id)
checkpointer = RedisCheckpoint()
if llm_type == LLMType.GPT_35_TURBO:
llm = get_openai_llm()
elif llm_type == LLMType.GPT_4:
llm = get_openai_llm(gpt_4=True)
elif llm_type == LLMType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
elif llm_type == LLMType.CLAUDE2:
llm = get_anthropic_llm()
elif llm_type == LLMType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
elif llm_type == LLMType.GEMINI:
llm = get_google_llm()
elif llm_type == LLMType.MIXTRAL:
llm = get_mixtral_fireworks()
else:
raise ValueError("Unexpected llm type")
chatbot = get_retrieval_executor(llm, retriever, system_message, checkpointer)
super().__init__(
llm_type=llm_type,
system_message=system_message,
bound=chatbot,
kwargs=kwargs or {},
config=config or {},
)


chat_retrieval = (
ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=RedisCheckpoint())
.configurable_fields(
llm_type=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
Expand Down Expand Up @@ -216,6 +271,7 @@ def __init__(
default_key="assistant",
prefix_keys=True,
chatbot=chatbot,
chat_retrieval=chat_retrieval,
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
)
Expand Down
109 changes: 103 additions & 6 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,120 @@
import json

from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
FunctionMessage,
)
from langchain_core.runnables import chain
from langchain_core.retrievers import BaseRetriever
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langchain_core.prompts import PromptTemplate


search_prompt = PromptTemplate.from_template(
"""Given the conversation below, come up with a search query to look up.
This search query can be either a few words or question
Return ONLY this search query, nothing more.
>>> Conversation:
{conversation}
>>> END OF CONVERSATION
Remember, return ONLY the search query that will help you when formulating a response to the above conversation."""
)


response_prompt_template = """{instructions}
Respond to the user using ONLY the context provided below. Do not make anything up.
{context}"""


@chain
def get_search_query(llm, messages):
convo = []
for m in messages:
if isinstance(m, AIMessage):
if "function_call" not in m.additional_kwargs:
convo.append(f"AI: {m.content}")
if isinstance(m, HumanMessage):
convo.append(f"Human: {m.content}")
conversation = "\n".join(convo)
prompt = search_prompt.invoke({"conversation": conversation})
response = llm.invoke(prompt)
return response.content


def get_retrieval_executor(
llm: LanguageModelLike,
retriever: BaseRetriever,
system_message: str,
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
chat_history = []
for m in messages:
if isinstance(m, AIMessage):
if "function_call" not in m.additional_kwargs:
chat_history.append(m)
if isinstance(m, HumanMessage):
chat_history.append(m)
content = messages[-1].content
return [
SystemMessage(
content=response_prompt_template.format(
instructions=system_message, context=content
)
)
] + chat_history

def invoke_retrieval(messages):
if len(messages) == 1:
human_input = messages[-1].content
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "retrieval",
"arguments": json.dumps({"query": human_input}),
}
},
)
else:
search_query = get_search_query.invoke({"llm": llm, "messages": messages})
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "retrieval",
"arguments": json.dumps({"query": search_query}),
}
},
)

def retrieve(messages):
params = messages[-1].additional_kwargs["function_call"]
query = json.loads(params["arguments"])["query"]
response = retriever.invoke(query)
content = "\n".join([d.page_content for d in response])
return FunctionMessage(name="retrieval", content=content)

chatbot = _get_messages | llm
response = _get_messages | llm

workflow = MessageGraph()
workflow.add_node("chatbot", chatbot)
workflow.set_entry_point("chatbot")
workflow.add_edge("chatbot", END)
workflow.add_node("invoke_retrieval", invoke_retrieval)
workflow.add_node("retrieve", retrieve)
workflow.add_node("response", response)
workflow.set_entry_point("invoke_retrieval")
workflow.add_edge("invoke_retrieval", "retrieve")
workflow.add_edge("retrieve", "response")
workflow.add_edge("response", END)
app = workflow.compile(checkpointer=checkpoint)
return app
10 changes: 7 additions & 3 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class PythonREPLInput(BaseModel):
If the user is referencing particular files, that is often a good hint that information may be here."""


def get_retriever(assistant_id: str):
return vstore.as_retriever(
search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id}
)


def get_retrieval_tool(assistant_id: str, description: str):
return create_retriever_tool(
vstore.as_retriever(
search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id}
),
get_retriever(assistant_id),
"Retriever",
description,
)
Expand Down

0 comments on commit 8ffdb51

Please sign in to comment.