diff --git a/backend/app/agent.py b/backend/app/agent.py index 16e277cd..b69e1c39 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -24,8 +24,10 @@ TOOLS, AvailableTools, get_retrieval_tool, + get_retriever, ) from app.chatbot import get_chatbot_executor +from app.retrieval import get_retrieval_executor class AgentType(str, Enum): @@ -149,7 +151,6 @@ def get_chatbot( class ConfigurableChatBot(RunnableBinding): llm: LLMType system_message: str = DEFAULT_SYSTEM_MESSAGE - assistant_id: Optional[str] = None user_id: Optional[str] = None def __init__( @@ -157,7 +158,6 @@ def __init__( *, llm: LLMType = LLMType.GPT_35_TURBO, system_message: str = DEFAULT_SYSTEM_MESSAGE, - assistant_id: Optional[str] = None, kwargs: Optional[Mapping[str, Any]] = None, config: Optional[Mapping[str, Any]] = None, **others: Any, @@ -179,6 +179,61 @@ def __init__( .configurable_fields( llm=ConfigurableField(id="llm_type", name="LLM Type"), system_message=ConfigurableField(id="system_message", name="System Message"), + ) + .with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage]) +) + + +class ConfigurableRetrieval(RunnableBinding): + llm_type: LLMType + system_message: str = DEFAULT_SYSTEM_MESSAGE + assistant_id: Optional[str] = None + user_id: Optional[str] = None + + def __init__( + self, + *, + llm_type: LLMType = LLMType.GPT_35_TURBO, + system_message: str = DEFAULT_SYSTEM_MESSAGE, + assistant_id: Optional[str] = None, + kwargs: Optional[Mapping[str, Any]] = None, + config: Optional[Mapping[str, Any]] = None, + **others: Any, + ) -> None: + others.pop("bound", None) + retriever = get_retriever(assistant_id) + checkpointer = RedisCheckpoint() + if llm_type == LLMType.GPT_35_TURBO: + llm = get_openai_llm() + elif llm_type == LLMType.GPT_4: + llm = get_openai_llm(gpt_4=True) + elif llm_type == LLMType.AZURE_OPENAI: + llm = get_openai_llm(azure=True) + elif llm_type == LLMType.CLAUDE2: + llm = get_anthropic_llm() + elif llm_type == LLMType.BEDROCK_CLAUDE2: + llm = get_anthropic_llm(bedrock=True) + elif llm_type == LLMType.GEMINI: + llm = get_google_llm() + elif llm_type == LLMType.MIXTRAL: + llm = get_mixtral_fireworks() + else: + raise ValueError("Unexpected llm type") + chatbot = get_retrieval_executor(llm, retriever, system_message, checkpointer) + super().__init__( + llm_type=llm_type, + system_message=system_message, + bound=chatbot, + kwargs=kwargs or {}, + config=config or {}, + ) + + +chat_retrieval = ( + ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=RedisCheckpoint()) + .configurable_fields( + llm_type=ConfigurableField(id="llm_type", name="LLM Type"), + system_message=ConfigurableField(id="system_message", name="System Message"), assistant_id=ConfigurableField( id="assistant_id", name="Assistant ID", is_shared=True ), @@ -216,6 +271,7 @@ def __init__( default_key="assistant", prefix_keys=True, chatbot=chatbot, + chat_retrieval=chat_retrieval, ) .with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage]) ) diff --git a/backend/app/retrieval.py b/backend/app/retrieval.py index e86ec5c2..3fd6633a 100644 --- a/backend/app/retrieval.py +++ b/backend/app/retrieval.py @@ -1,23 +1,120 @@ +import json + from langchain_core.language_models.base import LanguageModelLike -from langchain_core.messages import SystemMessage +from langchain_core.messages import ( + SystemMessage, + HumanMessage, + AIMessage, + FunctionMessage, +) +from langchain_core.runnables import chain +from langchain_core.retrievers import BaseRetriever from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END from langgraph.graph.message import MessageGraph +from langchain_core.prompts import PromptTemplate + + +search_prompt = PromptTemplate.from_template( + """Given the conversation below, come up with a search query to look up. + +This search query can be either a few words or question + +Return ONLY this search query, nothing more. + +>>> Conversation: +{conversation} +>>> END OF CONVERSATION + +Remember, return ONLY the search query that will help you when formulating a response to the above conversation.""" +) + + +response_prompt_template = """{instructions} + +Respond to the user using ONLY the context provided below. Do not make anything up. + +{context}""" + + +@chain +def get_search_query(llm, messages): + convo = [] + for m in messages: + if isinstance(m, AIMessage): + if "function_call" not in m.additional_kwargs: + convo.append(f"AI: {m.content}") + if isinstance(m, HumanMessage): + convo.append(f"Human: {m.content}") + conversation = "\n".join(convo) + prompt = search_prompt.invoke({"conversation": conversation}) + response = llm.invoke(prompt) + return response.content def get_retrieval_executor( llm: LanguageModelLike, + retriever: BaseRetriever, system_message: str, checkpoint: BaseCheckpointSaver, ): def _get_messages(messages): - return [SystemMessage(content=system_message)] + messages + chat_history = [] + for m in messages: + if isinstance(m, AIMessage): + if "function_call" not in m.additional_kwargs: + chat_history.append(m) + if isinstance(m, HumanMessage): + chat_history.append(m) + content = messages[-1].content + return [ + SystemMessage( + content=response_prompt_template.format( + instructions=system_message, context=content + ) + ) + ] + chat_history + + def invoke_retrieval(messages): + if len(messages) == 1: + human_input = messages[-1].content + return AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "retrieval", + "arguments": json.dumps({"query": human_input}), + } + }, + ) + else: + search_query = get_search_query.invoke({"llm": llm, "messages": messages}) + return AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "retrieval", + "arguments": json.dumps({"query": search_query}), + } + }, + ) + + def retrieve(messages): + params = messages[-1].additional_kwargs["function_call"] + query = json.loads(params["arguments"])["query"] + response = retriever.invoke(query) + content = "\n".join([d.page_content for d in response]) + return FunctionMessage(name="retrieval", content=content) - chatbot = _get_messages | llm + response = _get_messages | llm workflow = MessageGraph() - workflow.add_node("chatbot", chatbot) - workflow.set_entry_point("chatbot") - workflow.add_edge("chatbot", END) + workflow.add_node("invoke_retrieval", invoke_retrieval) + workflow.add_node("retrieve", retrieve) + workflow.add_node("response", response) + workflow.set_entry_point("invoke_retrieval") + workflow.add_edge("invoke_retrieval", "retrieve") + workflow.add_edge("retrieve", "response") + workflow.add_edge("response", END) app = workflow.compile(checkpointer=checkpoint) return app diff --git a/backend/app/tools.py b/backend/app/tools.py index 74e0eecf..2c37a362 100644 --- a/backend/app/tools.py +++ b/backend/app/tools.py @@ -33,11 +33,15 @@ class PythonREPLInput(BaseModel): If the user is referencing particular files, that is often a good hint that information may be here.""" +def get_retriever(assistant_id: str): + return vstore.as_retriever( + search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id} + ) + + def get_retrieval_tool(assistant_id: str, description: str): return create_retriever_tool( - vstore.as_retriever( - search_kwargs={"filter": RedisFilter.tag("namespace") == assistant_id} - ), + get_retriever(assistant_id), "Retriever", description, )