Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(google-genai): Add support for search retrieval and code execution tools #7138

Merged
merged 9 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/core_docs/docs/concepts/streaming.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ The `stream()` method returns an iterator that yields these chunks as they are p

```typescript
for await (const chunk of await component.stream(someInput)) {
// IMPORTANT: Keep the processing of each chunk as efficient as possible.
// While you're processing the current chunk, the upstream component is
// waiting to produce the next one. For example, if working with LangGraph,
// graph execution is paused while the current chunk is being processed.
// In extreme cases, this could even result in timeouts (e.g., when llm outputs are
// streamed from an API that has a timeout).
console.log(chunk)
// IMPORTANT: Keep the processing of each chunk as efficient as possible.
// While you're processing the current chunk, the upstream component is
// waiting to produce the next one. For example, if working with LangGraph,
// graph execution is paused while the current chunk is being processed.
// In extreme cases, this could even result in timeouts (e.g., when llm outputs are
// streamed from an API that has a timeout).
console.log(chunk);
}
```

Expand Down
370 changes: 369 additions & 1 deletion docs/core_docs/docs/integrations/chat/google_generativeai.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type ToolChoice = string | Record<string, any> | "auto" | "any";
export type ToolChoice = string | Record<string, any> | "auto" | "any";

/**
* Represents a serialized chat model.
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-google-genai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@google/generative-ai": "^0.7.0",
"@google/generative-ai": "^0.21.0",
"zod-to-json-schema": "^3.22.4"
},
"peerDependencies": {
Expand Down
74 changes: 12 additions & 62 deletions libs/langchain-google-genai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import {
GenerateContentRequest,
SafetySetting,
Part as GenerativeAIPart,
Tool as GenerativeAITool,
ToolConfig,
FunctionCallingMode,
} from "@google/generative-ai";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
Expand All @@ -30,7 +27,6 @@ import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
} from "@langchain/core/language_models/base";
import { StructuredToolInterface } from "@langchain/core/tools";
import {
Runnable,
RunnablePassthrough,
Expand All @@ -43,11 +39,11 @@ import { zodToGenerativeAIParameters } from "./utils/zod_to_genai_parameters.js"
import {
convertBaseMessagesToContent,
convertResponseContentToChatGenerationChunk,
convertToGenerativeAITools,
mapGenerateContentResultToChatResult,
} from "./utils/common.js";
import { GoogleGenerativeAIToolsOutputParser } from "./output_parsers.js";
import { GoogleGenerativeAIToolType } from "./types.js";
import { convertToolsToGenAI } from "./utils/tools.js";

interface TokenUsage {
completionTokens?: number;
Expand Down Expand Up @@ -682,70 +678,24 @@ export class ChatGoogleGenerativeAI
AIMessageChunk,
GoogleGenerativeAIChatCallOptions
> {
return this.bind({ tools: convertToGenerativeAITools(tools), ...kwargs });
return this.bind({ tools: convertToolsToGenAI(tools)?.tools, ...kwargs });
}

invocationParams(
options?: this["ParsedCallOptions"]
): Omit<GenerateContentRequest, "contents"> {
let genaiTools: GenerativeAITool[] | undefined;
if (
Array.isArray(options?.tools) &&
!options?.tools.some(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(t: any) => !("lc_namespace" in t)
)
) {
// Tools are in StructuredToolInterface format. Convert to GenAI format
genaiTools = convertToGenerativeAITools(
options?.tools as StructuredToolInterface[]
);
} else {
genaiTools = options?.tools as GenerativeAITool[];
}

let toolConfig: ToolConfig | undefined;
if (genaiTools?.length && options?.tool_choice) {
if (["any", "auto", "none"].some((c) => c === options.tool_choice)) {
const modeMap: Record<string, FunctionCallingMode> = {
any: FunctionCallingMode.ANY,
auto: FunctionCallingMode.AUTO,
none: FunctionCallingMode.NONE,
};

toolConfig = {
functionCallingConfig: {
mode:
modeMap[options.tool_choice as keyof typeof modeMap] ??
"MODE_UNSPECIFIED",
allowedFunctionNames: options.allowedFunctionNames,
},
};
} else if (typeof options.tool_choice === "string") {
toolConfig = {
functionCallingConfig: {
mode: FunctionCallingMode.ANY,
allowedFunctionNames: [
...(options.allowedFunctionNames ?? []),
options.tool_choice,
],
},
};
}

if (!options.tool_choice && options.allowedFunctionNames) {
toolConfig = {
functionCallingConfig: {
mode: FunctionCallingMode.ANY,
allowedFunctionNames: options.allowedFunctionNames,
},
};
}
}
const toolsAndConfig = options?.tools?.length
? convertToolsToGenAI(options.tools, {
toolChoice: options.tool_choice,
allowedFunctionNames: options.allowedFunctionNames,
})
: undefined;

return {
tools: genaiTools,
toolConfig,
...(toolsAndConfig?.tools ? { tools: toolsAndConfig.tools } : {}),
...(toolsAndConfig?.toolConfig
? { toolConfig: toolsAndConfig.toolConfig }
: {}),
};
}

Expand Down
160 changes: 159 additions & 1 deletion libs/langchain-google-genai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ import {
} from "@langchain/core/prompts";
import { StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import { FunctionDeclarationSchemaType } from "@google/generative-ai";
import {
CodeExecutionTool,
DynamicRetrievalMode,
SchemaType as FunctionDeclarationSchemaType,
GoogleSearchRetrievalTool,
} from "@google/generative-ai";
import { concat } from "@langchain/core/utils/stream";
import { ChatGoogleGenerativeAI } from "../chat_models.js";

// Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable
Expand Down Expand Up @@ -567,3 +573,155 @@ test("Supports tool_choice", async () => {
);
expect(response.tool_calls?.length).toBe(1);
});

describe("GoogleSearchRetrievalTool", () => {
test("Supports GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool: GoogleSearchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: DynamicRetrievalMode.MODE_DYNAMIC,
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatGoogleGenerativeAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const result = await model.invoke("Who won the 2024 MLB World Series?");

expect(result.response_metadata?.groundingMetadata).toBeDefined();
expect(result.content as string).toContain("Dodgers");
});

test("Can stream GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool: GoogleSearchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: DynamicRetrievalMode.MODE_DYNAMIC,
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatGoogleGenerativeAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const stream = await model.stream("Who won the 2024 MLB World Series?");
let finalMsg: AIMessageChunk | undefined;
for await (const msg of stream) {
finalMsg = finalMsg ? concat(finalMsg, msg) : msg;
}
if (!finalMsg) {
throw new Error("finalMsg is undefined");
}
expect(finalMsg.response_metadata?.groundingMetadata).toBeDefined();
expect(finalMsg.content as string).toContain("Dodgers");
});
});

describe("CodeExecutionTool", () => {
test("Supports CodeExecutionTool", async () => {
const codeExecutionTool: CodeExecutionTool = {
codeExecution: {}, // Simply pass an empty object to enable it.
};
const model = new ChatGoogleGenerativeAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([codeExecutionTool]);

const result = await model.invoke(
"Use code execution to find the sum of the first and last 3 numbers in the following list: [1, 2, 3, 72638, 8, 727, 4, 5, 6]"
);

expect(Array.isArray(result.content)).toBeTruthy();
if (!Array.isArray(result.content)) {
throw new Error("Content is not an array");
}
const texts = result.content
.flatMap((item) => ("text" in item ? [item.text] : []))
.join("\n");
expect(texts).toContain("21");

const executableCode = result.content.find(
(item) => item.type === "executableCode"
);
expect(executableCode).toBeDefined();
const codeResult = result.content.find(
(item) => item.type === "codeExecutionResult"
);
expect(codeResult).toBeDefined();
});

test("CodeExecutionTool contents can be passed in chat history", async () => {
const codeExecutionTool: CodeExecutionTool = {
codeExecution: {}, // Simply pass an empty object to enable it.
};
const model = new ChatGoogleGenerativeAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([codeExecutionTool]);

const codeResult = await model.invoke(
"Use code execution to find the sum of the first and last 3 numbers in the following list: [1, 2, 3, 72638, 8, 727, 4, 5, 6]"
);

const explanation = await model.invoke([
codeResult,
{
role: "user",
content:
"Please explain the question I asked, the code you wrote, and the answer you got.",
},
]);

expect(typeof explanation.content).toBe("string");
expect(explanation.content.length).toBeGreaterThan(10);
});

test("Can stream CodeExecutionTool", async () => {
const codeExecutionTool: CodeExecutionTool = {
codeExecution: {}, // Simply pass an empty object to enable it.
};
const model = new ChatGoogleGenerativeAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([codeExecutionTool]);

const stream = await model.stream(
"Use code execution to find the sum of the first and last 3 numbers in the following list: [1, 2, 3, 72638, 8, 727, 4, 5, 6]"
);
let finalMsg: AIMessageChunk | undefined;
for await (const msg of stream) {
finalMsg = finalMsg ? concat(finalMsg, msg) : msg;
}

if (!finalMsg) {
throw new Error("finalMsg is undefined");
}
expect(Array.isArray(finalMsg.content)).toBeTruthy();
if (!Array.isArray(finalMsg.content)) {
throw new Error("Content is not an array");
}
const texts = finalMsg.content
.flatMap((item) => ("text" in item ? [item.text] : []))
.join("\n");
expect(texts).toContain("21");

const executableCode = finalMsg.content.find(
(item) => item.type === "executableCode"
);
expect(executableCode).toBeDefined();
const codeResult = finalMsg.content.find(
(item) => item.type === "codeExecutionResult"
);
expect(codeResult).toBeDefined();
});
});
10 changes: 8 additions & 2 deletions libs/langchain-google-genai/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool } from "@google/generative-ai";
import {
CodeExecutionTool,
FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
GoogleSearchRetrievalTool,
} from "@google/generative-ai";
import { BindToolsInput } from "@langchain/core/language_models/chat_models";

export type GoogleGenerativeAIToolType =
| BindToolsInput
| GoogleGenerativeAIFunctionDeclarationsTool;
| GoogleGenerativeAIFunctionDeclarationsTool
| CodeExecutionTool
| GoogleSearchRetrievalTool;
Loading
Loading