Skip to content

Commit

Permalink
ai/core: add text streaming usage and finish reason support (#1217)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Mar 26, 2024
1 parent 8088de8 commit adccb26
Show file tree
Hide file tree
Showing 16 changed files with 721 additions and 31 deletions.
6 changes: 6 additions & 0 deletions examples/ai-core/src/stream-text/openai-fullstream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ async function main() {
break;
}

case 'finish': {
console.log('Finish reason:', part.finishReason);
console.log('Usage:', part.usage);
break;
}

case 'error':
console.error('Error:', part.error);
break;
Expand Down
1 change: 1 addition & 0 deletions packages/core/ai-model-specification/errors/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from './api-call-error';
export * from './invalid-argument-error';
export * from './invalid-data-content-error';
export * from './invalid-prompt-error';
export * from './invalid-response-data-error';
export * from './invalid-tool-arguments-error';
export * from './json-parse-error';
export * from './load-api-key-error';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export class InvalidDataContentError extends Error {
return (
error instanceof Error &&
error.name === 'AI_InvalidDataContentError' &&
prompt != null
(error as InvalidDataContentError).content != null
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Server returned a response with invalid data content. This should be thrown by providers when they
* cannot parse the response from the API.
*/
export class InvalidResponseDataError extends Error {
readonly data: unknown;

constructor({
data,
message = `Invalid response data: ${JSON.stringify(data)}.`,
}: {
data: unknown;
message?: string;
}) {
super(message);

this.name = 'AI_InvalidResponseDataError';

this.data = data;
}

static isInvalidResponseDataError(
error: unknown,
): error is InvalidResponseDataError {
return (
error instanceof Error &&
error.name === 'AI_InvalidResponseDataError' &&
(error as InvalidResponseDataError).data != null
);
}

toJSON() {
return {
name: this.name,
message: this.message,
stack: this.stack,

data: this.data,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export type LanguageModelV1StreamPart =
// the usage stats and finish reason should be the last part of the
// stream:
| {
type: 'finish-metadata';
type: 'finish';
finishReason: LanguageModelV1FinishReason;
usage: { promptTokens: number; completionTokens: number };
}
Expand Down
18 changes: 16 additions & 2 deletions packages/core/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function runToolsTransformation<
break;
}

// process
// process tool call:
case 'tool-call': {
const toolName = chunk.toolName as keyof TOOLS & string;

Expand Down Expand Up @@ -135,8 +135,22 @@ export function runToolsTransformation<
break;
}

// process finish:
case 'finish': {
controller.enqueue({
type: 'finish',
finishReason: chunk.finishReason,
usage: {
promptTokens: chunk.usage.promptTokens,
completionTokens: chunk.usage.completionTokens,
totalTokens:
chunk.usage.promptTokens + chunk.usage.completionTokens,
},
});
break;
}

// ignore
case 'finish-metadata':
case 'tool-call-delta': {
break;
}
Expand Down
35 changes: 35 additions & 0 deletions packages/core/core/generate-text/stream-text.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ describe('result.textStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: `world!` },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand Down Expand Up @@ -50,6 +55,11 @@ describe('result.fullStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: `world!` },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand All @@ -64,6 +74,11 @@ describe('result.fullStream', () => {
{ type: 'text-delta', textDelta: 'Hello' },
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: 'world!' },
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down Expand Up @@ -102,6 +117,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand All @@ -124,6 +144,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: { value: 'value' },
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down Expand Up @@ -162,6 +187,11 @@ describe('result.fullStream', () => {
toolName: 'tool1',
args: `{ "value": "value" }`,
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3 },
},
]),
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
};
Expand Down Expand Up @@ -192,6 +222,11 @@ describe('result.fullStream', () => {
args: { value: 'value' },
result: 'value-result',
},
{
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 },
},
],
);
});
Expand Down
12 changes: 11 additions & 1 deletion packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import zodToJsonSchema from 'zod-to-json-schema';
import { LanguageModelV1FinishReason } from '../../ai-model-specification';
import {
LanguageModelV1,
LanguageModelV1CallWarning,
Expand Down Expand Up @@ -92,7 +93,16 @@ export type TextStreamPart<TOOLS extends Record<string, ExperimentalTool>> =
}
| ({
type: 'tool-result';
} & ToToolResult<TOOLS>);
} & ToToolResult<TOOLS>)
| {
type: 'finish';
finishReason: LanguageModelV1FinishReason;
usage: {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
};

export class StreamTextResult<TOOLS extends Record<string, ExperimentalTool>> {
private readonly originalStream: ReadableStream<TextStreamPart<TOOLS>>;
Expand Down
65 changes: 65 additions & 0 deletions packages/core/mistral/mistral-chat-language-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import zodToJsonSchema from 'zod-to-json-schema';
import { LanguageModelV1Prompt } from '../ai-model-specification';
import { convertStreamToArray } from '../ai-model-specification/test/convert-stream-to-array';
import { JsonTestServer } from '../ai-model-specification/test/json-test-server';
import { StreamingTestServer } from '../ai-model-specification/test/streaming-test-server';
import { Mistral } from './mistral-facade';
import { z } from 'zod';

const TEST_PROMPT: LanguageModelV1Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
Expand Down Expand Up @@ -157,6 +159,69 @@ describe('doStream', () => {
{ type: 'text-delta', textDelta: ', ' },
{ type: 'text-delta', textDelta: 'world!' },
{ type: 'text-delta', textDelta: '' },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 32 },
},
]);
});

