diff --git a/docs/docs/use_cases/agent_simulations/violation_of_expectations_chain.mdx b/docs/docs/use_cases/agent_simulations/violation_of_expectations_chain.mdx new file mode 100644 index 000000000000..31226a9b7d66 --- /dev/null +++ b/docs/docs/use_cases/agent_simulations/violation_of_expectations_chain.mdx @@ -0,0 +1,67 @@ +# Violation of Expectations Chain + +This page demonstrates how to use the `ViolationOfExpectationsChain`. This chain extracts insights from chat conversations +by comparing the differences between an LLM's prediction of the next message in a conversation and the user's mental state against the actual next message, +and is intended to provide a form of reflection for long-term memory. + +The `ViolationOfExpectationsChain` was implemented using the results of a paper by [Plastic Labs](https://plasticlabs.ai/). Their paper, `Violation of Expectation via Metacognitive Prompting Reduces +Theory of Mind Prediction Error in Large Language Models` can be found [here](https://arxiv.org/abs/2310.06983). + +## Usage + +The below example features a chat between a human and an AI, talking about a journal entry the user made. + +import CodeBlock from "@theme/CodeBlock"; +import ViolationOfExpectationsChainExample from "@examples/use_cases/advanced/violation_of_expectations_chain.ts"; + + + {ViolationOfExpectationsChainExample} + + +## Explanation + +Now let's go over everything the chain is doing, step by step. + +Under the hood, the `ViolationOfExpectationsChain` performs four main steps: + +### Step 1. Predict the user's next message using only the chat history. + +The LLM is tasked with generating three key pieces of information: + +- Concise reasoning about the users internal mental state. +- A prediction on how they will respond to the AI's most recent message. +- A concise list of any additional insights that would be useful to improve prediction. + Once the LLM response is returned, we query our retriever with the insights, mapping over all. + From each result we extract the first retrieved document, and return it. + Then, all retrieved documents and generated insights are sorted to remove duplicates, and returned. + +### Step 2. Generate prediction violations. + +Using the results from step 1, we query the LLM to generate the following: + +- How exactly was the original prediction violated? Which parts were wrong? State the exact differences. +- If there were errors with the prediction, what were they and why? + We pass the LLM our predicted response and generated (along with any retrieved) insights from step 1, and the actual response from the user. + +Once we have the difference between the predicted and actual response, we can move on to step 3. + +### Step 3. Regenerate the prediction. + +Using the original prediction, key insights and the differences between the actual response and our prediction, we can generate a new more accurate prediction. +These predictions will help us in the next step to generate an insight that isn't just parts of the user's conversation verbatim. + +### Step 4. Generate an insight. + +Lastly, we prompt the LLM to generate one concise insight given the following context: + +- Ways in which our original prediction was violated. +- Our generated revised prediction (step 3) +- The actual response from the user. + Given these three data points, we prompt the LLM to return one fact relevant to the specific user response. + A key point here is giving it the ways in which our original prediction was violated. This list contains the exact differences --and often specific facts themselves-- between the predicted and actual response. + +We perform these steps on every human message, so if you have a conversation with 10 messages (5 human 5 AI), you'll get 5 insights. +The list of messages are chunked by iterating over the entire chat history, stopping at an AI message and returning it, along with all messages that preceded it. + +Once our `.call({...})` method returns the array of insights, we can save them to our vector store. +Later, we can retrieve them in future insight generations, or for other reasons like insightful context in a chat bot. diff --git a/environment_tests/test-exports-bun/src/entrypoints.js b/environment_tests/test-exports-bun/src/entrypoints.js index 2bdf86e1754a..9c36721f05a5 100644 --- a/environment_tests/test-exports-bun/src/entrypoints.js +++ b/environment_tests/test-exports-bun/src/entrypoints.js @@ -82,5 +82,6 @@ export * from "langchain/experimental/babyagi"; export * from "langchain/experimental/generative_agents"; export * from "langchain/experimental/plan_and_execute"; export * from "langchain/experimental/chat_models/bittensor"; +export * from "langchain/experimental/chains/violation_of_expectations"; export * from "langchain/evaluation"; export * from "langchain/runnables/remote"; diff --git a/environment_tests/test-exports-cf/src/entrypoints.js b/environment_tests/test-exports-cf/src/entrypoints.js index 2bdf86e1754a..9c36721f05a5 100644 --- a/environment_tests/test-exports-cf/src/entrypoints.js +++ b/environment_tests/test-exports-cf/src/entrypoints.js @@ -82,5 +82,6 @@ export * from "langchain/experimental/babyagi"; export * from "langchain/experimental/generative_agents"; export * from "langchain/experimental/plan_and_execute"; export * from "langchain/experimental/chat_models/bittensor"; +export * from "langchain/experimental/chains/violation_of_expectations"; export * from "langchain/evaluation"; export * from "langchain/runnables/remote"; diff --git a/environment_tests/test-exports-cjs/src/entrypoints.js b/environment_tests/test-exports-cjs/src/entrypoints.js index 6dc031278b3e..41cacf76314b 100644 --- a/environment_tests/test-exports-cjs/src/entrypoints.js +++ b/environment_tests/test-exports-cjs/src/entrypoints.js @@ -82,5 +82,6 @@ const experimental_babyagi = require("langchain/experimental/babyagi"); const experimental_generative_agents = require("langchain/experimental/generative_agents"); const experimental_plan_and_execute = require("langchain/experimental/plan_and_execute"); const experimental_chat_models_bittensor = require("langchain/experimental/chat_models/bittensor"); +const experimental_chains_violation_of_expectations = require("langchain/experimental/chains/violation_of_expectations"); const evaluation = require("langchain/evaluation"); const runnables_remote = require("langchain/runnables/remote"); diff --git a/environment_tests/test-exports-esbuild/src/entrypoints.js b/environment_tests/test-exports-esbuild/src/entrypoints.js index ff7933bb6f16..cde0f3318c55 100644 --- a/environment_tests/test-exports-esbuild/src/entrypoints.js +++ b/environment_tests/test-exports-esbuild/src/entrypoints.js @@ -82,5 +82,6 @@ import * as experimental_babyagi from "langchain/experimental/babyagi"; import * as experimental_generative_agents from "langchain/experimental/generative_agents"; import * as experimental_plan_and_execute from "langchain/experimental/plan_and_execute"; import * as experimental_chat_models_bittensor from "langchain/experimental/chat_models/bittensor"; +import * as experimental_chains_violation_of_expectations from "langchain/experimental/chains/violation_of_expectations"; import * as evaluation from "langchain/evaluation"; import * as runnables_remote from "langchain/runnables/remote"; diff --git a/environment_tests/test-exports-esm/src/entrypoints.js b/environment_tests/test-exports-esm/src/entrypoints.js index ff7933bb6f16..cde0f3318c55 100644 --- a/environment_tests/test-exports-esm/src/entrypoints.js +++ b/environment_tests/test-exports-esm/src/entrypoints.js @@ -82,5 +82,6 @@ import * as experimental_babyagi from "langchain/experimental/babyagi"; import * as experimental_generative_agents from "langchain/experimental/generative_agents"; import * as experimental_plan_and_execute from "langchain/experimental/plan_and_execute"; import * as experimental_chat_models_bittensor from "langchain/experimental/chat_models/bittensor"; +import * as experimental_chains_violation_of_expectations from "langchain/experimental/chains/violation_of_expectations"; import * as evaluation from "langchain/evaluation"; import * as runnables_remote from "langchain/runnables/remote"; diff --git a/environment_tests/test-exports-vercel/src/entrypoints.js b/environment_tests/test-exports-vercel/src/entrypoints.js index 2bdf86e1754a..9c36721f05a5 100644 --- a/environment_tests/test-exports-vercel/src/entrypoints.js +++ b/environment_tests/test-exports-vercel/src/entrypoints.js @@ -82,5 +82,6 @@ export * from "langchain/experimental/babyagi"; export * from "langchain/experimental/generative_agents"; export * from "langchain/experimental/plan_and_execute"; export * from "langchain/experimental/chat_models/bittensor"; +export * from "langchain/experimental/chains/violation_of_expectations"; export * from "langchain/evaluation"; export * from "langchain/runnables/remote"; diff --git a/environment_tests/test-exports-vite/src/entrypoints.js b/environment_tests/test-exports-vite/src/entrypoints.js index 2bdf86e1754a..9c36721f05a5 100644 --- a/environment_tests/test-exports-vite/src/entrypoints.js +++ b/environment_tests/test-exports-vite/src/entrypoints.js @@ -82,5 +82,6 @@ export * from "langchain/experimental/babyagi"; export * from "langchain/experimental/generative_agents"; export * from "langchain/experimental/plan_and_execute"; export * from "langchain/experimental/chat_models/bittensor"; +export * from "langchain/experimental/chains/violation_of_expectations"; export * from "langchain/evaluation"; export * from "langchain/runnables/remote"; diff --git a/examples/src/use_cases/advanced/violation_of_expectations_chain.ts b/examples/src/use_cases/advanced/violation_of_expectations_chain.ts new file mode 100644 index 000000000000..6cb9eee3357e --- /dev/null +++ b/examples/src/use_cases/advanced/violation_of_expectations_chain.ts @@ -0,0 +1,77 @@ +import { ViolationOfExpectationsChain } from "langchain/experimental/chains/violation_of_expectations"; +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { AIMessage, HumanMessage } from "langchain/schema"; +import { HNSWLib } from "langchain/vectorstores/hnswlib"; + +// Short GPT generated conversation between a human and an AI. +const dummyMessages = [ + new HumanMessage( + "I've been thinking about the importance of time with myself to discover my voice. I feel like 1-2 hours is never enough." + ), + new AIMessage( + "The concept of 'adequate time' varies. Have you tried different formats of introspection, such as morning pages or long-form writing, to see if they make the process more efficient?" + ), + new HumanMessage( + "I have tried journaling but never consistently. Sometimes it feels like writing doesn't capture everything." + ), + new AIMessage( + "Writing has its limits. What about other mediums like digital art, or interactive journal apps with dynamic prompts that dig deeper? Even coding a personal project can be a form of self-discovery." + ), + new HumanMessage( + "That's an interesting idea. I've never thought about coding as a form of self-discovery." + ), + new AIMessage( + "Since you're comfortable with code, consider building a tool to log and analyze your emotional state, thoughts, or personal growth metrics. It merges skill with introspection, makes the data quantifiable." + ), + new HumanMessage( + "The idea of quantifying emotions and personal growth is fascinating. But I wonder how much it can really capture the 'dark zone' within us." + ), + new AIMessage( + "Good point. The 'dark zone' isn't fully quantifiable. But a tool could serve as a scaffold to explore those areas. It gives a structured approach to an unstructured problem." + ), + new HumanMessage( + "You might be onto something. A structured approach could help unearth patterns or triggers I hadn't noticed." + ), + new AIMessage( + "Exactly. It's about creating a framework to understand what can't easily be understood. Then you can allocate those 5+ hours more effectively, targeting areas that your data flags." + ), +]; + +// Instantiate with an empty string to start, since we have no data yet. +const vectorStore = await HNSWLib.fromTexts( + [" "], + [{ id: 1 }], + new OpenAIEmbeddings() +); +const retriever = vectorStore.asRetriever(); + +// Instantiate the LLM, +const llm = new ChatOpenAI({ + modelName: "gpt-4", +}); + +// And the chain. +const voeChain = ViolationOfExpectationsChain.fromLLM(llm, retriever); + +// Requires an input key of "chat_history" with an array of messages. +const result = await voeChain.call({ + chat_history: dummyMessages, +}); + +console.log({ + result, +}); + +/** + * Output: +{ + result: [ + 'The user has experience with coding and has tried journaling before, but struggles with maintaining consistency and fully expressing their thoughts and feelings through writing.', + 'The user shows a thoughtful approach towards new concepts and is willing to engage with and contemplate novel ideas before making a decision. They also consider time effectiveness as a crucial factor in their decision-making process.', + 'The user is curious and open-minded about new concepts, but also values introspection and self-discovery in understanding emotions and personal growth.', + 'The user is open to new ideas and strategies, specifically those that involve a structured approach to identifying patterns or triggers.', + 'The user may not always respond or engage with prompts, indicating a need for varied and adaptable communication strategies.' + ] +} + */ diff --git a/langchain/.gitignore b/langchain/.gitignore index 14553254749f..d1f195bd7c98 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -691,6 +691,9 @@ experimental/llms/bittensor.d.ts experimental/hubs/makersuite/googlemakersuitehub.cjs experimental/hubs/makersuite/googlemakersuitehub.js experimental/hubs/makersuite/googlemakersuitehub.d.ts +experimental/chains/violation_of_expectations.cjs +experimental/chains/violation_of_expectations.js +experimental/chains/violation_of_expectations.d.ts evaluation.cjs evaluation.js evaluation.d.ts diff --git a/langchain/package.json b/langchain/package.json index 933bc97ac713..face90177bfa 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -703,6 +703,9 @@ "experimental/hubs/makersuite/googlemakersuitehub.cjs", "experimental/hubs/makersuite/googlemakersuitehub.js", "experimental/hubs/makersuite/googlemakersuitehub.d.ts", + "experimental/chains/violation_of_expectations.cjs", + "experimental/chains/violation_of_expectations.js", + "experimental/chains/violation_of_expectations.d.ts", "evaluation.cjs", "evaluation.js", "evaluation.d.ts", @@ -2451,6 +2454,11 @@ "import": "./experimental/hubs/makersuite/googlemakersuitehub.js", "require": "./experimental/hubs/makersuite/googlemakersuitehub.cjs" }, + "./experimental/chains/violation_of_expectations": { + "types": "./experimental/chains/violation_of_expectations.d.ts", + "import": "./experimental/chains/violation_of_expectations.js", + "require": "./experimental/chains/violation_of_expectations.cjs" + }, "./evaluation": { "types": "./evaluation.d.ts", "import": "./evaluation.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index c447a933964e..dc87e80bc1de 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -276,6 +276,7 @@ const entrypoints = { "experimental/llms/bittensor": "experimental/llms/bittensor", "experimental/hubs/makersuite/googlemakersuitehub": "experimental/hubs/makersuite/googlemakersuitehub", + "experimental/chains/violation_of_expectations": "experimental/chains/violation_of_expectations/index", // evaluation evaluation: "evaluation/index", // runnables diff --git a/langchain/src/experimental/chains/tests/violation_of_expectations_chain.int.test.ts b/langchain/src/experimental/chains/tests/violation_of_expectations_chain.int.test.ts new file mode 100644 index 000000000000..0ece5872b07e --- /dev/null +++ b/langchain/src/experimental/chains/tests/violation_of_expectations_chain.int.test.ts @@ -0,0 +1,42 @@ +import { ChatOpenAI } from "../../../chat_models/openai.js"; +import { OpenAIEmbeddings } from "../../../embeddings/openai.js"; +import { AIMessage, HumanMessage } from "../../../schema/index.js"; +import { HNSWLib } from "../../../vectorstores/hnswlib.js"; +import { ViolationOfExpectationsChain } from "../violation_of_expectations/violation_of_expectations_chain.js"; + +const dummyMessages = [ + new HumanMessage( + "I've been thinking about the importance of time with myself to discover my voice. I feel like 1-2 hours is never enough." + ), + new AIMessage( + "The concept of 'adequate time' varies. Have you tried different formats of introspection, such as morning pages or long-form writing, to see if they make the process more efficient?" + ), + new HumanMessage( + "I have tried journaling but never consistently. Sometimes it feels like writing doesn't capture everything." + ), +]; + +test("should respond with the proper schema", async () => { + const vectorStore = await HNSWLib.fromTexts( + [" "], + [{ id: 1 }], + new OpenAIEmbeddings() + ); + const retriever = vectorStore.asRetriever(); + + const llm = new ChatOpenAI({ + modelName: "gpt-4", + }); + const chain = new ViolationOfExpectationsChain({ + llm, + retriever, + }); + + const res = await chain.call({ + chat_history: dummyMessages, + }); + + console.log({ + res, + }); +}); diff --git a/langchain/src/experimental/chains/violation_of_expectations/index.ts b/langchain/src/experimental/chains/violation_of_expectations/index.ts new file mode 100644 index 000000000000..2c5b18343fb5 --- /dev/null +++ b/langchain/src/experimental/chains/violation_of_expectations/index.ts @@ -0,0 +1,4 @@ +export { + type ViolationOfExpectationsChainInput, + ViolationOfExpectationsChain, +} from "./violation_of_expectations_chain.js"; diff --git a/langchain/src/experimental/chains/violation_of_expectations/types.ts b/langchain/src/experimental/chains/violation_of_expectations/types.ts new file mode 100644 index 000000000000..1dffcb5d6274 --- /dev/null +++ b/langchain/src/experimental/chains/violation_of_expectations/types.ts @@ -0,0 +1,77 @@ +import { BaseMessage, HumanMessage } from "../../../schema/index.js"; + +/** + * Contains the chunk of messages, along with the + * users response, which is the next message after the chunk. + */ +export type MessageChunkResult = { + chunkedMessages: BaseMessage[]; + /** + * User response can be undefined if the last message in + * the chat history was from the AI. + */ + userResponse?: HumanMessage; +}; + +export type PredictNextUserMessageResponse = { + userState: string; + predictedUserMessage: string; + insights: Array; +}; + +export type GetPredictionViolationsResponse = { + userResponse?: HumanMessage; + revisedPrediction: string; + explainedPredictionErrors: Array; +}; + +export const PREDICT_NEXT_USER_MESSAGE_FUNCTION = { + name: "predictNextUserMessage", + description: "Predicts the next user message, along with insights.", + parameters: { + type: "object", + properties: { + userState: { + type: "string", + description: "Concise reasoning about the users internal mental state.", + }, + predictedUserMessage: { + type: "string", + description: + "Your prediction on how they will respond to the AI's most recent message.", + }, + insights: { + type: "array", + items: { + type: "string", + }, + description: + "A concise list of any additional insights that would be useful to improve prediction.", + }, + }, + required: ["userState", "predictedUserMessage", "insights"], + }, +}; + +export const PREDICTION_VIOLATIONS_FUNCTION = { + name: "predictionViolations", + description: + "Generates violations, errors and differences between the predicted user response, and the actual response.", + parameters: { + type: "object", + properties: { + violationExplanation: { + type: "string", + description: "How was the predication violated?", + }, + explainedPredictionErrors: { + type: "array", + items: { + type: "string", + }, + description: "Explanations of how the prediction was violated and why", + }, + }, + required: ["violationExplanation", "explainedPredictionErrors"], + }, +}; diff --git a/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_chain.ts b/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_chain.ts new file mode 100644 index 000000000000..a4777e06c244 --- /dev/null +++ b/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_chain.ts @@ -0,0 +1,475 @@ +import { CallbackManagerForChainRun } from "../../../callbacks/manager.js"; +import { ChatOpenAI } from "../../../chat_models/openai.js"; +import { JsonOutputFunctionsParser } from "../../../output_parsers/openai_functions.js"; +import { + BaseMessage, + ChainValues, + HumanMessage, + isBaseMessage, +} from "../../../schema/index.js"; +import { StringOutputParser } from "../../../schema/output_parser.js"; +import { BaseRetriever } from "../../../schema/retriever.js"; +import { BaseChain, ChainInputs } from "../../../chains/base.js"; +import { + GetPredictionViolationsResponse, + MessageChunkResult, + PREDICTION_VIOLATIONS_FUNCTION, + PREDICT_NEXT_USER_MESSAGE_FUNCTION, + PredictNextUserMessageResponse, +} from "./types.js"; +import { + GENERATE_FACTS_PROMPT, + GENERATE_REVISED_PREDICTION_PROMPT, + PREDICTION_VIOLATIONS_PROMPT, + PREDICT_NEXT_USER_MESSAGE_PROMPT, +} from "./violation_of_expectations_prompt.js"; + +/** + * Interface for the input parameters of the ViolationOfExpectationsChain class. + */ +export interface ViolationOfExpectationsChainInput extends ChainInputs { + /** + * The retriever to use for retrieving stored + * thoughts and insights. + */ + retriever: BaseRetriever; + /** + * The LLM to use + */ + llm: ChatOpenAI; +} + +/** + * Chain that generates key insights/facts of a user based on a + * a chat conversation with an AI. + */ +export class ViolationOfExpectationsChain + extends BaseChain + implements ViolationOfExpectationsChainInput +{ + static lc_name() { + return "ViolationOfExpectationsChain"; + } + + _chainType(): string { + return "violation_of_expectation_chain"; + } + + chatHistoryKey = "chat_history"; + + thoughtsKey = "thoughts"; + + get inputKeys() { + return [this.chatHistoryKey]; + } + + get outputKeys() { + return [this.thoughtsKey]; + } + + retriever: BaseRetriever; + + llm: ChatOpenAI; + + jsonOutputParser: JsonOutputFunctionsParser; + + stringOutputParser: StringOutputParser; + + constructor(fields: ViolationOfExpectationsChainInput) { + super(fields); + this.retriever = fields.retriever; + this.llm = fields.llm; + this.jsonOutputParser = new JsonOutputFunctionsParser(); + this.stringOutputParser = new StringOutputParser(); + } + + getChatHistoryString(chatHistory: BaseMessage[]): string { + return chatHistory + .map((chatMessage) => { + if (chatMessage._getType() === "human") { + return `Human: ${chatMessage.content}`; + } else if (chatMessage._getType() === "ai") { + return `AI: ${chatMessage.content}`; + } else { + return `${chatMessage.content}`; + } + }) + .join("\n"); + } + + removeDuplicateStrings(strings: Array): Array { + return [...new Set(strings)]; + } + + /** + * This method breaks down the chat history into chunks of messages. + * Each chunk consists of a sequence of messages ending with an AI message and the subsequent user response, if any. + * + * @param {BaseMessage[]} chatHistory - The chat history to be chunked. + * + * @returns {MessageChunkResult[]} An array of message chunks. Each chunk includes a sequence of messages and the subsequent user response. + * + * @description + * The method iterates over the chat history and pushes each message into a temporary array. + * When it encounters an AI message, it checks for a subsequent user message. + * If a user message is found, it is considered as the user response to the AI message. + * If no user message is found after the AI message, the user response is undefined. + * The method then pushes the chunk (sequence of messages and user response) into the result array. + * This process continues until all messages in the chat history have been processed. + */ + chunkMessagesByAIResponse(chatHistory: BaseMessage[]): MessageChunkResult[] { + const newArray: MessageChunkResult[] = []; + const tempArray: BaseMessage[] = []; + + chatHistory.forEach((item, index) => { + tempArray.push(item); + if (item._getType() === "ai") { + let userResponse: BaseMessage | undefined = chatHistory[index + 1]; + if (!userResponse || userResponse._getType() !== "human") { + userResponse = undefined; + } + + newArray.push({ + chunkedMessages: tempArray, + userResponse: userResponse + ? new HumanMessage(userResponse) + : undefined, + }); + } + }); + + return newArray; + } + + /** + * This method processes a chat history to generate insights about the user. + * + * @param {ChainValues} values - The input values for the chain. It should contain a key for chat history. + * @param {CallbackManagerForChainRun} [runManager] - Optional callback manager for the chain run. + * + * @returns {Promise} A promise that resolves to a list of insights about the user. + * + * @throws {Error} If the chat history key is not found in the input values or if the chat history is not an array of BaseMessages. + * + * @description + * The method performs the following steps: + * 1. Checks if the chat history key is present in the input values and if the chat history is an array of BaseMessages. + * 2. Breaks the chat history into chunks of messages. + * 3. For each chunk, it generates an initial prediction for the user's next message. + * 4. For each prediction, it generates insights and prediction violations, and regenerates the prediction based on the violations. + * 5. For each set of messages, it generates a fact/insight about the user. + * The method returns a list of these insights. + */ + async _call( + values: ChainValues, + runManager?: CallbackManagerForChainRun + ): Promise { + if (!(this.chatHistoryKey in values)) { + throw new Error(`Chat history key ${this.chatHistoryKey} not found`); + } + + const chatHistory: unknown[] = values[this.chatHistoryKey]; + + const isEveryMessageBaseMessage = chatHistory.every((message) => + isBaseMessage(message) + ); + if (!isEveryMessageBaseMessage) { + throw new Error("Chat history must be an array of BaseMessages"); + } + + const messageChunks = this.chunkMessagesByAIResponse( + chatHistory as BaseMessage[] + ); + + // Generate the initial prediction for every user message. + const userPredictions = await Promise.all( + messageChunks.map(async (chatHistoryChunk) => ({ + userPredictions: await this.predictNextUserMessage( + chatHistoryChunk.chunkedMessages + ), + userResponse: chatHistoryChunk.userResponse, + runManager, + })) + ); + + // Generate insights, and prediction violations for every user message. + // This call also regenerates the prediction based on the violations. + const predictionViolations = await Promise.all( + userPredictions.map((prediction) => + this.getPredictionViolations({ + userPredictions: prediction.userPredictions, + userResponse: prediction.userResponse, + runManager, + }) + ) + ); + + // Generate a fact/insight about the user for every set of messages. + const insights = await Promise.all( + predictionViolations.map((violation) => + this.generateFacts({ + userResponse: violation.userResponse, + predictions: { + revisedPrediction: violation.revisedPrediction, + explainedPredictionErrors: violation.explainedPredictionErrors, + }, + }) + ) + ); + + return { + insights, + }; + } + + /** + * This method predicts the next user message based on the chat history. + * + * @param {BaseMessage[]} chatHistory - The chat history based on which the next user message is predicted. + * @param {CallbackManagerForChainRun} [runManager] - Optional callback manager for the chain run. + * + * @returns {Promise} A promise that resolves to the predicted next user message, the user state, and any insights. + * + * @throws {Error} If the response from the language model does not contain the expected keys: 'userState', 'predictedUserMessage', and 'insights'. + */ + private async predictNextUserMessage( + chatHistory: BaseMessage[], + runManager?: CallbackManagerForChainRun + ): Promise { + const messageString = this.getChatHistoryString(chatHistory); + + const llmWithFunctions = this.llm.bind({ + functions: [PREDICT_NEXT_USER_MESSAGE_FUNCTION], + function_call: { name: PREDICT_NEXT_USER_MESSAGE_FUNCTION.name }, + }); + + const chain = PREDICT_NEXT_USER_MESSAGE_PROMPT.pipe(llmWithFunctions).pipe( + this.jsonOutputParser + ); + + const res = await chain.invoke( + { + chat_history: messageString, + }, + runManager?.getChild("prediction") + ); + + if ( + !( + "userState" in res && + "predictedUserMessage" in res && + "insights" in res + ) + ) { + throw new Error(`Invalid response from LLM: ${JSON.stringify(res)}`); + } + + const predictionResponse = res as PredictNextUserMessageResponse; + + // Query the retriever for relevant insights. Use the generates insights as a query. + const retrievedDocs = await this.retrieveRelevantInsights( + predictionResponse.insights + ); + const relevantDocs = this.removeDuplicateStrings([ + ...predictionResponse.insights, + ...retrievedDocs, + ]); + + return { + ...predictionResponse, + insights: relevantDocs, + }; + } + + /** + * Retrieves relevant insights based on the provided insights. + * + * @param {Array} insights - An array of insights to be used for retrieving relevant documents. + * + * @returns {Promise>} A promise that resolves to an array of relevant insights content. + */ + private async retrieveRelevantInsights( + insights: Array + ): Promise> { + // Only extract the first relevant doc from the retriever. We don't need more than one. + const relevantInsightsDocuments = await Promise.all( + insights.map(async (insight) => { + const relevantInsight = await this.retriever.getRelevantDocuments( + insight + ); + return relevantInsight[0]; + }) + ); + + const relevantInsightsContent = relevantInsightsDocuments.map( + (document) => document.pageContent + ); + + return relevantInsightsContent; + } + + /** + * This method generates prediction violations based on the predicted and actual user responses. + * It also generates a revised prediction based on the identified violations. + * + * @param {Object} params - The parameters for the method. + * @param {PredictNextUserMessageResponse} params.userPredictions - The predicted user message, user state, and insights. + * @param {BaseMessage} [params.userResponse] - The actual user response. + * @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run. + * + * @returns {Promise<{ userResponse: BaseMessage | undefined; revisedPrediction: string; explainedPredictionErrors: Array; }>} A promise that resolves to an object containing the actual user response, the revised prediction, and the explained prediction errors. + * + * @throws {Error} If the response from the language model does not contain the expected keys: 'violationExplanation', 'explainedPredictionErrors', and 'accuratePrediction'. + */ + private async getPredictionViolations({ + userPredictions, + userResponse, + runManager, + }: { + userPredictions: PredictNextUserMessageResponse; + userResponse?: BaseMessage; + runManager?: CallbackManagerForChainRun; + }): Promise { + const llmWithFunctions = this.llm.bind({ + functions: [PREDICTION_VIOLATIONS_FUNCTION], + function_call: { name: PREDICTION_VIOLATIONS_FUNCTION.name }, + }); + + const chain = PREDICTION_VIOLATIONS_PROMPT.pipe(llmWithFunctions).pipe( + this.jsonOutputParser + ); + + const res = (await chain.invoke( + { + predicted_output: userPredictions.predictedUserMessage, + actual_output: userResponse?.content ?? "", + user_insights: userPredictions.insights.join("\n"), + }, + runManager?.getChild("prediction_violations") + )) as Awaited<{ + violationExplanation: string; + explainedPredictionErrors: Array; + accuratePrediction: boolean; + }>; + + // Generate a revised prediction based on violations. + const revisedPrediction = await this.generateRevisedPrediction({ + originalPrediction: userPredictions.predictedUserMessage, + explainedPredictionErrors: res.explainedPredictionErrors, + userInsights: userPredictions.insights, + runManager, + }); + + return { + userResponse, + revisedPrediction, + explainedPredictionErrors: res.explainedPredictionErrors, + }; + } + + /** + * This method generates a revised prediction based on the original prediction, explained prediction errors, and user insights. + * + * @param {Object} params - The parameters for the method. + * @param {string} params.originalPrediction - The original prediction made by the model. + * @param {Array} params.explainedPredictionErrors - An array of explained prediction errors. + * @param {Array} params.userInsights - An array of insights about the user. + * @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run. + * + * @returns {Promise} A promise that resolves to a revised prediction. + */ + private async generateRevisedPrediction({ + originalPrediction, + explainedPredictionErrors, + userInsights, + runManager, + }: { + originalPrediction: string; + explainedPredictionErrors: Array; + userInsights: Array; + runManager?: CallbackManagerForChainRun; + }): Promise { + const revisedPredictionChain = GENERATE_REVISED_PREDICTION_PROMPT.pipe( + this.llm + ).pipe(this.stringOutputParser); + + const revisedPredictionRes = await revisedPredictionChain.invoke( + { + prediction: originalPrediction, + explained_prediction_errors: explainedPredictionErrors.join("\n"), + user_insights: userInsights.join("\n"), + }, + runManager?.getChild("prediction_revision") + ); + + return revisedPredictionRes; + } + + /** + * This method generates facts or insights about the user based on the revised prediction, explained prediction errors, and the user's response. + * + * @param {Object} params - The parameters for the method. + * @param {BaseMessage} [params.userResponse] - The actual user response. + * @param {Object} params.predictions - The revised prediction and explained prediction errors. + * @param {string} params.predictions.revisedPrediction - The revised prediction made by the model. + * @param {Array} params.predictions.explainedPredictionErrors - An array of explained prediction errors. + * @param {CallbackManagerForChainRun} [params.runManager] - Optional callback manager for the chain run. + * + * @returns {Promise} A promise that resolves to a string containing the generated facts or insights about the user. + */ + private async generateFacts({ + userResponse, + predictions, + runManager, + }: { + userResponse?: BaseMessage; + /** + * Optional if the prediction was accurate. + */ + predictions: { + revisedPrediction: string; + explainedPredictionErrors: Array; + }; + runManager?: CallbackManagerForChainRun; + }): Promise { + const chain = GENERATE_FACTS_PROMPT.pipe(this.llm).pipe( + this.stringOutputParser + ); + + const res = await chain.invoke( + { + prediction_violations: predictions.explainedPredictionErrors.join("\n"), + prediction: predictions.revisedPrediction, + user_message: userResponse?.content ?? "", + }, + runManager?.getChild("generate_facts") + ); + + return res; + } + + /** + * Static method that creates a ViolationOfExpectationsChain instance from a + * ChatOpenAI and retriever. It also accepts optional options + * to customize the chain. + * + * @param llm The ChatOpenAI instance. + * @param retriever The retriever used for similarity search. + * @param options Optional options to customize the chain. + * + * @returns A new instance of ViolationOfExpectationsChain. + */ + static fromLLM( + llm: ChatOpenAI, + retriever: BaseRetriever, + options?: Partial< + Omit + > + ): ViolationOfExpectationsChain { + return new this({ + retriever, + llm, + ...options, + }); + } +} diff --git a/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_prompt.ts b/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_prompt.ts new file mode 100644 index 000000000000..b5957c6c538d --- /dev/null +++ b/langchain/src/experimental/chains/violation_of_expectations/violation_of_expectations_prompt.ts @@ -0,0 +1,50 @@ +import { PromptTemplate } from "../../../prompts/prompt.js"; + +export const PREDICT_NEXT_USER_MESSAGE_PROMPT = + /* #__PURE__ */ PromptTemplate.fromTemplate(` +You have been tasked with coming up with insights and data-points based on a chat history between a human and an AI. +Given the user's chat history provide the following: +- Concise reasoning about the users internal mental state. +- Your prediction on how they will respond to the AI's most recent message. +- A concise list of any additional insights that would be useful to improve prediction. +-------- +Chat History: {chat_history}`); + +export const PREDICTION_VIOLATIONS_PROMPT = + /* #__PURE__ */ PromptTemplate.fromTemplate(`You have been given a prediction and an actual message from a human and AI conversation. +Using the prediction, actual message, and additional user insights, generate the following: +- How exactly was the original prediction violated? Which parts were wrong? State the exact differences. +- If there were errors with the prediction, what were they and why? +-------- +Predicted Output: {predicted_output} +-------- +Actual Output: {actual_output} +-------- +User Insights: {user_insights} +-------- +`); + +export const GENERATE_REVISED_PREDICTION_PROMPT = + /* #__PURE__ */ PromptTemplate.fromTemplate(` +You have been tasked with revising a prediction on what a user might say in a chat conversation. +-------- +Your previous prediction: {prediction} +-------- +Ways in which your prediction was off: {explained_prediction_errors} +-------- +Key insights to the user: {user_insights} +-------- +Given the above, revise your prediction to be more accurate. +Revised Prediction:`); + +export const GENERATE_FACTS_PROMPT = + /* #__PURE__ */ PromptTemplate.fromTemplate(` +Given a user message, an LLM generated prediction of what that message might be, and a list of violations which the prediction made compared to the actual message, generate a fact about the user, relevant to the users message. +-------- +Prediction violations: {prediction_violations} +-------- +Revised prediction: {prediction} +-------- +Actual user message: {user_message} +-------- +Relevant fact:`); diff --git a/langchain/src/load/import_map.ts b/langchain/src/load/import_map.ts index 5fc225648fa5..a968697794d1 100644 --- a/langchain/src/load/import_map.ts +++ b/langchain/src/load/import_map.ts @@ -83,5 +83,6 @@ export * as experimental__babyagi from "../experimental/babyagi/index.js"; export * as experimental__generative_agents from "../experimental/generative_agents/index.js"; export * as experimental__plan_and_execute from "../experimental/plan_and_execute/index.js"; export * as experimental__chat_models__bittensor from "../experimental/chat_models/bittensor.js"; +export * as experimental__chains__violation_of_expectations from "../experimental/chains/violation_of_expectations/index.js"; export * as evaluation from "../evaluation/index.js"; export * as runnables__remote from "../runnables/remote.js"; diff --git a/langchain/tsconfig.json b/langchain/tsconfig.json index 0eee7f0e388b..06bc64000386 100644 --- a/langchain/tsconfig.json +++ b/langchain/tsconfig.json @@ -263,6 +263,7 @@ "src/experimental/chat_models/bittensor.ts", "src/experimental/llms/bittensor.ts", "src/experimental/hubs/makersuite/googlemakersuitehub.ts", + "src/experimental/chains/violation_of_expectations/index.ts", "src/evaluation/index.ts", "src/runnables/remote.ts" ], diff --git a/package.json b/package.json index 6da3b1490609..46fcb548bfba 100644 --- a/package.json +++ b/package.json @@ -30,7 +30,7 @@ "publish": "bash langchain/scripts/release-branch.sh && turbo run --filter langchain build lint test && yarn run test:exports:docker && yarn workspace langchain run release && echo '🔗 Open https://github.com/hwchase17/langchainjs/compare/release?expand=1 and merge the release PR'", "example": "yarn workspace examples start", "precommit": "turbo run precommit", - "docs": "yarn workspace docs dev", + "docs": "yarn workspace docs start", "postinstall": "husky install" }, "author": "LangChain",