From 66b589245112aa8ec0ff2ba5445166afe622b65d Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Tue, 16 Apr 2024 16:19:18 +0200 Subject: [PATCH] Introduce streamMode for useChat / useCompletion. (#1350) --- .changeset/empty-windows-think.md | 5 + examples/solidstart-openai/package.json | 1 + .../src/routes/api/chat/index.ts | 32 +-- packages/core/react/use-chat.ts | 4 + packages/core/react/use-chat.ui.test.tsx | 264 +++++++++++------- packages/core/react/use-completion.ts | 2 + .../core/react/use-completion.ui.test.tsx | 180 +++++++----- packages/core/shared/call-chat-api.ts | 71 ++++- packages/core/shared/call-completion-api.ts | 56 +++- packages/core/shared/types.ts | 6 + packages/core/solid/use-chat.ts | 2 + packages/core/solid/use-chat.ui.test.tsx | 232 +++++++++------ packages/core/solid/use-completion.ts | 2 + .../core/solid/use-completion.ui.test.tsx | 162 +++++++---- packages/core/svelte/use-chat.ts | 4 + packages/core/svelte/use-completion.ts | 2 + .../core/vue/TestChatTextStreamComponent.vue | 28 ++ .../vue/TestCompletionTextStreamComponent.vue | 19 ++ packages/core/vue/use-chat.ts | 2 + packages/core/vue/use-chat.ui.test.tsx | 153 ++++++---- packages/core/vue/use-completion.ts | 2 + packages/core/vue/use-completion.ui.test.ts | 102 ++++--- pnpm-lock.yaml | 3 + 23 files changed, 879 insertions(+), 455 deletions(-) create mode 100644 .changeset/empty-windows-think.md create mode 100644 packages/core/vue/TestChatTextStreamComponent.vue create mode 100644 packages/core/vue/TestCompletionTextStreamComponent.vue diff --git a/.changeset/empty-windows-think.md b/.changeset/empty-windows-think.md new file mode 100644 index 000000000000..5ed1789737f4 --- /dev/null +++ b/.changeset/empty-windows-think.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +Add streamMode parameter to useChat and useCompletion. diff --git a/examples/solidstart-openai/package.json b/examples/solidstart-openai/package.json index 11223e274ab1..dbd5dcb54420 100644 --- a/examples/solidstart-openai/package.json +++ b/examples/solidstart-openai/package.json @@ -16,6 +16,7 @@ "vite": "^4.1.4" }, "dependencies": { + "@ai-sdk/openai": "latest", "@solidjs/meta": "0.29.3", "@solidjs/router": "0.8.2", "ai": "latest", diff --git a/examples/solidstart-openai/src/routes/api/chat/index.ts b/examples/solidstart-openai/src/routes/api/chat/index.ts index f75766bae9b3..4b77046cb4be 100644 --- a/examples/solidstart-openai/src/routes/api/chat/index.ts +++ b/examples/solidstart-openai/src/routes/api/chat/index.ts @@ -1,25 +1,19 @@ -import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI from 'openai'; +import { openai } from '@ai-sdk/openai'; +import { StreamingTextResponse, experimental_streamText } from 'ai'; import { APIEvent } from 'solid-start/api'; -// Create an OpenAI API client -const openai = new OpenAI({ - apiKey: process.env['OPENAI_API_KEY'] || '', -}); - export const POST = async (event: APIEvent) => { - // Extract the `prompt` from the body of the request - const { messages } = await event.request.json(); + try { + const { messages } = await event.request.json(); - // Ask OpenAI for a streaming chat completion given the prompt - const response = await openai.chat.completions.create({ - model: 'gpt-3.5-turbo', - stream: true, - messages, - }); + const result = await experimental_streamText({ + model: openai.chat('gpt-4-turbo-preview'), + messages, + }); - // Convert the response into a friendly text-stream - const stream = OpenAIStream(response); - // Respond with the stream - return new StreamingTextResponse(stream); + return new StreamingTextResponse(result.toAIStream()); + } catch (error) { + console.error(error); + throw error; + } }; diff --git a/packages/core/react/use-chat.ts b/packages/core/react/use-chat.ts index 869cc58433c1..523b958bef19 100644 --- a/packages/core/react/use-chat.ts +++ b/packages/core/react/use-chat.ts @@ -88,6 +88,7 @@ const getStreamedResponse = async ( messagesRef: React.MutableRefObject, abortControllerRef: React.MutableRefObject, generateId: IdGenerator, + streamMode?: 'stream-data' | 'text', onFinish?: (message: Message) => void, onResponse?: (response: Response) => void | Promise, sendExtraMessageFields?: boolean, @@ -179,6 +180,7 @@ const getStreamedResponse = async ( tool_choice: chatRequest.tool_choice, }), }, + streamMode, credentials: extraMetadataRef.current.credentials, headers: { ...extraMetadataRef.current.headers, @@ -206,6 +208,7 @@ export function useChat({ sendExtraMessageFields, experimental_onFunctionCall, experimental_onToolCall, + streamMode, onResponse, onFinish, onError, @@ -292,6 +295,7 @@ export function useChat({ messagesRef, abortControllerRef, generateId, + streamMode, onFinish, onResponse, sendExtraMessageFields, diff --git a/packages/core/react/use-chat.ui.test.tsx b/packages/core/react/use-chat.ui.test.tsx index 9ac94284e4c3..7100d669df53 100644 --- a/packages/core/react/use-chat.ui.test.tsx +++ b/packages/core/react/use-chat.ui.test.tsx @@ -9,145 +9,201 @@ import { } from '../tests/utils/mock-fetch'; import { useChat } from './use-chat'; -const TestComponent = () => { - const [id, setId] = React.useState('first-id'); - const { messages, append, error, data, isLoading } = useChat({ id }); - - return ( -
-
{isLoading.toString()}
- {error &&
{error.toString()}
} - {data &&
{JSON.stringify(data)}
} - {messages.map((m, idx) => ( -
- {m.role === 'user' ? 'User: ' : 'AI: '} - {m.content} -
- ))} - -
- ); -}; - -beforeEach(() => { - render(); -}); +describe('stream data stream', () => { + const TestComponent = () => { + const [id, setId] = React.useState('first-id'); + const { messages, append, error, data, isLoading } = useChat({ id }); + + return ( +
+
{isLoading.toString()}
+ {error &&
{error.toString()}
} + {data &&
{JSON.stringify(data)}
} + {messages.map((m, idx) => ( +
+ {m.role === 'user' ? 'User: ' : 'AI: '} + {m.content} +
+ ))} + +
+ ); + }; -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + beforeEach(() => { + render(); + }); -test('Shows streamed complex text response', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); }); - await userEvent.click(screen.getByTestId('do-append')); + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); - await screen.findByTestId('message-0'); - expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + await userEvent.click(screen.getByTestId('do-append')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent( - 'AI: Hello, world.', - ); -}); + await screen.findByTestId('message-0'); + expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); -test('Shows streamed complex text response with data', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent( + 'AI: Hello, world.', + ); }); - await userEvent.click(screen.getByTestId('do-append')); + it('should show streamed response with data', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + }); - await screen.findByTestId('data'); - expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); + await userEvent.click(screen.getByTestId('do-append')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); -}); + await screen.findByTestId('data'); + expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); -test('Shows error response', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); + }); - await userEvent.click(screen.getByTestId('do-append')); + it('should show error response', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - // TODO bug? the user message does not show up - // await screen.findByTestId('message-0'); - // expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + await userEvent.click(screen.getByTestId('do-append')); - await screen.findByTestId('error'); - expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); -}); + // TODO bug? the user message does not show up + // await screen.findByTestId('message-0'); + // expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + + await screen.findByTestId('error'); + expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); + }); + + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); -describe('loading state', () => { - test('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); + + await userEvent.click(screen.getByTestId('do-append')); + + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); + + finishGeneration?.(); + + await findByText(await screen.findByTestId('loading'), 'false'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + + await userEvent.click(screen.getByTestId('do-append')); + + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); + }); - await userEvent.click(screen.getByTestId('do-append')); + describe('id', () => { + it('should clear out messages when the id changes', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.click(screen.getByTestId('do-append')); - finishGeneration?.(); + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent( + 'AI: Hello, world.', + ); - await findByText(await screen.findByTestId('loading'), 'false'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await userEvent.click(screen.getByTestId('do-change-id')); + + expect(screen.queryByTestId('message-0')).not.toBeInTheDocument(); + }); }); +}); - test('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); +describe('text stream', () => { + const TestComponent = () => { + const { messages, append } = useChat({ + streamMode: 'text', + }); - await userEvent.click(screen.getByTestId('do-append')); + return ( +
+ {messages.map((m, idx) => ( +
+ {m.role === 'user' ? 'User: ' : 'AI: '} + {m.content} +
+ ))} + +
+ ); + }; - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + beforeEach(() => { + render(); }); -}); -describe('id', () => { - it('should clear out messages when the id changes', async () => { + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should show streamed response', async () => { mockFetchDataStream({ url: 'https://example.com/api/chat', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + chunks: ['Hello', ',', ' world', '.'], }); - await userEvent.click(screen.getByTestId('do-append')); + await userEvent.click(screen.getByTestId('do-append-text-stream')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent( - 'AI: Hello, world.', + await screen.findByTestId('message-0-text-stream'); + expect(screen.getByTestId('message-0-text-stream')).toHaveTextContent( + 'User: hi', ); - await userEvent.click(screen.getByTestId('do-change-id')); - - expect(screen.queryByTestId('message-0')).not.toBeInTheDocument(); + await screen.findByTestId('message-1-text-stream'); + expect(screen.getByTestId('message-1-text-stream')).toHaveTextContent( + 'AI: Hello, world.', + ); }); }); diff --git a/packages/core/react/use-completion.ts b/packages/core/react/use-completion.ts index f5cdb2acca42..2b0dfaf5121d 100644 --- a/packages/core/react/use-completion.ts +++ b/packages/core/react/use-completion.ts @@ -69,6 +69,7 @@ export function useCompletion({ credentials, headers, body, + streamMode, onResponse, onFinish, onError, @@ -122,6 +123,7 @@ export function useCompletion({ ...extraMetadataRef.current.body, ...options?.body, }, + streamMode, setCompletion: completion => mutate(completion, false), setLoading: mutateLoading, setError, diff --git a/packages/core/react/use-completion.ui.test.tsx b/packages/core/react/use-completion.ui.test.tsx index ff7a1df39d05..2ec15c51b6b4 100644 --- a/packages/core/react/use-completion.ui.test.tsx +++ b/packages/core/react/use-completion.ui.test.tsx @@ -8,87 +8,135 @@ import { } from '../tests/utils/mock-fetch'; import { useCompletion } from './use-completion'; -const TestComponent = () => { - const { - completion, - handleSubmit, - error, - handleInputChange, - input, - isLoading, - } = useCompletion(); - - return ( -
-
{isLoading.toString()}
-
{error?.toString()}
-
{completion}
-
- -
-
- ); -}; - -beforeEach(() => { - render(); -}); +describe('stream data stream', () => { + const TestComponent = () => { + const { + completion, + handleSubmit, + error, + handleInputChange, + input, + isLoading, + } = useCompletion(); + + return ( +
+
{isLoading.toString()}
+
{error?.toString()}
+
{completion}
+
+ +
+
+ ); + }; + + beforeEach(() => { + render(); + }); -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should render stream', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); + + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); -it('should render complex text stream', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/completion', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + await screen.findByTestId('completion'); + expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); }); - await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); - await screen.findByTestId('completion'); - expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); -}); + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); -describe('loading state', () => { - it('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; - }); + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); + + finishGeneration?.(); + + await findByText(await screen.findByTestId('loading'), 'false'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); + }); + }); +}); - finishGeneration?.(); +describe('text stream', () => { + const TestComponent = () => { + const { completion, handleSubmit, handleInputChange, input } = + useCompletion({ + streamMode: 'text', + }); + + return ( +
+
{completion}
+
+ +
+
+ ); + }; + + beforeEach(() => { + render(); + }); - await findByText(await screen.findByTestId('loading'), 'false'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); }); - it('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + it('should render stream', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['Hello', ',', ' world', '.'], + }); - await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + await userEvent.type(screen.getByTestId('input-text-stream'), 'hi{enter}'); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await screen.findByTestId('completion-text-stream'); + expect(screen.getByTestId('completion-text-stream')).toHaveTextContent( + 'Hello, world.', + ); }); }); diff --git a/packages/core/shared/call-chat-api.ts b/packages/core/shared/call-chat-api.ts index 8b30888ecb3d..54c9b495d099 100644 --- a/packages/core/shared/call-chat-api.ts +++ b/packages/core/shared/call-chat-api.ts @@ -1,10 +1,12 @@ import { parseComplexResponse } from './parse-complex-response'; import { IdGenerator, JSONValue, Message } from './types'; +import { createChunkDecoder } from './utils'; export async function callChatApi({ api, messages, body, + streamMode = 'stream-data', credentials, headers, abortController, @@ -17,6 +19,7 @@ export async function callChatApi({ api: string; messages: Omit[]; body: Record; + streamMode?: 'stream-data' | 'text'; credentials?: RequestCredentials; headers?: HeadersInit; abortController?: () => AbortController | null; @@ -64,16 +67,62 @@ export async function callChatApi({ const reader = response.body.getReader(); - return await parseComplexResponse({ - reader, - abortControllerRef: - abortController != null ? { current: abortController() } : undefined, - update: onUpdate, - onFinish(prefixMap) { - if (onFinish && prefixMap.text != null) { - onFinish(prefixMap.text); + switch (streamMode) { + case 'text': { + const decoder = createChunkDecoder(); + + const resultMessage = { + id: generateId(), + createdAt: new Date(), + role: 'assistant' as const, + content: '', + }; + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + resultMessage.content += decoder(value); + resultMessage.id = generateId(); + + // note: creating a new message object is required for Solid.js streaming + onUpdate([{ ...resultMessage }], []); + + // The request has been aborted, stop reading the stream. + if (abortController?.() === null) { + reader.cancel(); + break; + } } - }, - generateId, - }); + + onFinish?.(resultMessage); + + return { + messages: [resultMessage], + data: [], + }; + } + + case 'stream-data': { + return await parseComplexResponse({ + reader, + abortControllerRef: + abortController != null ? { current: abortController() } : undefined, + update: onUpdate, + onFinish(prefixMap) { + if (onFinish && prefixMap.text != null) { + onFinish(prefixMap.text); + } + }, + generateId, + }); + } + + default: { + const exhaustiveCheck: never = streamMode; + throw new Error(`Unknown stream mode: ${exhaustiveCheck}`); + } + } } diff --git a/packages/core/shared/call-completion-api.ts b/packages/core/shared/call-completion-api.ts index db1bc52ff403..694108d30774 100644 --- a/packages/core/shared/call-completion-api.ts +++ b/packages/core/shared/call-completion-api.ts @@ -1,5 +1,6 @@ import { readDataStream } from './read-data-stream'; import { JSONValue } from './types'; +import { createChunkDecoder } from './utils'; export async function callCompletionApi({ api, @@ -7,6 +8,7 @@ export async function callCompletionApi({ credentials, headers, body, + streamMode = 'stream-data', setCompletion, setLoading, setError, @@ -21,6 +23,7 @@ export async function callCompletionApi({ credentials?: RequestCredentials; headers?: HeadersInit; body: Record; + streamMode?: 'stream-data' | 'text'; setCompletion: (completion: string) => void; setLoading: (loading: boolean) => void; setError: (error: Error | undefined) => void; @@ -77,19 +80,52 @@ export async function callCompletionApi({ let result = ''; const reader = res.body.getReader(); - for await (const { type, value } of readDataStream(reader, { - isAborted: () => abortController === null, - })) { - switch (type) { - case 'text': { - result += value; + switch (streamMode) { + case 'text': { + const decoder = createChunkDecoder(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + // Update the completion state with the new message tokens. + result += decoder(value); setCompletion(result); - break; + + // The request has been aborted, stop reading the stream. + if (abortController === null) { + reader.cancel(); + break; + } } - case 'data': { - onData?.(value); - break; + + break; + } + + case 'stream-data': { + for await (const { type, value } of readDataStream(reader, { + isAborted: () => abortController === null, + })) { + switch (type) { + case 'text': { + result += value; + setCompletion(result); + break; + } + case 'data': { + onData?.(value); + break; + } + } } + break; + } + + default: { + const exhaustiveCheck: never = streamMode; + throw new Error(`Unknown stream mode: ${exhaustiveCheck}`); } } diff --git a/packages/core/shared/types.ts b/packages/core/shared/types.ts index be7be26b2375..dc2279ecf620 100644 --- a/packages/core/shared/types.ts +++ b/packages/core/shared/types.ts @@ -248,6 +248,9 @@ export type UseChatOptions = { * handle the extra fields before forwarding the request to the AI service. */ sendExtraMessageFields?: boolean; + + /** Stream mode (default to "stream-data") */ + streamMode?: 'stream-data' | 'text'; }; export type UseCompletionOptions = { @@ -313,6 +316,9 @@ export type UseCompletionOptions = { * ``` */ body?: object; + + /** Stream mode (default to "stream-data") */ + streamMode?: 'stream-data' | 'text'; }; export type JSONValue = diff --git a/packages/core/solid/use-chat.ts b/packages/core/solid/use-chat.ts index 02b9a0d771ce..06ccb20ff241 100644 --- a/packages/core/solid/use-chat.ts +++ b/packages/core/solid/use-chat.ts @@ -82,6 +82,7 @@ export function useChat({ credentials, headers, body, + streamMode, generateId = generateIdFunc, }: UseChatOptions = {}): UseChatHelpers { // Generate a unique ID for the chat if not provided. @@ -158,6 +159,7 @@ export function useChat({ ...body, ...options?.body, }, + streamMode, headers: { ...headers, ...options?.headers, diff --git a/packages/core/solid/use-chat.ui.test.tsx b/packages/core/solid/use-chat.ui.test.tsx index 730c87ac1b66..c21c49dd1027 100644 --- a/packages/core/solid/use-chat.ui.test.tsx +++ b/packages/core/solid/use-chat.ui.test.tsx @@ -10,117 +10,175 @@ import { } from '../tests/utils/mock-fetch'; import { useChat } from './use-chat'; -const TestComponent = () => { - const { messages, append, error, data, isLoading } = useChat(); - - return ( -
-
{isLoading().toString()}
-
{error()?.toString()}
-
{JSON.stringify(data())}
- - - {(m, idx) => ( -
- {m.role === 'user' ? 'User: ' : 'AI: '} - {m.content} -
- )} -
- -
- ); -}; - -beforeEach(() => { - render(() => ); -}); +describe('stream data stream', () => { + const TestComponent = () => { + const { messages, append, error, data, isLoading } = useChat(); + + return ( +
+
{isLoading().toString()}
+
{error()?.toString()}
+
{JSON.stringify(data())}
+ + + {(m, idx) => ( +
+ {m.role === 'user' ? 'User: ' : 'AI: '} + {m.content} +
+ )} +
+ +
+ ); + }; + + beforeEach(() => { + render(() => ); + }); -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should return messages', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); + + await userEvent.click(screen.getByTestId('button')); -it('should return messages', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + await screen.findByTestId('message-0'); + expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent( + 'AI: Hello, world.', + ); }); - await userEvent.click(screen.getByTestId('button')); + it('should return messages and data', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + }); - await screen.findByTestId('message-0'); - expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent( - 'AI: Hello, world.', - ); -}); + await screen.findByTestId('data'); + expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); -it('should return messages and data', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); }); - await userEvent.click(screen.getByTestId('button')); + it('should return error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('data'); - expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); + await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); -}); + await screen.findByTestId('error'); + expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); + }); -it('should return error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); - await userEvent.click(screen.getByTestId('button')); + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); - await screen.findByTestId('error'); - expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); -}); + await userEvent.click(screen.getByTestId('button')); -describe('loading state', () => { - it('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; - }); + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + finishGeneration?.(); + + await findByText(await screen.findByTestId('loading'), 'false'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - await userEvent.click(screen.getByTestId('button')); + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.click(screen.getByTestId('button')); - finishGeneration?.(); + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); + }); + }); +}); - await findByText(await screen.findByTestId('loading'), 'false'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); +describe('text stream', () => { + const TestComponent = () => { + const { messages, append } = useChat({ + streamMode: 'text', + }); + + return ( +
+ + {(m, idx) => ( +
+ {m.role === 'user' ? 'User: ' : 'AI: '} + {m.content} +
+ )} +
+ +
+ ); + }; + + beforeEach(() => { + render(() => ); }); - it('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); - await userEvent.click(screen.getByTestId('button')); + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['Hello', ',', ' world', '.'], + }); + + await userEvent.click(screen.getByTestId('do-append-text-stream')); + + await screen.findByTestId('message-0-text-stream'); + expect(screen.getByTestId('message-0-text-stream')).toHaveTextContent( + 'User: hi', + ); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await screen.findByTestId('message-1-text-stream'); + expect(screen.getByTestId('message-1-text-stream')).toHaveTextContent( + 'AI: Hello, world.', + ); }); }); diff --git a/packages/core/solid/use-completion.ts b/packages/core/solid/use-completion.ts index 0eee825cba3a..aedf846a1a7a 100644 --- a/packages/core/solid/use-completion.ts +++ b/packages/core/solid/use-completion.ts @@ -67,6 +67,7 @@ export function useCompletion({ credentials, headers, body, + streamMode, onResponse, onFinish, onError, @@ -115,6 +116,7 @@ export function useCompletion({ ...body, ...options?.body, }, + streamMode, setCompletion: mutate, setLoading: setIsLoading, setError, diff --git a/packages/core/solid/use-completion.ui.test.tsx b/packages/core/solid/use-completion.ui.test.tsx index 319a5353bc8b..2f561ddd897d 100644 --- a/packages/core/solid/use-completion.ui.test.tsx +++ b/packages/core/solid/use-completion.ui.test.tsx @@ -9,80 +9,124 @@ import { } from '../tests/utils/mock-fetch'; import { useCompletion } from './use-completion'; -const TestComponent = () => { - const { completion, complete, error, isLoading } = useCompletion(); - - return ( -
-
{isLoading().toString()}
-
{error()?.toString()}
- -
{completion()}
- -
- ); -}; - -beforeEach(() => { - render(() => ); -}); +describe('stream data stream', () => { + const TestComponent = () => { + const { completion, complete, error, isLoading } = useCompletion(); + + return ( +
+
{isLoading().toString()}
+
{error()?.toString()}
+ +
{completion()}
+ +
+ ); + }; + + beforeEach(() => { + render(() => ); + }); -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should render complex text stream', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); + + await userEvent.click(screen.getByTestId('button')); -it('should render complex text stream', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/completion', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + await screen.findByTestId('completion'); + expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); }); - await userEvent.click(screen.getByTestId('button')); + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); - await screen.findByTestId('completion'); - expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); -}); + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); -describe('loading state', () => { - it('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; - }); + await userEvent.click(screen.getByTestId('button')); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); + + finishGeneration?.(); + + await findByText(await screen.findByTestId('loading'), 'false'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - await userEvent.click(screen.getByTestId('button')); + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.click(screen.getByTestId('button')); + + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); + }); + }); +}); - finishGeneration?.(); +describe('text stream', () => { + const TestComponent = () => { + const { completion, complete } = useCompletion({ streamMode: 'text' }); + + return ( +
+
{completion()}
+ +
+ ); + }; + + beforeEach(() => { + render(() => ); + }); - await findByText(await screen.findByTestId('loading'), 'false'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); }); - it('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + it('should render stream', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['Hello', ',', ' world', '.'], + }); - await userEvent.click(screen.getByTestId('button')); + await userEvent.click(screen.getByTestId('button-text-stream')); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await screen.findByTestId('completion-text-stream'); + expect(screen.getByTestId('completion-text-stream')).toHaveTextContent( + 'Hello, world.', + ); }); }); diff --git a/packages/core/svelte/use-chat.ts b/packages/core/svelte/use-chat.ts index 2e69c16f2498..a8d29ae5c76f 100644 --- a/packages/core/svelte/use-chat.ts +++ b/packages/core/svelte/use-chat.ts @@ -72,6 +72,7 @@ const getStreamedResponse = async ( previousMessages: Message[], abortControllerRef: AbortController | null, generateId: IdGenerator, + streamMode?: 'stream-data' | 'text', onFinish?: (message: Message) => void, onResponse?: (response: Response) => void | Promise, sendExtraMessageFields?: boolean, @@ -116,6 +117,7 @@ const getStreamedResponse = async ( tool_choice: chatRequest.tool_choice, }), }, + streamMode, credentials: extraMetadata.credentials, headers: { ...extraMetadata.headers, @@ -147,6 +149,7 @@ export function useChat({ sendExtraMessageFields, experimental_onFunctionCall, experimental_onToolCall, + streamMode, onResponse, onFinish, onError, @@ -216,6 +219,7 @@ export function useChat({ get(messages), abortController, generateId, + streamMode, onFinish, onResponse, sendExtraMessageFields, diff --git a/packages/core/svelte/use-completion.ts b/packages/core/svelte/use-completion.ts index e3845aa2230f..2319577e36ca 100644 --- a/packages/core/svelte/use-completion.ts +++ b/packages/core/svelte/use-completion.ts @@ -60,6 +60,7 @@ export function useCompletion({ credentials, headers, body, + streamMode, onResponse, onFinish, onError, @@ -113,6 +114,7 @@ export function useCompletion({ ...body, ...options?.body, }, + streamMode, setCompletion: mutate, setLoading: loadingState => loading.set(loadingState), setError: err => error.set(err), diff --git a/packages/core/vue/TestChatTextStreamComponent.vue b/packages/core/vue/TestChatTextStreamComponent.vue new file mode 100644 index 000000000000..48609cc0431a --- /dev/null +++ b/packages/core/vue/TestChatTextStreamComponent.vue @@ -0,0 +1,28 @@ + + + diff --git a/packages/core/vue/TestCompletionTextStreamComponent.vue b/packages/core/vue/TestCompletionTextStreamComponent.vue new file mode 100644 index 000000000000..3bdd531308ff --- /dev/null +++ b/packages/core/vue/TestCompletionTextStreamComponent.vue @@ -0,0 +1,19 @@ + + + diff --git a/packages/core/vue/use-chat.ts b/packages/core/vue/use-chat.ts index c8376c156a3e..6cdcf97a5ea1 100644 --- a/packages/core/vue/use-chat.ts +++ b/packages/core/vue/use-chat.ts @@ -70,6 +70,7 @@ export function useChat({ initialInput = '', sendExtraMessageFields, experimental_onFunctionCall, + streamMode, onResponse, onFinish, onError, @@ -154,6 +155,7 @@ export function useChat({ ...unref(body), // Use unref to unwrap the ref value ...options?.body, }, + streamMode, headers: { ...headers, ...options?.headers, diff --git a/packages/core/vue/use-chat.ui.test.tsx b/packages/core/vue/use-chat.ui.test.tsx index 9a7272addcc1..ff96bd2f2237 100644 --- a/packages/core/vue/use-chat.ui.test.tsx +++ b/packages/core/vue/use-chat.ui.test.tsx @@ -7,95 +7,126 @@ import { mockFetchError, } from '../tests/utils/mock-fetch'; import TestChatComponent from './TestChatComponent.vue'; +import TestChatTextStreamComponent from './TestChatTextStreamComponent.vue'; -beforeEach(() => { - render(TestChatComponent); -}); +describe('stream data stream', () => { + beforeEach(() => { + render(TestChatComponent); + }); -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); + + await userEvent.click(screen.getByTestId('button')); + + await screen.findByTestId('message-0'); + expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); -test('Shows streamed complex text response', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent( + 'AI: Hello, world.', + ); }); - await userEvent.click(screen.getByTestId('button')); + it('should show streamed response with data', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + }); - await screen.findByTestId('message-0'); - expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent( - 'AI: Hello, world.', - ); -}); + await screen.findByTestId('data'); + expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); -test('Shows streamed complex text response with data', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/chat', - chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'], + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); }); - await userEvent.click(screen.getByTestId('button')); + it('should show error response', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + + await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('data'); - expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]'); + // TODO bug? the user message does not show up + // await screen.findByTestId('message-0'); + // expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); - await screen.findByTestId('message-1'); - expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello'); -}); + await screen.findByTestId('error'); + expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); + }); -test('Shows error response', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); - await userEvent.click(screen.getByTestId('button')); + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); - // TODO bug? the user message does not show up - // await screen.findByTestId('message-0'); - // expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('error'); - expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found'); -}); + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); -describe('loading state', () => { - test('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; - }); + finishGeneration?.(); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + await findByText(await screen.findByTestId('loading'), 'false'); + + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - await userEvent.click(screen.getByTestId('button')); + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.click(screen.getByTestId('button')); - finishGeneration?.(); + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); + }); + }); +}); - await findByText(await screen.findByTestId('loading'), 'false'); +describe('text stream', () => { + beforeEach(() => { + render(TestChatTextStreamComponent); + }); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); }); - test('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/chat', + chunks: ['Hello', ',', ' world', '.'], + }); await userEvent.click(screen.getByTestId('button')); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await screen.findByTestId('message-0'); + expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi'); + + await screen.findByTestId('message-1'); + expect(screen.getByTestId('message-1')).toHaveTextContent( + 'AI: Hello, world.', + ); }); }); diff --git a/packages/core/vue/use-completion.ts b/packages/core/vue/use-completion.ts index 33564640bc20..a1233eb449f0 100644 --- a/packages/core/vue/use-completion.ts +++ b/packages/core/vue/use-completion.ts @@ -63,6 +63,7 @@ export function useCompletion({ credentials, headers, body, + streamMode, onResponse, onFinish, onError, @@ -116,6 +117,7 @@ export function useCompletion({ ...unref(body), ...options?.body, }, + streamMode, setCompletion: mutate, setLoading: loading => mutateLoading(() => loading), setError: err => { diff --git a/packages/core/vue/use-completion.ui.test.ts b/packages/core/vue/use-completion.ui.test.ts index 2e86522cfeda..e22e9539f9d1 100644 --- a/packages/core/vue/use-completion.ui.test.ts +++ b/packages/core/vue/use-completion.ui.test.ts @@ -7,61 +7,87 @@ import { mockFetchError, } from '../tests/utils/mock-fetch'; import TestCompletionComponent from './TestCompletionComponent.vue'; +import TestCompletionTextStreamComponent from './TestCompletionTextStreamComponent.vue'; -beforeEach(() => { - render(TestCompletionComponent); -}); +describe('stream data stream', () => { + beforeEach(() => { + render(TestCompletionComponent); + }); -afterEach(() => { - vi.restoreAllMocks(); - cleanup(); -}); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); + }); + + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + }); + + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); -it('should render complex text stream', async () => { - mockFetchDataStream({ - url: 'https://example.com/api/completion', - chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'], + await screen.findByTestId('completion'); + expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); }); - await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + describe('loading state', () => { + it('should show loading state', async () => { + let finishGeneration: ((value?: unknown) => void) | undefined; + const finishGenerationPromise = new Promise(resolve => { + finishGeneration = resolve; + }); - await screen.findByTestId('completion'); - expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); -}); + mockFetchDataStreamWithGenerator({ + url: 'https://example.com/api/chat', + chunkGenerator: (async function* generate() { + const encoder = new TextEncoder(); + yield encoder.encode('0:"Hello"\n'); + await finishGenerationPromise; + })(), + }); -describe('loading state', () => { - it('should show loading state', async () => { - let finishGeneration: ((value?: unknown) => void) | undefined; - const finishGenerationPromise = new Promise(resolve => { - finishGeneration = resolve; - }); + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); - mockFetchDataStreamWithGenerator({ - url: 'https://example.com/api/chat', - chunkGenerator: (async function* generate() { - const encoder = new TextEncoder(); - yield encoder.encode('0:"Hello"\n'); - await finishGenerationPromise; - })(), + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('true'); + + finishGeneration?.(); + + await findByText(await screen.findByTestId('loading'), 'false'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); }); - await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); + it('should reset loading state on error', async () => { + mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('true'); + await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); - finishGeneration?.(); + await screen.findByTestId('loading'); + expect(screen.getByTestId('loading')).toHaveTextContent('false'); + }); + }); +}); + +describe('stream data stream', () => { + beforeEach(() => { + render(TestCompletionTextStreamComponent); + }); - await findByText(await screen.findByTestId('loading'), 'false'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + afterEach(() => { + vi.restoreAllMocks(); + cleanup(); }); - it('should reset loading state on error', async () => { - mockFetchError({ statusCode: 404, errorMessage: 'Not found' }); + it('should show streamed response', async () => { + mockFetchDataStream({ + url: 'https://example.com/api/completion', + chunks: ['Hello', ',', ' world', '.'], + }); await userEvent.type(screen.getByTestId('input'), 'hi{enter}'); - await screen.findByTestId('loading'); - expect(screen.getByTestId('loading')).toHaveTextContent('false'); + await screen.findByTestId('completion'); + expect(screen.getByTestId('completion')).toHaveTextContent('Hello, world.'); }); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 57477336b377..e3083ca28d0e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1010,6 +1010,9 @@ importers: examples/solidstart-openai: dependencies: + '@ai-sdk/openai': + specifier: latest + version: link:../../packages/openai '@solidjs/meta': specifier: 0.29.3 version: 0.29.3(solid-js@1.8.7)