it('should stream tool deltas', async () => {
server.responseChunks = [
`data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` +
`"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null,"logprobs":null}]}\n\n`,
`data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` +
`"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"name":"test-tool","arguments":` +
`"{\\"value\\":\\"Sparkle Day\\"}"` +
`}}]},"finish_reason":"tool_calls","logprobs":null}],"usage":{"prompt_tokens":183,"total_tokens":316,"completion_tokens":133}}\n\n`,
'data: [DONE]\n\n',
];

const { stream } = await new Mistral({
apiKey: 'test-api-key',
generateId: () => 'test-id',
})
.chat('mistral-large-latest')
.doStream({
inputFormat: 'prompt',
mode: {
type: 'regular',
tools: [
{
type: 'function',
name: 'test-tool',
parameters: zodToJsonSchema(z.object({ value: z.string() })),
},
],
},
prompt: TEST_PROMPT,
});

expect(await convertStreamToArray(stream)).toStrictEqual([
{
type: 'text-delta',
textDelta: '',
},
{
type: 'tool-call-delta',
toolCallId: 'test-id',
toolCallType: 'function',
toolName: 'test-tool',
argsTextDelta: '{"value":"Sparkle Day"}',
},
{
type: 'tool-call',
toolCallId: 'test-id',
toolCallType: 'function',
toolName: 'test-tool',
args: '{"value":"Sparkle Day"}',
},
{
type: 'finish',
finishReason: 'tool-calls',
usage: { promptTokens: 183, completionTokens: 133 },
},
]);
});

Expand Down
41 changes: 37 additions & 4 deletions packages/core/mistral/mistral-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import { z } from 'zod';
import {
LanguageModelV1,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1StreamPart,
ParseResult,
UnsupportedFunctionalityError,
createEventSourceResponseHandler,
createJsonResponseHandler,
generateId,
postJsonToApi,
} from '../ai-model-specification';
import { convertToMistralChatMessages } from './convert-to-mistral-chat-messages';
Expand All @@ -22,6 +22,7 @@ type MistralChatConfig = {
provider: string;
baseUrl: string;
headers: () => Record<string, string | undefined>;
generateId: () => string;
};

export class MistralChatLanguageModel implements LanguageModelV1 {
Expand Down Expand Up @@ -174,7 +175,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
text: choice.message.content ?? undefined,
toolCalls: choice.message.tool_calls?.map(toolCall => ({
toolCallType: 'function',
toolCallId: generateId(),
toolCallId: this.config.generateId(),
toolName: toolCall.function.name,
args: toolCall.function.arguments!,
})),
Expand Down Expand Up @@ -209,6 +210,14 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const { messages: rawPrompt, ...rawSettings } = args;

let finishReason: LanguageModelV1FinishReason = 'other';
let usage: { promptTokens: number; completionTokens: number } = {
promptTokens: Number.NaN,
completionTokens: Number.NaN,
};

const generateId = this.config.generateId;

return {
stream: response.pipeThrough(
new TransformStream<
Expand All @@ -223,11 +232,24 @@ export class MistralChatLanguageModel implements LanguageModelV1 {

const value = chunk.value;

if (value.choices?.[0]?.delta == null) {
if (value.usage != null) {
usage = {
promptTokens: value.usage.prompt_tokens,
completionTokens: value.usage.completion_tokens,
};
}

const choice = value.choices[0];

if (choice?.finish_reason != null) {
finishReason = mapMistralFinishReason(choice.finish_reason);
}

if (choice?.delta == null) {
return;
}

const delta = value.choices[0].delta;
const delta = choice.delta;

if (delta.content != null) {
controller.enqueue({
Expand Down Expand Up @@ -258,6 +280,10 @@ export class MistralChatLanguageModel implements LanguageModelV1 {
}
}
},

flush(controller) {
controller.enqueue({ type: 'finish', finishReason, usage });
},
}),
),
rawCall: { rawPrompt, rawSettings },
Expand Down Expand Up @@ -319,4 +345,11 @@ const mistralChatChunkSchema = z.object({
index: z.number(),
}),
),
usage: z
.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
})
.optional()
.nullable(),
});
Loading

0 comments on commit adccb26

Please sign in to comment.