Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai/core: add text streaming usage and finish reason support #1217

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/ai-core/src/stream-text/openai-fullstream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ async function main() {
break;
}

case 'finish': {
console.log('Finish reason:', part.finishReason);
console.log('Usage:', part.usage);
break;
}

case 'error':
console.error('Error:', part.error);
break;
Expand Down
1 change: 1 addition & 0 deletions packages/core/ai-model-specification/errors/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from './api-call-error';
export * from './invalid-argument-error';
export * from './invalid-data-content-error';
export * from './invalid-prompt-error';
export * from './invalid-response-data-error';
export * from './invalid-tool-arguments-error';
export * from './json-parse-error';
export * from './load-api-key-error';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export class InvalidDataContentError extends Error {
return (
error instanceof Error &&
error.name === 'AI_InvalidDataContentError' &&
prompt != null
(error as InvalidDataContentError).content != null
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Server returned a response with invalid data content. This should be thrown by providers when they
* cannot parse the response from the API.
*/
export class InvalidResponseDataError extends Error {
readonly data: unknown;

constructor({
data,
message = `Invalid response data: ${JSON.stringify(data)}.`,
}: {
data: unknown;
message?: string;
}) {
super(message);

this.name = 'AI_InvalidResponseDataError';

this.data = data;
}

static isInvalidResponseDataError(
error: unknown,
): error is InvalidResponseDataError {
return (
error instanceof Error &&
error.name === 'AI_InvalidResponseDataError' &&
(error as InvalidResponseDataError).data != null
);
}

toJSON() {
return {
name: this.name,
message: this.message,
stack: this.stack,

data: this.data,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export type LanguageModelV1StreamPart =
// the usage stats and finish reason should be the last part of the
// stream:
| {
type: 'finish-metadata';
type: 'finish';
finishReason: LanguageModelV1FinishReason;
usage: { promptTokens: number; completionTokens: number };
}
Expand Down
18 changes: 16 additions & 2 deletions packages/core/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function runToolsTransformation<
break;
}

// process
// process tool call:
case 'tool-call': {
const toolName = chunk.toolName as keyof TOOLS & string;

Expand Down Expand Up @@ -135,8 +135,22 @@ export function runToolsTransformation<
break;
}

// process finish:
case 'finish': {
controller.enqueue({
type: 'finish',
finishReason: chunk.finishReason,
usage: {
promptTokens: chunk.usage.promptTokens,
completionTokens: chunk.usage.completionTokens,
totalTokens:
chunk.usage.promptTokens + chunk.usage.completionTokens,
},
});
break;
}

// ignore
case 'finish-metadata':
case 'tool-call-delta': {
break;
}
Expand Down
35 changes: 35 additions & 0 deletions packages/core/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ describe('result.textStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: `world!` },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand Down Expand Up @@ -50,6 +55,11 @@ describe('result.fullStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: `world!` },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand All @@ -64,6 +74,11 @@ describe('result.fullStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: 'world!' },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down Expand Up @@ -102,6 +117,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand All @@ -124,6 +144,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: { value: 'value' },
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down Expand Up @@ -162,6 +187,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand Down Expand Up @@ -192,6 +222,11 @@ describe('result.fullStream', () => {
args: { value: 'value' },
result: 'value-result',
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down
12 changes: 11 additions & 1 deletion packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import zodToJsonSchema from 'zod-to-json-schema';
import { LanguageModelV1FinishReason } from '../../ai-model-specification';
import {
LanguageModelV1,
LanguageModelV1CallWarning,
Expand Down Expand Up @@ -92,7 +93,16 @@ export type TextStreamPart<TOOLS extends Record<string, ExperimentalTool>> =
}
| ({
type: 'tool-result';
} & ToToolResult<TOOLS>);
} & ToToolResult<TOOLS>)
| {
type: 'finish';
finishReason: LanguageModelV1FinishReason;
usage: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
};

export class StreamTextResult<TOOLS extends Record<string, ExperimentalTool>> {
private readonly originalStream: ReadableStream<TextStreamPart<TOOLS>>;
Expand Down
65 changes: 65 additions & 0 deletions packages/core/mistral/mistral-chat-language-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import zodToJsonSchema from 'zod-to-json-schema';
import { LanguageModelV1Prompt } from '../ai-model-specification';
import { convertStreamToArray } from '../ai-model-specification/test/convert-stream-to-array';
import { JsonTestServer } from '../ai-model-specification/test/json-test-server';
import { StreamingTestServer } from '../ai-model-specification/test/streaming-test-server';
import { Mistral } from './mistral-facade';
import { z } from 'zod';

const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
Expand Down Expand Up @@ -157,6 +159,69 @@ describe('doStream', () => {
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: 'world!' },
{ type: 'text-delta', textDelta: '' },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 32 },
},
]);
});

it('should stream tool deltas', async () => {
server.responseChunks = [
`data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` +
`"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null,"logprobs":null}]}\n\n`,
`data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` +
`"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"name":"test-tool","arguments":` +
`"{\\"value\\":\\"Sparkle Day\\"}"` +
`}}]},"finish_reason":"tool_calls","logprobs":null}],"usage":{"prompt_tokens":183,"total_tokens":316,"completion_tokens":133}}\n\n`,
'data: [DONE]\n\n',
];

const { stream } = await new Mistral({
apiKey: 'test-api-key',
generateId: () => 'test-id',
})
.chat('mistral-large-latest')
.doStream({
inputFormat: 'prompt',
mode: {
type: 'regular',
tools: [
{
type: 'function',
name: 'test-tool',
parameters: zodToJsonSchema(z.object({ value: z.string() })),
},
],
},
prompt: TEST_PROMPT,
});

expect(await convertStreamToArray(stream)).toStrictEqual([
{
type: 'text-delta',
textDelta: '',
},
{
type: 'tool-call-delta',
toolCallId: 'test-id',
toolCallType: 'function',
toolName: 'test-tool',
argsTextDelta: '{"value":"Sparkle Day"}',
},
{
type: 'tool-call',
toolCallId: 'test-id',
toolCallType: 'function',
toolName: 'test-tool',
args: '{"value":"Sparkle Day"}',
},
{
type: 'finish',
finishReason: 'tool-calls',
usage: { promptTokens: 183, completionTokens: 133 },
},
]);
});

Expand Down
41 changes: 37 additions & 4 deletions packages/core/mistral/mistral-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import { z } from 'zod';
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1StreamPart,
ParseResult,
UnsupportedFunctionalityError,
createEventSourceResponseHandler,
createJsonResponseHandler,
generateId,
postJsonToApi,
} from '../ai-model-specification';
import { convertToMistralChatMessages } from './convert-to-mistral-chat-messages';
Expand All @@ -22,6 +22,7 @@ type MistralChatConfig = {
provider: string;
baseUrl: string;
headers: () => Record<string, string | undefined>;
generateId: () => string;
};

