Skip to content

Commit

Permalink
core[patch]: Add LLM/ChatModel callbacks to cached generation (#3392)
Browse files Browse the repository at this point in the history
* Add _generateCached callback to chat_models

* Add _generateCached to llms

* Fix formatting

* Remove unused ignore

* Pass llmStringKey as parameter

* Wrap generateCached arguments into object

* Add coment for defineProperty block

* Fix run managers getting filtered out

* Fix formatting

* Naming nit

* Use more type imports

* Add language model callback tests

---------

Co-authored-by: Brace Sproul <braceasproul@gmail.com>
Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
3 people authored Dec 17, 2023
1 parent 19e4627 commit a22ce78
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 62 deletions.
152 changes: 132 additions & 20 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {
AIMessage,
BaseMessage,
type BaseMessage,
BaseMessageChunk,
type BaseMessageLike,
HumanMessage,
Expand All @@ -10,9 +10,10 @@ import { BasePromptValue } from "../prompt_values.js";
import {
LLMResult,
RUN_KEY,
ChatGeneration,
type ChatGeneration,
ChatGenerationChunk,
ChatResult,
type ChatResult,
type Generation,
} from "../outputs.js";
import {
BaseLanguageModel,
Expand All @@ -22,10 +23,11 @@ import {
} from "./base.js";
import {
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
type CallbackManagerForLLMRun,
type Callbacks,
} from "../callbacks/manager.js";
import { RunnableConfig } from "../runnables/config.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";

/**
* Represents a serialized chat model.
Expand Down Expand Up @@ -76,6 +78,17 @@ export function createChatMessageChunkEncoderStream() {
});
}

interface ChatModelGenerateCachedParameters<
T extends BaseChatModel<CallOptions>,
CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions
> {
messages: BaseMessageLike[][];
cache: BaseCache<Generation[]>;
llmStringKey: string;
parsedOptions: T["ParsedCallOptions"];
handledOptions: RunnableConfig;
}

/**
* Base class for chat models. It extends the BaseLanguageModel class and
* provides methods for generating chat based on input messages.
Expand Down Expand Up @@ -293,6 +306,112 @@ export abstract class BaseChatModel<
return output;
}

async _generateCached({
messages,
cache,
llmStringKey,
parsedOptions,
handledOptions,
}: ChatModelGenerateCachedParameters<typeof this>): Promise<
LLMResult & { missingPromptIndices: number[] }
> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);

// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
cached: true,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
baseMessages,
undefined,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);

// generate results
const missingPromptIndices: number[] = [];
const results = await Promise.allSettled(
baseMessages.map(async (baseMessage, index) => {
// Join all content into one string for the prompt index
const prompt =
BaseChatModel._convertInputToPromptValue(baseMessage).toString();
const result = await cache.lookup(prompt, llmStringKey);

if (result == null) {
missingPromptIndices.push(index);
}

return result;
})
);

// Map run managers to the results before filtering out null results
// Null results are just absent from the cache.
const cachedResults = results
.map((result, index) => ({ result, runManager: runManagers?.[index] }))
.filter(
({ result }) =>
(result.status === "fulfilled" && result.value != null) ||
result.status === "rejected"
);

// Handle results and call run managers
const generations: Generation[][] = [];
await Promise.all(
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
return Promise.reject(promiseResult.reason);
}
})
);

const output = {
generations,
missingPromptIndices,
};

// This defines RUN_KEY as a non-enumerable property on the output object
// so that it is not serialized when the output is stringified, and so that
// it isnt included when listing the keys of the output object.
Object.defineProperty(output, RUN_KEY, {
value: runManagers
? { runIds: runManagers?.map((manager) => manager.runId) }
: undefined,
configurable: true,
});

return output;
}

/**
* Generates chat based on the input messages.
* @param messages An array of arrays of BaseMessage instances.
Expand Down Expand Up @@ -329,20 +448,13 @@ export abstract class BaseChatModel<
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);

const missingPromptIndices: number[] = [];
const generations = await Promise.all(
baseMessages.map(async (baseMessage, index) => {
// Join all content into one string for the prompt index
const prompt =
BaseChatModel._convertInputToPromptValue(baseMessage).toString();
const result = await cache.lookup(prompt, llmStringKey);
if (!result) {
missingPromptIndices.push(index);
}

return result;
})
);
const { generations, missingPromptIndices } = await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
Expand Down
148 changes: 131 additions & 17 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
import { AIMessage, BaseMessage, getBufferString } from "../messages/index.js";
import { BasePromptValue } from "../prompt_values.js";
import { LLMResult, RUN_KEY, Generation, GenerationChunk } from "../outputs.js";
import {
BaseCallbackConfig,
AIMessage,
type BaseMessage,
getBufferString,
} from "../messages/index.js";
import type { BasePromptValue } from "../prompt_values.js";
import {
type LLMResult,
RUN_KEY,
type Generation,
GenerationChunk,
} from "../outputs.js";
import {
type BaseCallbackConfig,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
type CallbackManagerForLLMRun,
type Callbacks,
} from "../callbacks/manager.js";
import {
BaseLanguageModel,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type BaseLanguageModelParams,
} from "./base.js";
import { RunnableConfig } from "../runnables/config.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";

export type SerializedLLM = {
_model: string;
Expand All @@ -30,6 +40,17 @@ export interface BaseLLMParams extends BaseLanguageModelParams {

export interface BaseLLMCallOptions extends BaseLanguageModelCallOptions {}

interface LLMGenerateCachedParameters<
T extends BaseLLM<CallOptions>,
CallOptions extends BaseLLMCallOptions = BaseLLMCallOptions
> {
prompts: string[];
cache: BaseCache<Generation[]>;
llmStringKey: string;
parsedOptions: T["ParsedCallOptions"];
handledOptions: RunnableConfig;
}

/**
* LLM Wrapper. Provides an {@link call} (an {@link generate}) function that takes in a prompt (or prompts) and returns a string.
*/
Expand Down Expand Up @@ -282,6 +303,102 @@ export abstract class BaseLLM<
return output;
}

async _generateCached({
prompts,
cache,
llmStringKey,
parsedOptions,
handledOptions,
}: LLMGenerateCachedParameters<typeof this>): Promise<
LLMResult & { missingPromptIndices: number[] }
> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
cached: true,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
undefined,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);

// generate results
const missingPromptIndices: number[] = [];
const results = await Promise.allSettled(
prompts.map(async (prompt, index) => {
const result = await cache.lookup(prompt, llmStringKey);
if (result == null) {
missingPromptIndices.push(index);
}
return result;
})
);

// Map run managers to the results before filtering out null results
// Null results are just absent from the cache.
const cachedResults = results
.map((result, index) => ({ result, runManager: runManagers?.[index] }))
.filter(
({ result }) =>
(result.status === "fulfilled" && result.value != null) ||
result.status === "rejected"
);

// Handle results and call run managers
const generations: Generation[][] = [];
await Promise.all(
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
return Promise.reject(promiseResult.reason);
}
})
);

