From a22ce78ef7d237421240cf3c499c76b0dadcdd35 Mon Sep 17 00:00:00 2001 From: Gram Liu Date: Sat, 16 Dec 2023 19:11:00 -0800 Subject: [PATCH] core[patch]: Add LLM/ChatModel callbacks to cached generation (#3392) * Add _generateCached callback to chat_models * Add _generateCached to llms * Fix formatting * Remove unused ignore * Pass llmStringKey as parameter * Wrap generateCached arguments into object * Add coment for defineProperty block * Fix run managers getting filtered out * Fix formatting * Naming nit * Use more type imports * Add language model callback tests --------- Co-authored-by: Brace Sproul Co-authored-by: jacoblee93 --- .../src/language_models/chat_models.ts | 152 +++++++++++++++--- langchain-core/src/language_models/llms.ts | 148 +++++++++++++++-- .../language_models/tests/chat_models.test.ts | 38 +++++ .../src/language_models/tests/llms.test.ts | 38 +++++ .../tests/output_parser.test.ts | 20 +-- langchain-core/src/utils/testing/index.ts | 22 ++- 6 files changed, 356 insertions(+), 62 deletions(-) create mode 100644 langchain-core/src/language_models/tests/chat_models.test.ts create mode 100644 langchain-core/src/language_models/tests/llms.test.ts diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 8fa5e364134f..7c8b44bf9e7b 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -1,6 +1,6 @@ import { AIMessage, - BaseMessage, + type BaseMessage, BaseMessageChunk, type BaseMessageLike, HumanMessage, @@ -10,9 +10,10 @@ import { BasePromptValue } from "../prompt_values.js"; import { LLMResult, RUN_KEY, - ChatGeneration, + type ChatGeneration, ChatGenerationChunk, - ChatResult, + type ChatResult, + type Generation, } from "../outputs.js"; import { BaseLanguageModel, @@ -22,10 +23,11 @@ import { } from "./base.js"; import { CallbackManager, - CallbackManagerForLLMRun, - Callbacks, + type CallbackManagerForLLMRun, + type Callbacks, } from "../callbacks/manager.js"; -import { RunnableConfig } from "../runnables/config.js"; +import type { RunnableConfig } from "../runnables/config.js"; +import type { BaseCache } from "../caches.js"; /** * Represents a serialized chat model. @@ -76,6 +78,17 @@ export function createChatMessageChunkEncoderStream() { }); } +interface ChatModelGenerateCachedParameters< + T extends BaseChatModel, + CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions +> { + messages: BaseMessageLike[][]; + cache: BaseCache; + llmStringKey: string; + parsedOptions: T["ParsedCallOptions"]; + handledOptions: RunnableConfig; +} + /** * Base class for chat models. It extends the BaseLanguageModel class and * provides methods for generating chat based on input messages. @@ -293,6 +306,112 @@ export abstract class BaseChatModel< return output; } + async _generateCached({ + messages, + cache, + llmStringKey, + parsedOptions, + handledOptions, + }: ChatModelGenerateCachedParameters): Promise< + LLMResult & { missingPromptIndices: number[] } + > { + const baseMessages = messages.map((messageList) => + messageList.map(coerceMessageLikeToMessage) + ); + + // create callback manager and start run + const callbackManager_ = await CallbackManager.configure( + handledOptions.callbacks, + this.callbacks, + handledOptions.tags, + this.tags, + handledOptions.metadata, + this.metadata, + { verbose: this.verbose } + ); + const extra = { + options: parsedOptions, + invocation_params: this?.invocationParams(parsedOptions), + batch_size: 1, + cached: true, + }; + const runManagers = await callbackManager_?.handleChatModelStart( + this.toJSON(), + baseMessages, + undefined, + undefined, + extra, + undefined, + undefined, + handledOptions.runName + ); + + // generate results + const missingPromptIndices: number[] = []; + const results = await Promise.allSettled( + baseMessages.map(async (baseMessage, index) => { + // Join all content into one string for the prompt index + const prompt = + BaseChatModel._convertInputToPromptValue(baseMessage).toString(); + const result = await cache.lookup(prompt, llmStringKey); + + if (result == null) { + missingPromptIndices.push(index); + } + + return result; + }) + ); + + // Map run managers to the results before filtering out null results + // Null results are just absent from the cache. + const cachedResults = results + .map((result, index) => ({ result, runManager: runManagers?.[index] })) + .filter( + ({ result }) => + (result.status === "fulfilled" && result.value != null) || + result.status === "rejected" + ); + + // Handle results and call run managers + const generations: Generation[][] = []; + await Promise.all( + cachedResults.map(async ({ result: promiseResult, runManager }, i) => { + if (promiseResult.status === "fulfilled") { + const result = promiseResult.value as Generation[]; + generations[i] = result; + if (result.length) { + await runManager?.handleLLMNewToken(result[0].text); + } + return runManager?.handleLLMEnd({ + generations: [result], + }); + } else { + // status === "rejected" + await runManager?.handleLLMError(promiseResult.reason); + return Promise.reject(promiseResult.reason); + } + }) + ); + + const output = { + generations, + missingPromptIndices, + }; + + // This defines RUN_KEY as a non-enumerable property on the output object + // so that it is not serialized when the output is stringified, and so that + // it isnt included when listing the keys of the output object. + Object.defineProperty(output, RUN_KEY, { + value: runManagers + ? { runIds: runManagers?.map((manager) => manager.runId) } + : undefined, + configurable: true, + }); + + return output; + } + /** * Generates chat based on the input messages. * @param messages An array of arrays of BaseMessage instances. @@ -329,20 +448,13 @@ export abstract class BaseChatModel< const llmStringKey = this._getSerializedCacheKeyParametersForCall(callOptions); - const missingPromptIndices: number[] = []; - const generations = await Promise.all( - baseMessages.map(async (baseMessage, index) => { - // Join all content into one string for the prompt index - const prompt = - BaseChatModel._convertInputToPromptValue(baseMessage).toString(); - const result = await cache.lookup(prompt, llmStringKey); - if (!result) { - missingPromptIndices.push(index); - } - - return result; - }) - ); + const { generations, missingPromptIndices } = await this._generateCached({ + messages: baseMessages, + cache, + llmStringKey, + parsedOptions: callOptions, + handledOptions: runnableConfig, + }); let llmOutput = {}; if (missingPromptIndices.length > 0) { diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index c5ccb205cc7d..9332a7e84ee4 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -1,11 +1,20 @@ -import { AIMessage, BaseMessage, getBufferString } from "../messages/index.js"; -import { BasePromptValue } from "../prompt_values.js"; -import { LLMResult, RUN_KEY, Generation, GenerationChunk } from "../outputs.js"; import { - BaseCallbackConfig, + AIMessage, + type BaseMessage, + getBufferString, +} from "../messages/index.js"; +import type { BasePromptValue } from "../prompt_values.js"; +import { + type LLMResult, + RUN_KEY, + type Generation, + GenerationChunk, +} from "../outputs.js"; +import { + type BaseCallbackConfig, CallbackManager, - CallbackManagerForLLMRun, - Callbacks, + type CallbackManagerForLLMRun, + type Callbacks, } from "../callbacks/manager.js"; import { BaseLanguageModel, @@ -13,7 +22,8 @@ import { type BaseLanguageModelInput, type BaseLanguageModelParams, } from "./base.js"; -import { RunnableConfig } from "../runnables/config.js"; +import type { RunnableConfig } from "../runnables/config.js"; +import type { BaseCache } from "../caches.js"; export type SerializedLLM = { _model: string; @@ -30,6 +40,17 @@ export interface BaseLLMParams extends BaseLanguageModelParams { export interface BaseLLMCallOptions extends BaseLanguageModelCallOptions {} +interface LLMGenerateCachedParameters< + T extends BaseLLM, + CallOptions extends BaseLLMCallOptions = BaseLLMCallOptions +> { + prompts: string[]; + cache: BaseCache; + llmStringKey: string; + parsedOptions: T["ParsedCallOptions"]; + handledOptions: RunnableConfig; +} + /** * LLM Wrapper. Provides an {@link call} (an {@link generate}) function that takes in a prompt (or prompts) and returns a string. */ @@ -282,6 +303,102 @@ export abstract class BaseLLM< return output; } + async _generateCached({ + prompts, + cache, + llmStringKey, + parsedOptions, + handledOptions, + }: LLMGenerateCachedParameters): Promise< + LLMResult & { missingPromptIndices: number[] } + > { + const callbackManager_ = await CallbackManager.configure( + handledOptions.callbacks, + this.callbacks, + handledOptions.tags, + this.tags, + handledOptions.metadata, + this.metadata, + { verbose: this.verbose } + ); + const extra = { + options: parsedOptions, + invocation_params: this?.invocationParams(parsedOptions), + batch_size: prompts.length, + cached: true, + }; + const runManagers = await callbackManager_?.handleLLMStart( + this.toJSON(), + prompts, + undefined, + undefined, + extra, + undefined, + undefined, + handledOptions?.runName + ); + + // generate results + const missingPromptIndices: number[] = []; + const results = await Promise.allSettled( + prompts.map(async (prompt, index) => { + const result = await cache.lookup(prompt, llmStringKey); + if (result == null) { + missingPromptIndices.push(index); + } + return result; + }) + ); + + // Map run managers to the results before filtering out null results + // Null results are just absent from the cache. + const cachedResults = results + .map((result, index) => ({ result, runManager: runManagers?.[index] })) + .filter( + ({ result }) => + (result.status === "fulfilled" && result.value != null) || + result.status === "rejected" + ); + + // Handle results and call run managers + const generations: Generation[][] = []; + await Promise.all( + cachedResults.map(async ({ result: promiseResult, runManager }, i) => { + if (promiseResult.status === "fulfilled") { + const result = promiseResult.value as Generation[]; + generations[i] = result; + if (result.length) { + await runManager?.handleLLMNewToken(result[0].text); + } + return runManager?.handleLLMEnd({ + generations: [result], + }); + } else { + // status === "rejected" + await runManager?.handleLLMError(promiseResult.reason); + return Promise.reject(promiseResult.reason); + } + }) + ); + + const output = { + generations, + missingPromptIndices, + }; + + // This defines RUN_KEY as a non-enumerable property on the output object + // so that it is not serialized when the output is stringified, and so that + // it isnt included when listing the keys of the output object. + Object.defineProperty(output, RUN_KEY, { + value: runManagers + ? { runIds: runManagers?.map((manager) => manager.runId) } + : undefined, + configurable: true, + }); + + return output; + } + /** * Run the LLM on the given prompts and input, handling caching. */ @@ -312,16 +429,13 @@ export abstract class BaseLLM< const { cache } = this; const llmStringKey = this._getSerializedCacheKeyParametersForCall(callOptions); - const missingPromptIndices: number[] = []; - const generations = await Promise.all( - prompts.map(async (prompt, index) => { - const result = await cache.lookup(prompt, llmStringKey); - if (!result) { - missingPromptIndices.push(index); - } - return result; - }) - ); + const { generations, missingPromptIndices } = await this._generateCached({ + prompts, + cache, + llmStringKey, + parsedOptions: callOptions, + handledOptions: runnableConfig, + }); let llmOutput = {}; if (missingPromptIndices.length > 0) { diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts new file mode 100644 index 000000000000..e0465983e1d2 --- /dev/null +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -0,0 +1,38 @@ +import { test } from "@jest/globals"; +import { FakeChatModel } from "../../utils/testing/index.js"; + +test("Test ChatModel uses callbacks", async () => { + const model = new FakeChatModel({}); + let acc = ""; + const response = await model.invoke("Hello there!", { + callbacks: [ + { + handleLLMNewToken: (token: string) => { + console.log(token); + acc += token; + }, + }, + ], + }); + expect(response.content).toEqual(acc); +}); + +test("Test ChatModel uses callbacks with a cache", async () => { + const model = new FakeChatModel({ + cache: true, + }); + let acc = ""; + const response = await model.invoke("Hello there!"); + const response2 = await model.invoke("Hello there!", { + callbacks: [ + { + handleLLMNewToken: (token: string) => { + console.log(token); + acc += token; + }, + }, + ], + }); + expect(response.content).toEqual(response2.content); + expect(response2.content).toEqual(acc); +}); diff --git a/langchain-core/src/language_models/tests/llms.test.ts b/langchain-core/src/language_models/tests/llms.test.ts new file mode 100644 index 000000000000..6dd39b4cee91 --- /dev/null +++ b/langchain-core/src/language_models/tests/llms.test.ts @@ -0,0 +1,38 @@ +import { test } from "@jest/globals"; +import { FakeLLM } from "../../utils/testing/index.js"; + +test("Test FakeLLM uses callbacks", async () => { + const model = new FakeLLM({}); + let acc = ""; + const response = await model.invoke("Hello there!", { + callbacks: [ + { + handleLLMNewToken: (token: string) => { + console.log(token); + acc += token; + }, + }, + ], + }); + expect(response).toEqual(acc); +}); + +test("Test FakeLLM uses callbacks with a cache", async () => { + const model = new FakeLLM({ + cache: true, + }); + let acc = ""; + const response = await model.invoke("Hello there!"); + const response2 = await model.invoke("Hello there!", { + callbacks: [ + { + handleLLMNewToken: (token: string) => { + console.log(token); + acc += token; + }, + }, + ], + }); + expect(response).toEqual(response2); + expect(response2).toEqual(acc); +}); diff --git a/langchain-core/src/output_parsers/tests/output_parser.test.ts b/langchain-core/src/output_parsers/tests/output_parser.test.ts index 4392073ed28c..915b3afc8d62 100644 --- a/langchain-core/src/output_parsers/tests/output_parser.test.ts +++ b/langchain-core/src/output_parsers/tests/output_parser.test.ts @@ -1,26 +1,8 @@ /* eslint-disable no-promise-executor-return */ import { test } from "@jest/globals"; -import { LLM } from "../../language_models/llms.js"; +import { FakeStreamingLLM } from "../../utils/testing/index.js"; import { BytesOutputParser } from "../bytes.js"; -import { GenerationChunk } from "../../outputs.js"; - -class FakeStreamingLLM extends LLM { - _llmType() { - return "fake"; - } - - async _call(prompt: string): Promise { - return prompt; - } - - async *_streamResponseChunks(input: string) { - for (const c of input) { - await new Promise((resolve) => setTimeout(resolve, 50)); - yield { text: c, generationInfo: {} } as GenerationChunk; - } - } -} test("BytesOutputParser", async () => { const llm = new FakeStreamingLLM({}); diff --git a/langchain-core/src/utils/testing/index.ts b/langchain-core/src/utils/testing/index.ts index 334fe05934ba..2d6350abd3a3 100644 --- a/langchain-core/src/utils/testing/index.ts +++ b/langchain-core/src/utils/testing/index.ts @@ -14,7 +14,7 @@ import { BaseChatModel, BaseChatModelParams, } from "../../language_models/chat_models.js"; -import { LLM } from "../../language_models/llms.js"; +import { BaseLLMParams, LLM } from "../../language_models/llms.js"; import { BaseMessage, AIMessage, @@ -72,8 +72,10 @@ export class FakeLLM extends LLM { thrownErrorString?: string; - constructor(fields: { response?: string; thrownErrorString?: string }) { - super({}); + constructor( + fields: { response?: string; thrownErrorString?: string } & BaseLLMParams + ) { + super(fields); this.response = fields.response; this.thrownErrorString = fields.thrownErrorString; } @@ -82,11 +84,17 @@ export class FakeLLM extends LLM { return "fake"; } - async _call(prompt: string): Promise { + async _call( + prompt: string, + _options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun + ): Promise { if (this.thrownErrorString) { throw new Error(this.thrownErrorString); } - return this.response ?? prompt; + const response = this.response ?? prompt; + await runManager?.handleLLMNewToken(response); + return response; } } @@ -118,7 +126,8 @@ export class FakeChatModel extends BaseChatModel { async _generate( messages: BaseMessage[], - options?: this["ParsedCallOptions"] + options?: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun ): Promise { if (options?.stop?.length) { return { @@ -131,6 +140,7 @@ export class FakeChatModel extends BaseChatModel { }; } const text = messages.map((m) => m.content).join("\n"); + await runManager?.handleLLMNewToken(text); return { generations: [ {