Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Fix caching of complex message types #6028

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langchain-core/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const config = {
internals: [/node\:/, /js-tiktoken/, /langsmith/],
entrypoints: {
agents: "agents",
caches: "caches",
caches: "caches/base",
"callbacks/base": "callbacks/base",
"callbacks/manager": "callbacks/manager",
"callbacks/promises": "callbacks/promises",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { insecureHash } from "./utils/hash.js";
import type { Generation, ChatGeneration } from "./outputs.js";
import { mapStoredMessageToChatMessage } from "./messages/utils.js";
import { type StoredGeneration } from "./messages/base.js";
import { insecureHash } from "../utils/hash.js";
import type { Generation, ChatGeneration } from "../outputs.js";
import { mapStoredMessageToChatMessage } from "../messages/utils.js";
import { type StoredGeneration } from "../messages/base.js";

/**
* This cache key should be consistent across all versions of langchain.
* It is currently NOT consistent across versions of langchain.
* This cache key should be consistent across all versions of LangChain.
* It is currently NOT consistent across versions of LangChain.
*
* A huge benefit of having a remote cache (like redis) is that you can
* access the cache from different processes/machines. The allows you to
* seperate concerns and scale horizontally.
* separate concerns and scale horizontally.
*
* TODO: Make cache key consistent across versions of langchain.
* TODO: Make cache key consistent across versions of LangChain.
*/
export const getCacheKey = (...strings: string[]): string =>
insecureHash(strings.join("_"));
Expand Down
40 changes: 40 additions & 0 deletions langchain-core/src/caches/tests/in_memory_cache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { MessageContentComplex } from "../../messages/base.js";
import { InMemoryCache } from "../base.js";

test("InMemoryCache works", async () => {
const cache = new InMemoryCache();

await cache.update("prompt", "key1", [
{
text: "text1",
},
]);

const result = await cache.lookup("prompt", "key1");
expect(result).toBeDefined();
if (!result) {
return;
}
expect(result[0].text).toBe("text1");
});

test("InMemoryCache works with complex message types", async () => {
const cache = new InMemoryCache<MessageContentComplex[]>();

await cache.update("prompt", "key1", [
{
type: "text",
text: "text1",
},
]);

const result = await cache.lookup("prompt", "key1");
expect(result).toBeDefined();
if (!result) {
return;
}
expect(result[0]).toEqual({
type: "text",
text: "text1",
});
});
4 changes: 2 additions & 2 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Tiktoken, TiktokenModel } from "js-tiktoken/lite";

import { z } from "zod";
import { type BaseCache, InMemoryCache } from "../caches.js";
import { type BaseCache, InMemoryCache } from "../caches/base.js";
import {
type BasePromptValueInterface,
StringPromptValue,
Expand Down Expand Up @@ -481,7 +481,7 @@ export abstract class BaseLanguageModel<
* @param callOptions Call options for the model
* @returns A unique cache key.
*/
protected _getSerializedCacheKeyParametersForCall(
_getSerializedCacheKeyParametersForCall(
// TODO: Fix when we remove the RunnableLambda backwards compatibility shim.
{ config, ...callOptions }: CallOptions & { config?: RunnableConfig }
): string {
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {
type Callbacks,
} from "../callbacks/manager.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import type { BaseCache } from "../caches/base.js";
import { StructuredToolInterface } from "../tools.js";
import {
Runnable,
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
type BaseLanguageModelParams,
} from "./base.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import type { BaseCache } from "../caches/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
Expand Down
39 changes: 39 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { test } from "@jest/globals";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";
import { HumanMessage } from "../../messages/human.js";
import { getBufferString } from "../../messages/utils.js";
import { AIMessage } from "../../messages/ai.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
Expand Down Expand Up @@ -189,3 +192,39 @@ test("Test ChatModel withStructuredOutput new syntax and includeRaw", async () =
// No error
console.log(response.parsed);
});

test("Test ChatModel can cache complex messages", async () => {
const model = new FakeChatModel({
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const contentToCache = [
{
type: "text",
text: "Hello there!",
},
];
const humanMessage = new HumanMessage({
content: contentToCache,
});

const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({});

// Invoke model to trigger cache update
await model.invoke([humanMessage]);

const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeDefined();
if (!value) return;

expect(value[0].text).toEqual(JSON.stringify(contentToCache, null, 2));

expect("message" in value[0]).toBeTruthy();
if (!("message" in value[0])) return;
const cachedMsg = value[0].message as AIMessage;
expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2));
});
2 changes: 1 addition & 1 deletion langchain-core/src/load/import_map.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually.

export * as agents from "../agents.js";
export * as caches from "../caches.js";
export * as caches from "../caches/base.js";
export * as callbacks__base from "../callbacks/base.js";
export * as callbacks__manager from "../callbacks/manager.js";
export * as callbacks__promises from "../callbacks/promises.js";
Expand Down
68 changes: 68 additions & 0 deletions langchain-core/src/messages/tests/message_utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AIMessage } from "../ai.js";
import { HumanMessage } from "../human.js";
import { SystemMessage } from "../system.js";
import { BaseMessage } from "../base.js";
import { getBufferString } from "../utils.js";

describe("filterMessage", () => {
const getMessages = () => [
Expand Down Expand Up @@ -431,3 +432,70 @@ describe("trimMessages can trim", () => {
expect(typeof (trimmedMessages as any).func).toBe("function");
});
});

test("getBufferString can handle complex messages", () => {
const messageArr1 = [new HumanMessage("Hello there!")];
const messageArr2 = [
new AIMessage({
content: [
{
type: "text",
text: "Hello there!",
},
],
}),
];
const messageArr3 = [
new HumanMessage({
content: [
{
type: "image_url",
image_url: {
url: "https://example.com/image.jpg",
},
},
{
type: "image_url",
image_url: "https://example.com/image.jpg",
},
],
}),
];

const bufferString1 = getBufferString(messageArr1);
expect(bufferString1).toBe("Human: Hello there!");

const bufferString2 = getBufferString(messageArr2);
expect(bufferString2).toBe(
`AI: ${JSON.stringify(
[
{
type: "text",
text: "Hello there!",
},
],
null,
2
)}`
);

const bufferString3 = getBufferString(messageArr3);
expect(bufferString3).toBe(
`Human: ${JSON.stringify(
[
{
type: "image_url",
image_url: {
url: "https://example.com/image.jpg",
},
},
{
type: "image_url",
image_url: "https://example.com/image.jpg",
},
],
null,
2
)}`
);
});
6 changes: 5 additions & 1 deletion langchain-core/src/messages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ export function getBufferString(
throw new Error(`Got unsupported message type: ${m._getType()}`);
}
const nameStr = m.name ? `${m.name}, ` : "";
string_messages.push(`${role}: ${nameStr}${m.content}`);
const readableContent =
typeof m.content === "string"
? m.content
: JSON.stringify(m.content, null, 2);
string_messages.push(`${role}: ${nameStr}${readableContent}`);
}
return string_messages.join("\n");
}
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/tests/caches.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { test, expect } from "@jest/globals";

import { InMemoryCache } from "../caches.js";
import { InMemoryCache } from "../caches/base.js";

test("InMemoryCache", async () => {
const cache = new InMemoryCache();
Expand Down
9 changes: 8 additions & 1 deletion langchain-core/src/utils/testing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,14 @@ export class FakeChatModel extends BaseChatModel {
],
};
}
const text = messages.map((m) => m.content).join("\n");
const text = messages
.map((m) => {
if (typeof m.content === "string") {
return m.content;
}
return JSON.stringify(m.content, null, 2);
})
.join("\n");
await runManager?.handleLLMNewToken(text);
return {
generations: [
Expand Down
10 changes: 9 additions & 1 deletion libs/langchain-groq/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests<
this.skipTestMessage(
"testToolMessageHistoriesListContent",
"ChatGroq",
"Not properly implemented."
"Complex message types not properly implemented"
);
}

async testCacheComplexMessageTypes() {
this.skipTestMessage(
"testCacheComplexMessageTypes",
"ChatGroq",
"Complex message types not properly implemented"
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests<
functionId: "123456789",
});
}

async testCacheComplexMessageTypes() {
this.skipTestMessage(
"testCacheComplexMessageTypes",
"ChatMistralAI",
"Complex message types not properly implemented"
);
}
}

