From adccb26afb9a27d30973daa217cc632683d3d7b0 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Tue, 26 Mar 2024 10:59:24 +0100 Subject: [PATCH] ai/core: add text streaming usage and finish reason support (#1217) --- .../src/stream-text/openai-fullstream.ts | 6 + .../ai-model-specification/errors/index.ts | 1 + .../errors/invalid-data-content-error.ts | 2 +- .../errors/invalid-response-data-error.ts | 41 +++ .../language-model/v1/language-model-v1.ts | 2 +- .../generate-text/run-tools-transformation.ts | 18 +- .../core/generate-text/stream-text.test.ts | 35 +++ .../core/core/generate-text/stream-text.ts | 12 +- .../mistral-chat-language-model.test.ts | 65 +++++ .../mistral/mistral-chat-language-model.ts | 41 ++- packages/core/mistral/mistral-facade.ts | 13 +- .../core/openai/map-openai-finish-reason.ts | 4 +- .../openai/openai-chat-language-model.test.ts | 131 ++++++++- .../core/openai/openai-chat-language-model.ts | 85 +++++- .../openai-completion-language-model.test.ts | 261 ++++++++++++++++++ .../openai-completion-language-model.ts | 35 ++- 16 files changed, 721 insertions(+), 31 deletions(-) create mode 100644 packages/core/ai-model-specification/errors/invalid-response-data-error.ts create mode 100644 packages/core/openai/openai-completion-language-model.test.ts diff --git a/examples/ai-core/src/stream-text/openai-fullstream.ts b/examples/ai-core/src/stream-text/openai-fullstream.ts index d9e32ef499b0..32837432be15 100644 --- a/examples/ai-core/src/stream-text/openai-fullstream.ts +++ b/examples/ai-core/src/stream-text/openai-fullstream.ts @@ -66,6 +66,12 @@ async function main() { break; } + case 'finish': { + console.log('Finish reason:', part.finishReason); + console.log('Usage:', part.usage); + break; + } + case 'error': console.error('Error:', part.error); break; diff --git a/packages/core/ai-model-specification/errors/index.ts b/packages/core/ai-model-specification/errors/index.ts index 2389d987f1fa..43e04f3cb720 100644 --- a/packages/core/ai-model-specification/errors/index.ts +++ b/packages/core/ai-model-specification/errors/index.ts @@ -2,6 +2,7 @@ export * from './api-call-error'; export * from './invalid-argument-error'; export * from './invalid-data-content-error'; export * from './invalid-prompt-error'; +export * from './invalid-response-data-error'; export * from './invalid-tool-arguments-error'; export * from './json-parse-error'; export * from './load-api-key-error'; diff --git a/packages/core/ai-model-specification/errors/invalid-data-content-error.ts b/packages/core/ai-model-specification/errors/invalid-data-content-error.ts index d4b502c9b1ce..05c6c515f63c 100644 --- a/packages/core/ai-model-specification/errors/invalid-data-content-error.ts +++ b/packages/core/ai-model-specification/errors/invalid-data-content-error.ts @@ -21,7 +21,7 @@ export class InvalidDataContentError extends Error { return ( error instanceof Error && error.name === 'AI_InvalidDataContentError' && - prompt != null + (error as InvalidDataContentError).content != null ); } diff --git a/packages/core/ai-model-specification/errors/invalid-response-data-error.ts b/packages/core/ai-model-specification/errors/invalid-response-data-error.ts new file mode 100644 index 000000000000..fcca7b9aa6c3 --- /dev/null +++ b/packages/core/ai-model-specification/errors/invalid-response-data-error.ts @@ -0,0 +1,41 @@ +/** + * Server returned a response with invalid data content. This should be thrown by providers when they + * cannot parse the response from the API. + */ +export class InvalidResponseDataError extends Error { + readonly data: unknown; + + constructor({ + data, + message = `Invalid response data: ${JSON.stringify(data)}.`, + }: { + data: unknown; + message?: string; + }) { + super(message); + + this.name = 'AI_InvalidResponseDataError'; + + this.data = data; + } + + static isInvalidResponseDataError( + error: unknown, + ): error is InvalidResponseDataError { + return ( + error instanceof Error && + error.name === 'AI_InvalidResponseDataError' && + (error as InvalidResponseDataError).data != null + ); + } + + toJSON() { + return { + name: this.name, + message: this.message, + stack: this.stack, + + data: this.data, + }; + } +} diff --git a/packages/core/ai-model-specification/language-model/v1/language-model-v1.ts b/packages/core/ai-model-specification/language-model/v1/language-model-v1.ts index 3368db6f6dcb..2c6f1e7a9930 100644 --- a/packages/core/ai-model-specification/language-model/v1/language-model-v1.ts +++ b/packages/core/ai-model-specification/language-model/v1/language-model-v1.ts @@ -137,7 +137,7 @@ export type LanguageModelV1StreamPart = // the usage stats and finish reason should be the last part of the // stream: | { - type: 'finish-metadata'; + type: 'finish'; finishReason: LanguageModelV1FinishReason; usage: { promptTokens: number; completionTokens: number }; } diff --git a/packages/core/core/generate-text/run-tools-transformation.ts b/packages/core/core/generate-text/run-tools-transformation.ts index 61f890dc3739..c5e1ea10c89f 100644 --- a/packages/core/core/generate-text/run-tools-transformation.ts +++ b/packages/core/core/generate-text/run-tools-transformation.ts @@ -48,7 +48,7 @@ export function runToolsTransformation< break; } - // process + // process tool call: case 'tool-call': { const toolName = chunk.toolName as keyof TOOLS & string; @@ -135,8 +135,22 @@ export function runToolsTransformation< break; } + // process finish: + case 'finish': { + controller.enqueue({ + type: 'finish', + finishReason: chunk.finishReason, + usage: { + promptTokens: chunk.usage.promptTokens, + completionTokens: chunk.usage.completionTokens, + totalTokens: + chunk.usage.promptTokens + chunk.usage.completionTokens, + }, + }); + break; + } + // ignore - case 'finish-metadata': case 'tool-call-delta': { break; } diff --git a/packages/core/core/generate-text/stream-text.test.ts b/packages/core/core/generate-text/stream-text.test.ts index 5d5bf6ca2cb7..21cd41f907af 100644 --- a/packages/core/core/generate-text/stream-text.test.ts +++ b/packages/core/core/generate-text/stream-text.test.ts @@ -20,6 +20,11 @@ describe('result.textStream', () => { { type: 'text-delta', textDelta: 'Hello' }, { type: 'text-delta', textDelta: ', ' }, { type: 'text-delta', textDelta: `world!` }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3 }, + }, ]), rawCall: { rawPrompt: 'prompt', rawSettings: {} }, }; @@ -50,6 +55,11 @@ describe('result.fullStream', () => { { type: 'text-delta', textDelta: 'Hello' }, { type: 'text-delta', textDelta: ', ' }, { type: 'text-delta', textDelta: `world!` }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3 }, + }, ]), rawCall: { rawPrompt: 'prompt', rawSettings: {} }, }; @@ -64,6 +74,11 @@ describe('result.fullStream', () => { { type: 'text-delta', textDelta: 'Hello' }, { type: 'text-delta', textDelta: ', ' }, { type: 'text-delta', textDelta: 'world!' }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 }, + }, ], ); }); @@ -102,6 +117,11 @@ describe('result.fullStream', () => { toolName: 'tool1', args: `{ "value": "value" }`, }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3 }, + }, ]), rawCall: { rawPrompt: 'prompt', rawSettings: {} }, }; @@ -124,6 +144,11 @@ describe('result.fullStream', () => { toolName: 'tool1', args: { value: 'value' }, }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 }, + }, ], ); }); @@ -162,6 +187,11 @@ describe('result.fullStream', () => { toolName: 'tool1', args: `{ "value": "value" }`, }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3 }, + }, ]), rawCall: { rawPrompt: 'prompt', rawSettings: {} }, }; @@ -192,6 +222,11 @@ describe('result.fullStream', () => { args: { value: 'value' }, result: 'value-result', }, + { + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 10, promptTokens: 3, totalTokens: 13 }, + }, ], ); }); diff --git a/packages/core/core/generate-text/stream-text.ts b/packages/core/core/generate-text/stream-text.ts index 32978e7c5b60..72dce5938f32 100644 --- a/packages/core/core/generate-text/stream-text.ts +++ b/packages/core/core/generate-text/stream-text.ts @@ -1,4 +1,5 @@ import zodToJsonSchema from 'zod-to-json-schema'; +import { LanguageModelV1FinishReason } from '../../ai-model-specification'; import { LanguageModelV1, LanguageModelV1CallWarning, @@ -92,7 +93,16 @@ export type TextStreamPart> = } | ({ type: 'tool-result'; - } & ToToolResult); + } & ToToolResult) + | { + type: 'finish'; + finishReason: LanguageModelV1FinishReason; + usage: { + promptTokens: number; + completionTokens: number; + totalTokens: number; + }; + }; export class StreamTextResult> { private readonly originalStream: ReadableStream>; diff --git a/packages/core/mistral/mistral-chat-language-model.test.ts b/packages/core/mistral/mistral-chat-language-model.test.ts index 00692f187c70..f2732526520c 100644 --- a/packages/core/mistral/mistral-chat-language-model.test.ts +++ b/packages/core/mistral/mistral-chat-language-model.test.ts @@ -1,8 +1,10 @@ +import zodToJsonSchema from 'zod-to-json-schema'; import { LanguageModelV1Prompt } from '../ai-model-specification'; import { convertStreamToArray } from '../ai-model-specification/test/convert-stream-to-array'; import { JsonTestServer } from '../ai-model-specification/test/json-test-server'; import { StreamingTestServer } from '../ai-model-specification/test/streaming-test-server'; import { Mistral } from './mistral-facade'; +import { z } from 'zod'; const TEST_PROMPT: LanguageModelV1Prompt = [ { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, @@ -157,6 +159,69 @@ describe('doStream', () => { { type: 'text-delta', textDelta: ', ' }, { type: 'text-delta', textDelta: 'world!' }, { type: 'text-delta', textDelta: '' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 4, completionTokens: 32 }, + }, + ]); + }); + + it('should stream tool deltas', async () => { + server.responseChunks = [ + `data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` + + `"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null,"logprobs":null}]}\n\n`, + `data: {"id":"ad6f7ce6543c4d0890280ae184fe4dd8","object":"chat.completion.chunk","created":1711365023,"model":"mistral-large-latest",` + + `"choices":[{"index":0,"delta":{"content":null,"tool_calls":[{"function":{"name":"test-tool","arguments":` + + `"{\\"value\\":\\"Sparkle Day\\"}"` + + `}}]},"finish_reason":"tool_calls","logprobs":null}],"usage":{"prompt_tokens":183,"total_tokens":316,"completion_tokens":133}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await new Mistral({ + apiKey: 'test-api-key', + generateId: () => 'test-id', + }) + .chat('mistral-large-latest') + .doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: zodToJsonSchema(z.object({ value: z.string() })), + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertStreamToArray(stream)).toStrictEqual([ + { + type: 'text-delta', + textDelta: '', + }, + { + type: 'tool-call-delta', + toolCallId: 'test-id', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"value":"Sparkle Day"}', + }, + { + type: 'tool-call', + toolCallId: 'test-id', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 183, completionTokens: 133 }, + }, ]); }); diff --git a/packages/core/mistral/mistral-chat-language-model.ts b/packages/core/mistral/mistral-chat-language-model.ts index 840e8a1b4070..f2f7949ba722 100644 --- a/packages/core/mistral/mistral-chat-language-model.ts +++ b/packages/core/mistral/mistral-chat-language-model.ts @@ -2,12 +2,12 @@ import { z } from 'zod'; import { LanguageModelV1, LanguageModelV1CallWarning, + LanguageModelV1FinishReason, LanguageModelV1StreamPart, ParseResult, UnsupportedFunctionalityError, createEventSourceResponseHandler, createJsonResponseHandler, - generateId, postJsonToApi, } from '../ai-model-specification'; import { convertToMistralChatMessages } from './convert-to-mistral-chat-messages'; @@ -22,6 +22,7 @@ type MistralChatConfig = { provider: string; baseUrl: string; headers: () => Record; + generateId: () => string; }; export class MistralChatLanguageModel implements LanguageModelV1 { @@ -174,7 +175,7 @@ export class MistralChatLanguageModel implements LanguageModelV1 { text: choice.message.content ?? undefined, toolCalls: choice.message.tool_calls?.map(toolCall => ({ toolCallType: 'function', - toolCallId: generateId(), + toolCallId: this.config.generateId(), toolName: toolCall.function.name, args: toolCall.function.arguments!, })), @@ -209,6 +210,14 @@ export class MistralChatLanguageModel implements LanguageModelV1 { const { messages: rawPrompt, ...rawSettings } = args; + let finishReason: LanguageModelV1FinishReason = 'other'; + let usage: { promptTokens: number; completionTokens: number } = { + promptTokens: Number.NaN, + completionTokens: Number.NaN, + }; + + const generateId = this.config.generateId; + return { stream: response.pipeThrough( new TransformStream< @@ -223,11 +232,24 @@ export class MistralChatLanguageModel implements LanguageModelV1 { const value = chunk.value; - if (value.choices?.[0]?.delta == null) { + if (value.usage != null) { + usage = { + promptTokens: value.usage.prompt_tokens, + completionTokens: value.usage.completion_tokens, + }; + } + + const choice = value.choices[0]; + + if (choice?.finish_reason != null) { + finishReason = mapMistralFinishReason(choice.finish_reason); + } + + if (choice?.delta == null) { return; } - const delta = value.choices[0].delta; + const delta = choice.delta; if (delta.content != null) { controller.enqueue({ @@ -258,6 +280,10 @@ export class MistralChatLanguageModel implements LanguageModelV1 { } } }, + + flush(controller) { + controller.enqueue({ type: 'finish', finishReason, usage }); + }, }), ), rawCall: { rawPrompt, rawSettings }, @@ -319,4 +345,11 @@ const mistralChatChunkSchema = z.object({ index: z.number(), }), ), + usage: z + .object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + }) + .optional() + .nullable(), }); diff --git a/packages/core/mistral/mistral-facade.ts b/packages/core/mistral/mistral-facade.ts index 6169321a9d93..41e276d0e34a 100644 --- a/packages/core/mistral/mistral-facade.ts +++ b/packages/core/mistral/mistral-facade.ts @@ -1,4 +1,4 @@ -import { loadApiKey } from '../ai-model-specification'; +import { generateId, loadApiKey } from '../ai-model-specification'; import { MistralChatLanguageModel } from './mistral-chat-language-model'; import { MistralChatModelId, @@ -12,11 +12,19 @@ export class Mistral { readonly baseUrl?: string; readonly apiKey?: string; + private readonly generateId: () => string; + constructor( - options: { baseUrl?: string; apiKey?: string; organization?: string } = {}, + options: { + baseUrl?: string; + apiKey?: string; + organization?: string; + generateId?: () => string; + } = {}, ) { this.baseUrl = options.baseUrl; this.apiKey = options.apiKey; + this.generateId = options.generateId ?? generateId; } private get baseConfig() { @@ -36,6 +44,7 @@ export class Mistral { return new MistralChatLanguageModel(modelId, settings, { provider: 'mistral.chat', ...this.baseConfig, + generateId: this.generateId, }); } } diff --git a/packages/core/openai/map-openai-finish-reason.ts b/packages/core/openai/map-openai-finish-reason.ts index fbbfaa9e3a40..a1c1ee89e295 100644 --- a/packages/core/openai/map-openai-finish-reason.ts +++ b/packages/core/openai/map-openai-finish-reason.ts @@ -8,10 +8,10 @@ export function mapOpenAIFinishReason( return 'stop'; case 'length': return 'length'; - case 'content-filter': + case 'content_filter': return 'content-filter'; case 'function_call': - case 'tool-calls': + case 'tool_calls': return 'tool-calls'; default: return 'other'; diff --git a/packages/core/openai/openai-chat-language-model.test.ts b/packages/core/openai/openai-chat-language-model.test.ts index 80ef79d17ed3..f9c4628d8967 100644 --- a/packages/core/openai/openai-chat-language-model.test.ts +++ b/packages/core/openai/openai-chat-language-model.test.ts @@ -1,3 +1,5 @@ +import { z } from 'zod'; +import zodToJsonSchema from 'zod-to-json-schema'; import { LanguageModelV1Prompt } from '../ai-model-specification'; import { convertStreamToArray } from '../ai-model-specification/test/convert-stream-to-array'; import { JsonTestServer } from '../ai-model-specification/test/json-test-server'; @@ -126,16 +128,18 @@ describe('doStream', () => { function prepareStreamResponse({ content }: { content: string[] }) { server.responseChunks = [ - `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + `"system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}\n\n`, ...content.map(text => { return ( - `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + `"system_fingerprint":null,"choices":[{"index":1,"delta":{"content":"${text}"},"finish_reason":null}]}\n\n` ); }), - `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0613",` + `"system_fingerprint":null,"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1702657020,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":227,"total_tokens":244}}\n\n`, 'data: [DONE]\n\n', ]; } @@ -155,10 +159,129 @@ describe('doStream', () => { { type: 'text-delta', textDelta: 'Hello' }, { type: 'text-delta', textDelta: ', ' }, { type: 'text-delta', textDelta: 'World!' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 17, completionTokens: 227 }, + }, ]); }); - it('should pass the messages', async () => { + it('should stream tool deltas', async () => { + server.responseChunks = [ + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,` + + `"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\""}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\":\\""}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},` + + `"logprobs":null,"finish_reason":null}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}\n\n`, + `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125",` + + `"system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}\n\n`, + 'data: [DONE]\n\n', + ]; + + const { stream } = await openai.chat('gpt-3.5-turbo').doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: zodToJsonSchema(z.object({ value: z.string() })), + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertStreamToArray(stream)).toStrictEqual([ + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'value', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '":"', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'Spark', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'le', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: ' Day', + }, + { + type: 'tool-call-delta', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '"}', + }, + { + type: 'tool-call', + toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 53, completionTokens: 17 }, + }, + ]); + }); + + it('should pass the messages and the model', async () => { prepareStreamResponse({ content: [] }); await openai.chat('gpt-3.5-turbo').doStream({ diff --git a/packages/core/openai/openai-chat-language-model.ts b/packages/core/openai/openai-chat-language-model.ts index c5388c0a4cda..ad1c6945bcb9 100644 --- a/packages/core/openai/openai-chat-language-model.ts +++ b/packages/core/openai/openai-chat-language-model.ts @@ -1,6 +1,8 @@ import { z } from 'zod'; import { + InvalidResponseDataError, LanguageModelV1, + LanguageModelV1FinishReason, LanguageModelV1StreamPart, ParseResult, UnsupportedFunctionalityError, @@ -199,14 +201,20 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { const { messages: rawPrompt, ...rawSettings } = args; const toolCalls: Array<{ - id?: string; - type?: 'function'; - function?: { - name?: string; - arguments?: string; + id: string; + type: 'function'; + function: { + name: string; + arguments: string; }; }> = []; + let finishReason: LanguageModelV1FinishReason = 'other'; + let usage: { promptTokens: number; completionTokens: number } = { + promptTokens: Number.NaN, + completionTokens: Number.NaN, + }; + return { stream: response.pipeThrough( new TransformStream< @@ -221,11 +229,24 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { const value = chunk.value; - if (value.choices?.[0]?.delta == null) { + if (value.usage != null) { + usage = { + promptTokens: value.usage.prompt_tokens, + completionTokens: value.usage.completion_tokens, + }; + } + + const choice = value.choices[0]; + + if (choice?.finish_reason != null) { + finishReason = mapOpenAIFinishReason(choice.finish_reason); + } + + if (choice?.delta == null) { return; } - const delta = value.choices[0].delta; + const delta = choice.delta; if (delta.content != null) { controller.enqueue({ @@ -238,9 +259,38 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { for (const toolCallDelta of delta.tool_calls) { const index = toolCallDelta.index; - // new tool call, add to list + // Tool call start. OpenAI returns all information except the arguments in the first chunk. if (toolCalls[index] == null) { - toolCalls[index] = toolCallDelta; + if (toolCallDelta.type !== 'function') { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'function' type.`, + }); + } + + if (toolCallDelta.id == null) { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'id' to be a string.`, + }); + } + + if (toolCallDelta.function?.name == null) { + throw new InvalidResponseDataError({ + data: toolCallDelta, + message: `Expected 'function.name' to be a string.`, + }); + } + + toolCalls[index] = { + id: toolCallDelta.id, + type: 'function', + function: { + name: toolCallDelta.function.name, + arguments: toolCallDelta.function.arguments ?? '', + }, + }; + continue; } @@ -256,9 +306,9 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { controller.enqueue({ type: 'tool-call-delta', toolCallType: 'function', - toolCallId: toolCall.id ?? '', // TODO empty? - toolName: toolCall.function?.name ?? '', // TODO empty? - argsTextDelta: toolCallDelta.function?.arguments ?? '', // TODO empty? + toolCallId: toolCall.id, + toolName: toolCall.function.name, + argsTextDelta: toolCallDelta.function.arguments ?? '', }); // check if tool call is complete @@ -280,6 +330,10 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { } } }, + + flush(controller) { + controller.enqueue({ type: 'finish', finishReason, usage }); + }, }), ), rawCall: { rawPrompt, rawSettings }, @@ -347,4 +401,11 @@ const openaiChatChunkSchema = z.object({ index: z.number(), }), ), + usage: z + .object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + }) + .optional() + .nullable(), }); diff --git a/packages/core/openai/openai-completion-language-model.test.ts b/packages/core/openai/openai-completion-language-model.test.ts new file mode 100644 index 000000000000..761e50e8a5b7 --- /dev/null +++ b/packages/core/openai/openai-completion-language-model.test.ts @@ -0,0 +1,261 @@ +import { LanguageModelV1Prompt } from '../ai-model-specification'; +import { convertStreamToArray } from '../ai-model-specification/test/convert-stream-to-array'; +import { JsonTestServer } from '../ai-model-specification/test/json-test-server'; +import { StreamingTestServer } from '../ai-model-specification/test/streaming-test-server'; +import { OpenAI } from './openai-facade'; + +const TEST_PROMPT: LanguageModelV1Prompt = [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, +]; + +const openai = new OpenAI({ + apiKey: 'test-api-key', +}); + +describe('doGenerate', () => { + const server = new JsonTestServer('https://api.openai.com/v1/completions'); + + server.setupTestEnvironment(); + + function prepareJsonResponse({ + content = '', + usage = { + prompt_tokens: 4, + total_tokens: 34, + completion_tokens: 30, + }, + }: { + content?: string; + usage?: { + prompt_tokens: number; + total_tokens: number; + completion_tokens: number; + }; + }) { + server.responseBodyJson = { + id: 'cmpl-96cAM1v77r4jXa4qb2NSmRREV5oWB', + object: 'text_completion', + created: 1711363706, + model: 'gpt-3.5-turbo-instruct', + choices: [ + { + text: content, + index: 0, + logprobs: null, + finish_reason: 'stop', + }, + ], + usage, + }; + } + + it('should extract text response', async () => { + prepareJsonResponse({ content: 'Hello, World!' }); + + const { text } = await openai + .completion('gpt-3.5-turbo-instruct') + .doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(text).toStrictEqual('Hello, World!'); + }); + + it('should extract usage', async () => { + prepareJsonResponse({ + content: '', + usage: { prompt_tokens: 20, total_tokens: 25, completion_tokens: 5 }, + }); + + const { usage } = await openai + .completion('gpt-3.5-turbo-instruct') + .doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(usage).toStrictEqual({ + promptTokens: 20, + completionTokens: 5, + }); + }); + + it('should pass the model and the prompt', async () => { + prepareJsonResponse({ content: '' }); + + await openai.completion('gpt-3.5-turbo-instruct').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'gpt-3.5-turbo-instruct', + prompt: 'Hello', + }); + }); + + it('should pass the api key as Authorization header', async () => { + prepareJsonResponse({ content: '' }); + + const openai = new OpenAI({ apiKey: 'test-api-key' }); + + await openai.completion('gpt-3.5-turbo-instruct').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect( + (await server.getRequestHeaders()).get('Authorization'), + ).toStrictEqual('Bearer test-api-key'); + }); +}); + +describe('doStream', () => { + const server = new StreamingTestServer( + 'https://api.openai.com/v1/completions', + ); + + server.setupTestEnvironment(); + + function prepareStreamResponse({ content }: { content: string[] }) { + server.responseChunks = [ + ...content.map(text => { + return ( + `data: {"id":"cmpl-96c64EdfhOw8pjFFgVpLuT8k2MtdT","object":"text_completion","created":1711363440,` + + `"choices":[{"text":"${text}","index":0,"logprobs":null,"finish_reason":null}],"model":"gpt-3.5-turbo-instruct"}\n\n` + ); + }), + `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + + `"choices":[{"text":"","index":0,"logprobs":null,"finish_reason":"stop"}],"model":"gpt-3.5-turbo-instruct"}\n\n`, + `data: {"id":"cmpl-96c3yLQE1TtZCd6n6OILVmzev8M8H","object":"text_completion","created":1711363310,` + + `"model":"gpt-3.5-turbo-instruct","usage":{"prompt_tokens":10,"completion_tokens":362,"total_tokens":372},"choices":[]}\n\n`, + 'data: [DONE]\n\n', + ]; + } + + it('should stream text deltas', async () => { + prepareStreamResponse({ content: ['Hello', ', ', 'World!'] }); + + const { stream } = await openai + .completion('gpt-3.5-turbo-instruct') + .doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + // note: space moved to last chunk bc of trimming + expect(await convertStreamToArray(stream)).toStrictEqual([ + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'World!' }, + { type: 'text-delta', textDelta: '' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 10, completionTokens: 362 }, + }, + ]); + }); + + it('should pass the model and the prompt', async () => { + prepareStreamResponse({ content: [] }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + stream: true, + model: 'gpt-3.5-turbo-instruct', + prompt: 'Hello', + }); + }); + + it('should scale the temperature', async () => { + prepareStreamResponse({ content: [] }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + temperature: 0.5, + }); + + expect((await server.getRequestBodyJson()).temperature).toBeCloseTo(1, 5); + }); + + it('should scale the frequency penalty', async () => { + prepareStreamResponse({ content: [] }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + frequencyPenalty: 0.2, + }); + + expect((await server.getRequestBodyJson()).frequency_penalty).toBeCloseTo( + 0.4, + 5, + ); + }); + + it('should scale the presence penalty', async () => { + prepareStreamResponse({ content: [] }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + presencePenalty: -0.9, + }); + + expect((await server.getRequestBodyJson()).presence_penalty).toBeCloseTo( + -1.8, + 5, + ); + }); + + it('should pass the organization as OpenAI-Organization header', async () => { + prepareStreamResponse({ content: [] }); + + const openai = new OpenAI({ + apiKey: 'test-api-key', + organization: 'test-organization', + }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect( + (await server.getRequestHeaders()).get('OpenAI-Organization'), + ).toStrictEqual('test-organization'); + }); + + it('should pass the api key as Authorization header', async () => { + prepareStreamResponse({ content: [] }); + + const openai = new OpenAI({ apiKey: 'test-api-key' }); + + await openai.completion('gpt-3.5-turbo-instruct').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect( + (await server.getRequestHeaders()).get('Authorization'), + ).toStrictEqual('Bearer test-api-key'); + }); +}); diff --git a/packages/core/openai/openai-completion-language-model.ts b/packages/core/openai/openai-completion-language-model.ts index b183a8a3aa1c..be5d11403191 100644 --- a/packages/core/openai/openai-completion-language-model.ts +++ b/packages/core/openai/openai-completion-language-model.ts @@ -1,6 +1,7 @@ import { z } from 'zod'; import { LanguageModelV1, + LanguageModelV1FinishReason, LanguageModelV1StreamPart, ParseResult, UnsupportedFunctionalityError, @@ -198,6 +199,12 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { const { prompt: rawPrompt, ...rawSettings } = args; + let finishReason: LanguageModelV1FinishReason = 'other'; + let usage: { promptTokens: number; completionTokens: number } = { + promptTokens: Number.NaN, + completionTokens: Number.NaN, + }; + return { stream: response.pipeThrough( new TransformStream< @@ -212,13 +219,30 @@ export class OpenAICompletionLanguageModel implements LanguageModelV1 { const value = chunk.value; - if (value.choices?.[0]?.text != null) { + if (value.usage != null) { + usage = { + promptTokens: value.usage.prompt_tokens, + completionTokens: value.usage.completion_tokens, + }; + } + + const choice = value.choices[0]; + + if (choice?.finish_reason != null) { + finishReason = mapOpenAIFinishReason(choice.finish_reason); + } + + if (choice?.text != null) { controller.enqueue({ type: 'text-delta', - textDelta: value.choices[0].text, + textDelta: choice.text, }); } }, + + flush(controller) { + controller.enqueue({ type: 'finish', finishReason, usage }); + }, }), ), rawCall: { rawPrompt, rawSettings }, @@ -256,4 +280,11 @@ const openaiCompletionChunkSchema = z.object({ index: z.number(), }), ), + usage: z + .object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + }) + .optional() + .nullable(), });