Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor],openai[patch]: Add usage metadata to AIMessage/Chunk #5586

Merged
merged 12 commits into from
May 31, 2024
58 changes: 55 additions & 3 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ import {
export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
invalid_tool_calls?: InvalidToolCall[];
usage_metadata?: UsageMetadata;
};

/**
* Usage metadata for a message, such as token counts.
*/
export type UsageMetadata = {
/**
* The count of input (or prompt) tokens.
*/
input_tokens: number;
/**
* The count of output (or completion) tokens
*/
output_tokens: number;
/**
* The total token count
*/
total_tokens: number;
};

/**
Expand All @@ -30,6 +49,11 @@ export class AIMessage extends BaseMessage {

invalid_tool_calls?: InvalidToolCall[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
Expand Down Expand Up @@ -94,6 +118,7 @@ export class AIMessage extends BaseMessage {
this.invalid_tool_calls =
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
}
this.usage_metadata = initParams.usage_metadata;
}

static lc_name() {
Expand Down Expand Up @@ -127,6 +152,11 @@ export class AIMessageChunk extends BaseMessageChunk {

tool_call_chunks?: ToolCallChunk[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

constructor(fields: string | AIMessageChunkFields) {
let initParams: AIMessageChunkFields;
if (typeof fields === "string") {
Expand Down Expand Up @@ -177,10 +207,11 @@ export class AIMessageChunk extends BaseMessageChunk {
// properties with initializers, so we have to check types twice.
super(initParams);
this.tool_call_chunks =
initParams?.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams?.tool_calls ?? this.tool_calls;
initParams.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams?.invalid_tool_calls ?? this.invalid_tool_calls;
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
this.usage_metadata = initParams.usage_metadata;
}

get lc_aliases(): Record<string, string> {
Expand Down Expand Up @@ -226,6 +257,27 @@ export class AIMessageChunk extends BaseMessageChunk {
combinedFields.tool_call_chunks = rawToolCalls;
}
}
if (
this.usage_metadata !== undefined ||
chunk.usage_metadata !== undefined
) {
const left: UsageMetadata = this.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const right: UsageMetadata = chunk.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const usage_metadata: UsageMetadata = {
input_tokens: left.input_tokens + right.input_tokens,
output_tokens: left.output_tokens + right.output_tokens,
total_tokens: left.total_tokens + right.total_tokens,
};
combinedFields.usage_metadata = usage_metadata;
}
return new AIMessageChunk(combinedFields);
}
}
28 changes: 28 additions & 0 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ export class ChatOpenAI<
)
);
}
const llmType = this._llmType();

const params: Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
"messages"
Expand All @@ -553,6 +555,14 @@ export class ChatOpenAI<
tool_choice: options?.tool_choice,
response_format: options?.response_format,
seed: options?.seed,
// Only set stream_options if the model is OpenAI
bracesproul marked this conversation as resolved.
Show resolved Hide resolved
...(llmType === "openai"
? {
stream_options: {
include_usage: true,
},
}
: {}),
...this.modelKwargs,
};
return params;
Expand Down Expand Up @@ -586,8 +596,12 @@ export class ChatOpenAI<
};
let defaultRole: OpenAIRoleEnum | undefined;
const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
const choice = data?.choices[0];
if (data.usage) {
usage = data.usage;
}
if (!choice) {
continue;
}
Expand Down Expand Up @@ -632,6 +646,20 @@ export class ChatOpenAI<
{ chunk: generationChunk }
);
}
if (usage) {
const generationChunk = new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
usage_metadata: {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
},
}),
text: "",
});
yield generationChunk;
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down
37 changes: 37 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { test, jest, expect } from "@jest/globals";
import {
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
Expand Down Expand Up @@ -767,3 +768,39 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => {
expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed);
}
});

test("Streaming tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI();
const response = await model.stream("Hello, how are you?");
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}
console.log({
usage_metadata: finalResult?.usage_metadata,
});
expect(finalResult).toBeTruthy();
expect(finalResult?.usage_metadata).toBeTruthy();
expect(finalResult?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(finalResult?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});

test("streaming: true tokens can be found in usage_metadata field", async () => {
const model = new ChatOpenAI({
streaming: true,
});
const response = await model.invoke("Hello, how are you?");
console.log({
usage_metadata: response?.usage_metadata,
});
expect(response).toBeTruthy();
expect(response?.usage_metadata).toBeTruthy();
expect(response?.usage_metadata?.input_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.output_tokens).toBeGreaterThan(0);
expect(response?.usage_metadata?.total_tokens).toBeGreaterThan(0);
});
Loading