const testClass = new ChatMistralAIStandardIntegrationTests();
Expand Down
52 changes: 52 additions & 0 deletions libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
HumanMessage,
ToolMessage,
UsageMetadata,
getBufferString,
} from "@langchain/core/messages";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
Expand Down Expand Up @@ -438,6 +439,50 @@
expect(tool_calls[0].name).toBe("math_addition");
}

async testCacheComplexMessageTypes() {
const model = new this.Cls({
...this.constructorArgs,
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const humanMessage = new HumanMessage({
content: [
{
type: "text",
text: "Hello there!",
},
],
});
const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({} as any);

Check warning on line 460 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type

// Invoke the model to trigger a cache update.
await model.invoke([humanMessage]);
const cacheValue = await model.cache.lookup(prompt, llmKey);

// Ensure only one generation was added to the cache.
expect(cacheValue !== null).toBeTruthy();
if (!cacheValue) return;
expect(cacheValue).toHaveLength(1);

expect("message" in cacheValue[0]).toBeTruthy();
if (!("message" in cacheValue[0])) return;
const cachedMessage = cacheValue[0].message as AIMessage;

// Invoke the model again with the same prompt, triggering a cache hit.
const result = await model.invoke([humanMessage]);

expect(result.content).toBe(cacheValue[0].text);
expect(result).toEqual(cachedMessage);

// Verify a second generation was not added to the cache.
const cacheValue2 = await model.cache.lookup(prompt, llmKey);
expect(cacheValue2).toEqual(cacheValue);
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand All @@ -449,42 +494,42 @@

try {
await this.testInvoke();
} catch (e: any) {

Check warning on line 497 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testInvoke failed", e);
}

try {
await this.testStream();
} catch (e: any) {

Check warning on line 504 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testStream failed", e);
}

try {
await this.testBatch();
} catch (e: any) {

Check warning on line 511 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testBatch failed", e);
}

try {
await this.testConversation();
} catch (e: any) {

Check warning on line 518 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testConversation failed", e);
}

try {
await this.testUsageMetadata();
} catch (e: any) {

Check warning on line 525 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadata failed", e);
}

try {
await this.testUsageMetadataStreaming();
} catch (e: any) {

Check warning on line 532 in libs/langchain-standard-tests/src/integration_tests/chat_models.ts

View workflow job for this annotation

GitHub Actions / Check linting

Unexpected any. Specify a different type
allTestsPassed = false;
console.error("testUsageMetadataStreaming failed", e);
}
Expand Down Expand Up @@ -531,6 +576,13 @@
console.error("testBindToolsWithOpenAIFormattedTools failed", e);
}

try {
await this.testCacheComplexMessageTypes();
} catch (e: any) {
allTestsPassed = false;
console.error("testCacheComplexMessageTypes failed", e);
}

return allTestsPassed;
}
}
Loading