Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[minor]: Add support for json schema response format #6438

Merged
merged 10 commits into from
Aug 16, 2024
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ export type StructuredOutputType = z.infer<z.ZodObject<any, any, any, any>>;
export type StructuredOutputMethodOptions<IncludeRaw extends boolean = false> =
{
name?: string;
method?: "functionCalling" | "jsonMode";
method?: "functionCalling" | "jsonMode" | "jsonSchema" | string;
includeRaw?: IncludeRaw;
};

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function abs(relativePath) {


export const config = {
internals: [/node\:/, /@langchain\/core\//],
internals: [/node\:/, /@langchain\/core\//, "openai/helpers/zod"],
entrypoints: {
index: "index",
},
Expand Down
154 changes: 139 additions & 15 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import {
type StructuredOutputMethodParams,
} from "@langchain/core/language_models/base";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import {
Runnable,
Expand All @@ -56,12 +55,20 @@ import {
} from "@langchain/core/output_parsers/openai_tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ToolCallChunk } from "@langchain/core/messages/tool";
import { zodResponseFormat } from "openai/helpers/zod";
import type {
ResponseFormatText,
ResponseFormatJSONObject,
ResponseFormatJSONSchema,
} from "openai/resources/shared";
import { ParsedChatCompletion } from "openai/resources/beta/chat/completions.mjs";
import type {
AzureOpenAIInput,
OpenAICallOptions,
OpenAIChatInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
ChatOpenAIResponseFormat,
} from "./types.js";
import { type OpenAIEndpointConfig, getEndpoint } from "./utils/azure.js";
import {
Expand All @@ -73,6 +80,7 @@ import {
FunctionDef,
formatFunctionDefinitions,
} from "./utils/openai-format-fndef.js";
import { _convertToOpenAITool } from "./utils/tools.js";

export type { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput };

Expand Down Expand Up @@ -295,7 +303,7 @@ function _convertChatOpenAIToolTypeToOpenAITool(

return tool;
}
return convertToOpenAITool(tool, fields);
return _convertToOpenAITool(tool, fields);
}

export interface ChatOpenAIStructuredOutputMethodOptions<
Expand Down Expand Up @@ -324,7 +332,7 @@ export interface ChatOpenAICallOptions
tools?: ChatOpenAIToolType[];
tool_choice?: OpenAIToolChoice;
promptIndex?: number;
response_format?: { type: "json_object" };
response_format?: ChatOpenAIResponseFormat;
seed?: number;
/**
* Additional options to pass to streamed completions.
Expand Down Expand Up @@ -1027,6 +1035,34 @@ export class ChatOpenAI<
} as Partial<CallOptions>);
}

private createResponseFormat(
resFormat?: CallOptions["response_format"]
):
| ResponseFormatText
| ResponseFormatJSONObject
| ResponseFormatJSONSchema
| undefined {
if (
resFormat &&
resFormat.type === "json_schema" &&
resFormat.json_schema.schema &&
isZodSchema(resFormat.json_schema.schema)
) {
return zodResponseFormat(
resFormat.json_schema.schema,
resFormat.json_schema.name,
{
description: resFormat.json_schema.description,
}
);
}
return resFormat as
| ResponseFormatText
| ResponseFormatJSONObject
| ResponseFormatJSONSchema
| undefined;
}

/**
* Get the parameters used to invoke the model
*/
Expand All @@ -1049,6 +1085,7 @@ export class ChatOpenAI<
} else if (this.streamUsage && (this.streaming || extra?.streaming)) {
streamOptionsConfig = { stream_options: { include_usage: true } };
}

const params: Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
"messages"
Expand All @@ -1075,7 +1112,7 @@ export class ChatOpenAI<
)
: undefined,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
response_format: this.createResponseFormat(options?.response_format),
seed: options?.seed,
...streamOptionsConfig,
parallel_tool_calls: options?.parallel_tool_calls,
Expand Down Expand Up @@ -1113,6 +1150,32 @@ export class ChatOpenAI<
stream: true as const,
};
let defaultRole: OpenAIRoleEnum | undefined;
if (
params.response_format &&
params.response_format.type === "json_schema"
) {
console.warn(
`OpenAI does not yet support streaming with "response_format" set to "json_schema". Falling back to non-streaming mode.`
);
const res = await this._generate(messages, options, runManager);
const chunk = new ChatGenerationChunk({
message: new AIMessageChunk({
...res.generations[0].message,
}),
text: res.generations[0].text,
generationInfo: res.generations[0].generationInfo,
});
yield chunk;
return runManager?.handleLLMNewToken(
res.generations[0].text ?? "",
undefined,
undefined,
undefined,
undefined,
{ chunk }
);
}

const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
Expand Down Expand Up @@ -1248,17 +1311,36 @@ export class ChatOpenAI<
tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage;
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
} else {
const data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
let data;
if (
options.response_format &&
options.response_format.type === "json_schema"
) {
data = await this.betaParsedCompletionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
} else {
data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
}

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
Expand Down Expand Up @@ -1478,6 +1560,31 @@ export class ChatOpenAI<
});
}

/**
* Call the beta chat completions parse endpoint. This should only be called if
* response_format is set to "json_object".
* @param {OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming} request
* @param {OpenAICoreRequestOptions | undefined} options
*/
async betaParsedCompletionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<ParsedChatCompletion<null>> {
const requestOptions = this._getClientOptions(options);
return this.caller.call(async () => {
try {
const res = await this.client.beta.chat.completions.parse(
request,
requestOptions
);
return res;
} catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}

protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
Expand Down Expand Up @@ -1620,6 +1727,23 @@ export class ChatOpenAI<
} else {
outputParser = new JsonOutputParser<RunOutput>();
}
} else if (method === "jsonSchema") {
llm = this.bind({
response_format: {
type: "json_schema",
json_schema: {
name: name ?? "extract",
description: schema.description,
schema,
strict: config?.strict,
},
},
} as Partial<CallOptions>);
if (isZodSchema(schema)) {
outputParser = StructuredOutputParser.fromZodSchema(schema);
} else {
outputParser = new JsonOutputParser<RunOutput>();
}
} else {
let functionName = name ?? "extract";
// Is function calling
Expand Down
Loading
Loading