const output = {
generations,
missingPromptIndices,
};

// This defines RUN_KEY as a non-enumerable property on the output object
// so that it is not serialized when the output is stringified, and so that
// it isnt included when listing the keys of the output object.
Object.defineProperty(output, RUN_KEY, {
value: runManagers
? { runIds: runManagers?.map((manager) => manager.runId) }
: undefined,
configurable: true,
});

return output;
}

/**
* Run the LLM on the given prompts and input, handling caching.
*/
Expand Down Expand Up @@ -312,16 +429,13 @@ export abstract class BaseLLM<
const { cache } = this;
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);
const missingPromptIndices: number[] = [];
const generations = await Promise.all(
prompts.map(async (prompt, index) => {
const result = await cache.lookup(prompt, llmStringKey);
if (!result) {
missingPromptIndices.push(index);
}
return result;
})
);
const { generations, missingPromptIndices } = await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
Expand Down
Loading

2 comments on commit a22ce78

@vercel
Copy link

@vercel vercel bot commented on a22ce78 Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on a22ce78 Dec 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

langchainjs-docs – ./docs/core_docs/

langchainjs-docs-ruddy.vercel.app
langchainjs-docs-git-main-langchain.vercel.app
langchainjs-docs-langchain.vercel.app
js.langchain.com

Please sign in to comment.