Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OpenAIAgent to support Runnable models #3346

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions langchain/src/agents/openai/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { CallbackManager } from "../../callbacks/manager.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { ChatOpenAI, ChatOpenAICallOptions } from "../../chat_models/openai.js";
import { BasePromptTemplate } from "../../prompts/base.js";
import {
AIMessage,
Expand All @@ -10,6 +10,7 @@ import {
FunctionMessage,
ChainValues,
SystemMessage,
BaseMessageChunk,
} from "../../schema/index.js";
import { StructuredTool } from "../../tools/base.js";
import { Agent, AgentArgs } from "../agent.js";
Expand All @@ -21,13 +22,20 @@ import {
MessagesPlaceholder,
SystemMessagePromptTemplate,
} from "../../prompts/chat.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import {
BaseLanguageModel,
BaseLanguageModelInput,
} from "../../base_language/index.js";
import { LLMChain } from "../../chains/llm_chain.js";
import {
FunctionsAgentAction,
OpenAIFunctionsAgentOutputParser,
} from "./output_parser.js";
import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js";
import { Runnable } from "../../schema/runnable/base.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type CallOptionsIfAvailable<T> = T extends { CallOptions: infer CO } ? CO : any;

/**
* Checks if the given action is a FunctionsAgentAction.
Expand Down Expand Up @@ -199,28 +207,41 @@ export class OpenAIAgent extends Agent {
}

// Split inputs between prompt and llm
const llm = this.llmChain.llm as ChatOpenAI;
const llm = this.llmChain.llm as
| ChatOpenAI
| Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
ChatOpenAICallOptions
>;

const valuesForPrompt = { ...newInputs };
const valuesForLLM: (typeof llm)["CallOptions"] = {
const valuesForLLM: CallOptionsIfAvailable<typeof llm> = {
functions: this.tools.map(formatToOpenAIFunction),
};
const callKeys =
"callKeys" in this.llmChain.llm ? this.llmChain.llm.callKeys : [];
for (const key of callKeys) {
if (key in inputs) {
valuesForLLM[key as keyof (typeof llm)["CallOptions"]] = inputs[key];
valuesForLLM[key as keyof CallOptionsIfAvailable<typeof llm>] =
inputs[key];
delete valuesForPrompt[key];
}
}

const promptValue = await this.llmChain.prompt.formatPromptValue(
valuesForPrompt
);
const message = await llm.predictMessages(
promptValue.toChatMessages(),
valuesForLLM,
callbackManager
);
const promptValue =
await this.llmChain.prompt.formatPromptValue(valuesForPrompt);

const message = await (llm as
Runnable<
BaseLanguageModelInput,
BaseMessageChunk,
ChatOpenAICallOptions
>
).invoke(promptValue.toChatMessages(), {
...valuesForLLM,
callbacks: callbackManager,
});
return this.outputParser.parseAIMessage(message);
}
}
28 changes: 28 additions & 0 deletions langchain/src/agents/tests/runnable.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { SerpAPI } from "../../tools/serpapi.js";
import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js";
import { Calculator } from "../../tools/calculator.js";
import { OpenAIFunctionsAgentOutputParser } from "../openai/output_parser.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { OpenAIAgent } from "../openai/index.js";

test("Runnable variant", async () => {
const tools = [new Calculator(), new SerpAPI()];
Expand Down Expand Up @@ -64,3 +66,29 @@ test("Runnable variant", async () => {
});
console.log(result);
});

test("Runnable variant works with executor", async () => {
// Prepare tools
const tools = [new Calculator(), new SerpAPI()];
const runnableModel = new ChatOpenAI({
modelName: "gpt-4",
temperature: 0,
}).bind({});

// Prepare agent chain
const llmChain = new LLMChain({
prompt: ChatPromptTemplate.fromTemplate("What is the weather in New York?"),
llm: runnableModel,
});
const agent = new OpenAIAgent({
llmChain,
tools,
});

// Prepare and run executor
const executor = new AgentExecutor({
agent,
tools,
});
await executor.call({});
});