Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Add LLM/ChatModel callbacks to cached generation #3392

Merged
merged 16 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 132 additions & 20 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {
AIMessage,
BaseMessage,
type BaseMessage,
BaseMessageChunk,
type BaseMessageLike,
HumanMessage,
Expand All @@ -10,9 +10,10 @@ import { BasePromptValue } from "../prompt_values.js";
import {
LLMResult,
RUN_KEY,
ChatGeneration,
type ChatGeneration,
ChatGenerationChunk,
ChatResult,
type ChatResult,
type Generation,
} from "../outputs.js";
import {
BaseLanguageModel,
Expand All @@ -22,10 +23,11 @@ import {
} from "./base.js";
import {
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
type CallbackManagerForLLMRun,
type Callbacks,
} from "../callbacks/manager.js";
import { RunnableConfig } from "../runnables/config.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";

/**
* Represents a serialized chat model.
Expand Down Expand Up @@ -76,6 +78,17 @@ export function createChatMessageChunkEncoderStream() {
});
}

interface ChatModelGenerateCachedParameters<
T extends BaseChatModel<CallOptions>,
CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions
> {
messages: BaseMessageLike[][];
cache: BaseCache<Generation[]>;
llmStringKey: string;
parsedOptions: T["ParsedCallOptions"];
handledOptions: RunnableConfig;
}

/**
* Base class for chat models. It extends the BaseLanguageModel class and
* provides methods for generating chat based on input messages.
Expand Down Expand Up @@ -293,6 +306,112 @@ export abstract class BaseChatModel<
return output;
}

async _generateCached({
messages,
cache,
llmStringKey,
parsedOptions,
handledOptions,
}: ChatModelGenerateCachedParameters<typeof this>): Promise<
LLMResult & { missingPromptIndices: number[] }
> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);

// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
cached: true,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
baseMessages,
undefined,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);

// generate results
const missingPromptIndices: number[] = [];
const results = await Promise.allSettled(
baseMessages.map(async (baseMessage, index) => {
// Join all content into one string for the prompt index
const prompt =
BaseChatModel._convertInputToPromptValue(baseMessage).toString();
const result = await cache.lookup(prompt, llmStringKey);

if (result == null) {
missingPromptIndices.push(index);
}

return result;
})
);

// Map run managers to the results before filtering out null results
// Null results are just absent from the cache.
const cachedResults = results
.map((result, index) => ({ result, runManager: runManagers?.[index] }))
.filter(
({ result }) =>
(result.status === "fulfilled" && result.value != null) ||
result.status === "rejected"
);

// Handle results and call run managers
const generations: Generation[][] = [];
await Promise.all(
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
return Promise.reject(promiseResult.reason);
}
})
);

const output = {
generations,
missingPromptIndices,
};

// This defines RUN_KEY as a non-enumerable property on the output object
// so that it is not serialized when the output is stringified, and so that
// it isnt included when listing the keys of the output object.
Object.defineProperty(output, RUN_KEY, {
gramliu marked this conversation as resolved.
Show resolved Hide resolved
value: runManagers
? { runIds: runManagers?.map((manager) => manager.runId) }
: undefined,
configurable: true,
});

return output;
}

/**
* Generates chat based on the input messages.
* @param messages An array of arrays of BaseMessage instances.
Expand Down Expand Up @@ -329,20 +448,13 @@ export abstract class BaseChatModel<
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);

const missingPromptIndices: number[] = [];
const generations = await Promise.all(
baseMessages.map(async (baseMessage, index) => {
// Join all content into one string for the prompt index
const prompt =
BaseChatModel._convertInputToPromptValue(baseMessage).toString();
const result = await cache.lookup(prompt, llmStringKey);
if (!result) {
missingPromptIndices.push(index);
}

return result;
})
);
const { generations, missingPromptIndices } = await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
Expand Down
148 changes: 131 additions & 17 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
import { AIMessage, BaseMessage, getBufferString } from "../messages/index.js";
import { BasePromptValue } from "../prompt_values.js";
import { LLMResult, RUN_KEY, Generation, GenerationChunk } from "../outputs.js";
import {
BaseCallbackConfig,
AIMessage,
type BaseMessage,
getBufferString,
} from "../messages/index.js";
import type { BasePromptValue } from "../prompt_values.js";
import {
type LLMResult,
RUN_KEY,
type Generation,
GenerationChunk,
} from "../outputs.js";
import {
type BaseCallbackConfig,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
type CallbackManagerForLLMRun,
type Callbacks,
} from "../callbacks/manager.js";
import {
BaseLanguageModel,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type BaseLanguageModelParams,
} from "./base.js";
import { RunnableConfig } from "../runnables/config.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";

export type SerializedLLM = {
_model: string;
Expand All @@ -30,6 +40,17 @@ export interface BaseLLMParams extends BaseLanguageModelParams {

export interface BaseLLMCallOptions extends BaseLanguageModelCallOptions {}

interface LLMGenerateCachedParameters<
T extends BaseLLM<CallOptions>,
CallOptions extends BaseLLMCallOptions = BaseLLMCallOptions
> {
prompts: string[];
cache: BaseCache<Generation[]>;
llmStringKey: string;
parsedOptions: T["ParsedCallOptions"];
handledOptions: RunnableConfig;
}

/**
* LLM Wrapper. Provides an {@link call} (an {@link generate}) function that takes in a prompt (or prompts) and returns a string.
*/
Expand Down Expand Up @@ -282,6 +303,102 @@ export abstract class BaseLLM<
return output;
}

async _generateCached({
prompts,
cache,
llmStringKey,
parsedOptions,
handledOptions,
}: LLMGenerateCachedParameters<typeof this>): Promise<
LLMResult & { missingPromptIndices: number[] }
> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
cached: true,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
undefined,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);

// generate results
const missingPromptIndices: number[] = [];
const results = await Promise.allSettled(
prompts.map(async (prompt, index) => {
const result = await cache.lookup(prompt, llmStringKey);
if (result == null) {
missingPromptIndices.push(index);
}
return result;
})
);

// Map run managers to the results before filtering out null results
// Null results are just absent from the cache.
const cachedResults = results
.map((result, index) => ({ result, runManager: runManagers?.[index] }))
.filter(
({ result }) =>
(result.status === "fulfilled" && result.value != null) ||
result.status === "rejected"
);

// Handle results and call run managers
const generations: Generation[][] = [];
await Promise.all(
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
return Promise.reject(promiseResult.reason);
}
})
);

const output = {
generations,
missingPromptIndices,
};

// This defines RUN_KEY as a non-enumerable property on the output object
// so that it is not serialized when the output is stringified, and so that
// it isnt included when listing the keys of the output object.
Object.defineProperty(output, RUN_KEY, {
value: runManagers
? { runIds: runManagers?.map((manager) => manager.runId) }
: undefined,
configurable: true,
});

return output;
}

/**
* Run the LLM on the given prompts and input, handling caching.
*/
Expand Down Expand Up @@ -312,16 +429,13 @@ export abstract class BaseLLM<
const { cache } = this;
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);
const missingPromptIndices: number[] = [];
const generations = await Promise.all(
prompts.map(async (prompt, index) => {
const result = await cache.lookup(prompt, llmStringKey);
if (!result) {
missingPromptIndices.push(index);
}
return result;
})
);
const { generations, missingPromptIndices } = await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
Expand Down
Loading