export class MistralChatLanguageModel implements LanguageModelV1 {
Expand Down Expand Up @@ -174,7 +175,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
text: choice.message.content ?? undefined,
toolCalls: choice.message.tool_calls?.map(toolCall => ({
toolCallType: 'function',
toolCallId: generateId(),
toolCallId: this.config.generateId(),
toolName: toolCall.function.name,
args: toolCall.function.arguments!,
})),
Expand Down Expand Up @@ -209,6 +210,14 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const { messages: rawPrompt, ...rawSettings } = args;

let finishReason: LanguageModelV1FinishReason = 'other';
let usage: { promptTokens: number; completionTokens: number } = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};

const generateId = this.config.generateId;

return {
stream: response.pipeThrough(
new TransformStream<
Expand All @@ -223,11 +232,24 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const value = chunk.value;

if (value.choices?.[0]?.delta == null) {
if (value.usage != null) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens,
};
}

const choice = value.choices[0];

if (choice?.finish_reason != null) {
finishReason = mapMistralFinishReason(choice.finish_reason);
}

if (choice?.delta == null) {
return;
}

const delta = value.choices[0].delta;
const delta = choice.delta;

if (delta.content != null) {
controller.enqueue({
Expand Down Expand Up @@ -258,6 +280,10 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
}
}
},

flush(controller) {
controller.enqueue({ type: 'finish', finishReason, usage });
},
}),
),
rawCall: { rawPrompt, rawSettings },
Expand Down Expand Up @@ -319,4 +345,11 @@ const mistralChatChunkSchema = z.object({
index: z.number(),
}),
),
usage: z
.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
})
.optional()
.nullable(),
});
Loading
Loading