Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cr #140

Merged
merged 1 commit into from
Jan 27, 2024
Merged

cr #140

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
TOOLS,
AvailableTools,
get_retrieval_tool,
get_retriever,
)
from app.chatbot import get_chatbot_executor
from app.retrieval import get_retrieval_executor


class AgentType(str, Enum):
Expand Down Expand Up @@ -149,15 +151,13 @@ def get_chatbot(
class ConfigurableChatBot(RunnableBinding):
llm: LLMType
system_message: str = DEFAULT_SYSTEM_MESSAGE
assistant_id: Optional[str] = None
user_id: Optional[str] = None

def __init__(
self,
*,
llm: LLMType = LLMType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
Expand All @@ -179,6 +179,61 @@ def __init__(
.configurable_fields(
llm=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
)


class ConfigurableRetrieval(RunnableBinding):
llm_type: LLMType
system_message: str = DEFAULT_SYSTEM_MESSAGE
assistant_id: Optional[str] = None
user_id: Optional[str] = None

def __init__(
self,
*,
llm_type: LLMType = LLMType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[Mapping[str, Any]] = None,
**others: Any,
) -> None:
others.pop("bound", None)
retriever = get_retriever(assistant_id)
checkpointer = RedisCheckpoint()
if llm_type == LLMType.GPT_35_TURBO:
llm = get_openai_llm()
elif llm_type == LLMType.GPT_4:
llm = get_openai_llm(gpt_4=True)
elif llm_type == LLMType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
elif llm_type == LLMType.CLAUDE2:
llm = get_anthropic_llm()
elif llm_type == LLMType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
elif llm_type == LLMType.GEMINI:
llm = get_google_llm()
elif llm_type == LLMType.MIXTRAL:
llm = get_mixtral_fireworks()
else:
raise ValueError("Unexpected llm type")
chatbot = get_retrieval_executor(llm, retriever, system_message, checkpointer)
super().__init__(
llm_type=llm_type,
system_message=system_message,
bound=chatbot,
kwargs=kwargs or {},
config=config or {},
)


chat_retrieval = (
ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=RedisCheckpoint())
.configurable_fields(
llm_type=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
Expand Down Expand Up @@ -216,6 +271,7 @@ def __init__(
default_key="assistant",
prefix_keys=True,
chatbot=chatbot,
chat_retrieval=chat_retrieval,
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
)
Expand Down
109 changes: 103 additions & 6 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,120 @@
import json

from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.messages import (
SystemMessage,
HumanMessage,
AIMessage,
FunctionMessage,
)
from langchain_core.runnables import chain
from langchain_core.retrievers import BaseRetriever
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph import END
from langgraph.graph.message import MessageGraph
from langchain_core.prompts import PromptTemplate


search_prompt = PromptTemplate.from_template(
"""Given the conversation below, come up with a search query to look up.

This search query can be either a few words or question

Return ONLY this search query, nothing more.

>>> Conversation:
{conversation}
>>> END OF CONVERSATION

Remember, return ONLY the search query that will help you when formulating a response to the above conversation."""
)


response_prompt_template = """{instructions}

Respond to the user using ONLY the context provided below. Do not make anything up.

{context}"""


@chain
def get_search_query(llm, messages):
convo = []
for m in messages:
if isinstance(m, AIMessage):
if "function_call" not in m.additional_kwargs:
convo.append(f"AI: {m.content}")
if isinstance(m, HumanMessage):
convo.append(f"Human: {m.content}")
conversation = "\n".join(convo)
prompt = search_prompt.invoke({"conversation": conversation})
response = llm.invoke(prompt)
return response.content


def get_retrieval_executor(
llm: LanguageModelLike,
retriever: BaseRetriever,
system_message: str,
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages
chat_history = []
for m in messages:
if isinstance(m, AIMessage):
if "function_call" not in m.additional_kwargs:
chat_history.append(m)
if isinstance(m, HumanMessage):
chat_history.append(m)
content = messages[-1].content
return [
SystemMessage(
content=response_prompt_template.format(
instructions=system_message, context=content
)
)
] + chat_history

def invoke_retrieval(messages):
if len(messages) == 1:
human_input = messages[-1].content
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "retrieval",
"arguments": json.dumps({"query": human_input}),
}
},
)
else:
search_query = get_search_query.invoke({"llm": llm, "messages": messages})
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "retrieval",
"arguments": json.dumps({"query": search_query}),
}
},
)

def retrieve(messages):
params = messages[-1].additional_kwargs["function_call"]
query = json.loads(params["arguments"])["query"]
response = retriever.invoke(query)
content = "\n".join([d.page_content for d in response])
return FunctionMessage(name="retrieval", content=content)

chatbot = _get_messages | llm
response = _get_messages | llm

workflow = MessageGraph()
workflow.add_node("chatbot", chatbot)
workflow.set_entry_point("chatbot")
workflow.add_edge("chatbot", END)
workflow.add_node("invoke_retrieval", invoke_retrieval)
workflow.add_node("retrieve", retrieve)
workflow.add_node("response", response)
workflow.set_entry_point("invoke_retrieval")
workflow.add_edge("invoke_retrieval", "retrieve")
workflow.add_edge("retrieve", "response")
workflow.add_edge("response", END)
app = workflow.compile(checkpointer=checkpoint)
return app
10 changes: 7 additions & 3 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class PythonREPLInput(BaseModel):
If the user is referencing particular files, that is often a good hint that information may be here."""


def get_retriever(assistant_id: str):
return vstore.as_retriever(
search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id}
)


def get_retrieval_tool(assistant_id: str, description: str):
return create_retriever_tool(
vstore.as_retriever(
search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id}
),
get_retriever(assistant_id),
"Retriever",
description,
)
Expand Down
Loading