diff --git a/langchain/src/agents/openai/index.ts b/langchain/src/agents/openai/index.ts index 7fb10c0aaa13..7cc3be08b417 100644 --- a/langchain/src/agents/openai/index.ts +++ b/langchain/src/agents/openai/index.ts @@ -1,5 +1,5 @@ import { CallbackManager } from "../../callbacks/manager.js"; -import { ChatOpenAI } from "../../chat_models/openai.js"; +import { ChatOpenAI, ChatOpenAICallOptions } from "../../chat_models/openai.js"; import { BasePromptTemplate } from "../../prompts/base.js"; import { AIMessage, @@ -10,6 +10,7 @@ import { FunctionMessage, ChainValues, SystemMessage, + BaseMessageChunk, } from "../../schema/index.js"; import { StructuredTool } from "../../tools/base.js"; import { Agent, AgentArgs } from "../agent.js"; @@ -21,13 +22,20 @@ import { MessagesPlaceholder, SystemMessagePromptTemplate, } from "../../prompts/chat.js"; -import { BaseLanguageModel } from "../../base_language/index.js"; +import { + BaseLanguageModel, + BaseLanguageModelInput, +} from "../../base_language/index.js"; import { LLMChain } from "../../chains/llm_chain.js"; import { FunctionsAgentAction, OpenAIFunctionsAgentOutputParser, } from "./output_parser.js"; import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js"; +import { Runnable } from "../../schema/runnable/base.js"; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type CallOptionsIfAvailable = T extends { CallOptions: infer CO } ? CO : any; /** * Checks if the given action is a FunctionsAgentAction. @@ -199,16 +207,24 @@ export class OpenAIAgent extends Agent { } // Split inputs between prompt and llm - const llm = this.llmChain.llm as ChatOpenAI; + const llm = this.llmChain.llm as + | ChatOpenAI + | Runnable< + BaseLanguageModelInput, + BaseMessageChunk, + ChatOpenAICallOptions + >; + const valuesForPrompt = { ...newInputs }; - const valuesForLLM: (typeof llm)["CallOptions"] = { + const valuesForLLM: CallOptionsIfAvailable = { functions: this.tools.map(formatToOpenAIFunction), }; const callKeys = "callKeys" in this.llmChain.llm ? this.llmChain.llm.callKeys : []; for (const key of callKeys) { if (key in inputs) { - valuesForLLM[key as keyof (typeof llm)["CallOptions"]] = inputs[key]; + valuesForLLM[key as keyof CallOptionsIfAvailable] = + inputs[key]; delete valuesForPrompt[key]; } } @@ -216,11 +232,17 @@ export class OpenAIAgent extends Agent { const promptValue = await this.llmChain.prompt.formatPromptValue( valuesForPrompt ); - const message = await llm.predictMessages( - promptValue.toChatMessages(), - valuesForLLM, - callbackManager - ); + + const message = await ( + llm as Runnable< + BaseLanguageModelInput, + BaseMessageChunk, + ChatOpenAICallOptions + > + ).invoke(promptValue.toChatMessages(), { + ...valuesForLLM, + callbacks: callbackManager, + }); return this.outputParser.parseAIMessage(message); } } diff --git a/langchain/src/agents/tests/runnable.int.test.ts b/langchain/src/agents/tests/runnable.int.test.ts index 995db66f7f1f..598dbed0aa40 100644 --- a/langchain/src/agents/tests/runnable.int.test.ts +++ b/langchain/src/agents/tests/runnable.int.test.ts @@ -14,6 +14,8 @@ import { SerpAPI } from "../../tools/serpapi.js"; import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js"; import { Calculator } from "../../tools/calculator.js"; import { OpenAIFunctionsAgentOutputParser } from "../openai/output_parser.js"; +import { LLMChain } from "../../chains/llm_chain.js"; +import { OpenAIAgent } from "../openai/index.js"; test("Runnable variant", async () => { const tools = [new Calculator(), new SerpAPI()]; @@ -59,8 +61,44 @@ test("Runnable variant", async () => { const query = "What is the weather in New York?"; console.log(`Calling agent executor with query: ${query}`); - const result = await executor.call({ + const result = await executor.invoke({ input: query, }); console.log(result); }); + +test("Runnable variant works with executor", async () => { + // Prepare tools + const tools = [new Calculator(), new SerpAPI()]; + const runnableModel = new ChatOpenAI({ + modelName: "gpt-4", + temperature: 0, + }).bind({}); + + const prompt = ChatPromptTemplate.fromMessages([ + ["ai", "You are a helpful assistant"], + ["human", "{input}"], + new MessagesPlaceholder("agent_scratchpad"), + ]); + + // Prepare agent chain + const llmChain = new LLMChain({ + prompt, + llm: runnableModel, + }); + const agent = new OpenAIAgent({ + llmChain, + tools, + }); + + // Prepare and run executor + const executor = new AgentExecutor({ + agent, + tools, + }); + const result = await executor.invoke({ + input: "What is the weather in New York?", + }); + + console.log(result); +});