Skip to content

Commit

Permalink
Introduce streamMode for useChat / useCompletion. (#1350)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Apr 16, 2024
1 parent f272b01 commit 66b5892
Show file tree
Hide file tree
Showing 23 changed files with 879 additions and 455 deletions.
5 changes: 5 additions & 0 deletions .changeset/empty-windows-think.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

Add streamMode parameter to useChat and useCompletion.
1 change: 1 addition & 0 deletions examples/solidstart-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"vite": "^4.1.4"
},
"dependencies": {
"@ai-sdk/openai": "latest",
"@solidjs/meta": "0.29.3",
"@solidjs/router": "0.8.2",
"ai": "latest",
Expand Down
32 changes: 13 additions & 19 deletions examples/solidstart-openai/src/routes/api/chat/index.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
import { OpenAIStream, StreamingTextResponse } from 'ai';
import OpenAI from 'openai';
import { openai } from '@ai-sdk/openai';
import { StreamingTextResponse, experimental_streamText } from 'ai';
import { APIEvent } from 'solid-start/api';

// Create an OpenAI API client
const openai = new OpenAI({
apiKey: process.env['OPENAI_API_KEY'] || '',
});

export const POST = async (event: APIEvent) => {
// Extract the `prompt` from the body of the request
const { messages } = await event.request.json();
try {
const { messages } = await event.request.json();

// Ask OpenAI for a streaming chat completion given the prompt
const response = await openai.chat.completions.create({
model: 'gpt-3.5-turbo',
stream: true,
messages,
});
const result = await experimental_streamText({
model: openai.chat('gpt-4-turbo-preview'),
messages,
});

// Convert the response into a friendly text-stream
const stream = OpenAIStream(response);
// Respond with the stream
return new StreamingTextResponse(stream);
return new StreamingTextResponse(result.toAIStream());
} catch (error) {
console.error(error);
throw error;
}
};
4 changes: 4 additions & 0 deletions packages/core/react/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ const getStreamedResponse = async (
messagesRef: React.MutableRefObject<Message[]>,
abortControllerRef: React.MutableRefObject<AbortController | null>,
generateId: IdGenerator,
streamMode?: 'stream-data' | 'text',
onFinish?: (message: Message) => void,
onResponse?: (response: Response) => void | Promise<void>,
sendExtraMessageFields?: boolean,
Expand Down Expand Up @@ -179,6 +180,7 @@ const getStreamedResponse = async (
tool_choice: chatRequest.tool_choice,
}),
},
streamMode,
credentials: extraMetadataRef.current.credentials,
headers: {
...extraMetadataRef.current.headers,
Expand Down Expand Up @@ -206,6 +208,7 @@ export function useChat({
sendExtraMessageFields,
experimental_onFunctionCall,
experimental_onToolCall,
streamMode,
onResponse,
onFinish,
onError,
Expand Down Expand Up @@ -292,6 +295,7 @@ export function useChat({
messagesRef,
abortControllerRef,
generateId,
streamMode,
onFinish,
onResponse,
sendExtraMessageFields,
Expand Down
264 changes: 160 additions & 104 deletions packages/core/react/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,145 +9,201 @@ import {
} from '../tests/utils/mock-fetch';
import { useChat } from './use-chat';

const TestComponent = () => {
const [id, setId] = React.useState<string>('first-id');
const { messages, append, error, data, isLoading } = useChat({ id });

return (
<div>
<div data-testid="loading">{isLoading.toString()}</div>
{error && <div data-testid="error">{error.toString()}</div>}
{data && <div data-testid="data">{JSON.stringify(data)}</div>}
{messages.map((m, idx) => (
<div data-testid={`message-${idx}`} key={m.id}>
{m.role === 'user' ? 'User: ' : 'AI: '}
{m.content}
</div>
))}

<button
data-testid="do-append"
onClick={() => {
append({ role: 'user', content: 'hi' });
}}
/>
<button
data-testid="do-change-id"
onClick={() => {
setId('second-id');
}}
/>
</div>
);
};

beforeEach(() => {
render(<TestComponent />);
});
describe('stream data stream', () => {
const TestComponent = () => {
const [id, setId] = React.useState<string>('first-id');
const { messages, append, error, data, isLoading } = useChat({ id });

return (
<div>
<div data-testid="loading">{isLoading.toString()}</div>
{error && <div data-testid="error">{error.toString()}</div>}
{data && <div data-testid="data">{JSON.stringify(data)}</div>}
{messages.map((m, idx) => (
<div data-testid={`message-${idx}`} key={m.id}>
{m.role === 'user' ? 'User: ' : 'AI: '}
{m.content}
</div>
))}

<button
data-testid="do-append"
onClick={() => {
append({ role: 'user', content: 'hi' });
}}
/>
<button
data-testid="do-change-id"
onClick={() => {
setId('second-id');
}}
/>
</div>
);
};

afterEach(() => {
vi.restoreAllMocks();
cleanup();
});
beforeEach(() => {
render(<TestComponent />);
});

test('Shows streamed complex text response', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
afterEach(() => {
vi.restoreAllMocks();
cleanup();
});

await userEvent.click(screen.getByTestId('do-append'));
it('should show streamed response', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
});

await screen.findByTestId('message-0');
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
await userEvent.click(screen.getByTestId('do-append'));

await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent(
'AI: Hello, world.',
);
});
await screen.findByTestId('message-0');
expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');

test('Shows streamed complex text response with data', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'],
await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent(
'AI: Hello, world.',
);
});

await userEvent.click(screen.getByTestId('do-append'));
it('should show streamed response with data', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['2:[{"t1":"v1"}]\n', '0:"Hello"\n'],
});

await screen.findByTestId('data');
expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]');
await userEvent.click(screen.getByTestId('do-append'));

await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello');
});
await screen.findByTestId('data');
expect(screen.getByTestId('data')).toHaveTextContent('[{"t1":"v1"}]');

test('Shows error response', async () => {
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent('AI: Hello');
});

await userEvent.click(screen.getByTestId('do-append'));
it('should show error response', async () => {
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });

// TODO bug? the user message does not show up
// await screen.findByTestId('message-0');
// expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');
await userEvent.click(screen.getByTestId('do-append'));

await screen.findByTestId('error');
expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found');
});
// TODO bug? the user message does not show up
// await screen.findByTestId('message-0');
// expect(screen.getByTestId('message-0')).toHaveTextContent('User: hi');

await screen.findByTestId('error');
expect(screen.getByTestId('error')).toHaveTextContent('Error: Not found');
});

describe('loading state', () => {
it('should show loading state', async () => {
let finishGeneration: ((value?: unknown) => void) | undefined;
const finishGenerationPromise = new Promise(resolve => {
finishGeneration = resolve;
});

describe('loading state', () => {
test('should show loading state', async () => {
let finishGeneration: ((value?: unknown) => void) | undefined;
const finishGenerationPromise = new Promise(resolve => {
finishGeneration = resolve;
mockFetchDataStreamWithGenerator({
url: 'https://example.com/api/chat',
chunkGenerator: (async function* generate() {
const encoder = new TextEncoder();
yield encoder.encode('0:"Hello"\n');
await finishGenerationPromise;
})(),
});

await userEvent.click(screen.getByTestId('do-append'));

await screen.findByTestId('loading');
expect(screen.getByTestId('loading')).toHaveTextContent('true');

finishGeneration?.();

await findByText(await screen.findByTestId('loading'), 'false');
expect(screen.getByTestId('loading')).toHaveTextContent('false');
});

mockFetchDataStreamWithGenerator({
url: 'https://example.com/api/chat',
chunkGenerator: (async function* generate() {
const encoder = new TextEncoder();
yield encoder.encode('0:"Hello"\n');
await finishGenerationPromise;
})(),
it('should reset loading state on error', async () => {
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });

await userEvent.click(screen.getByTestId('do-append'));

await screen.findByTestId('loading');
expect(screen.getByTestId('loading')).toHaveTextContent('false');
});
});

await userEvent.click(screen.getByTestId('do-append'));
describe('id', () => {
it('should clear out messages when the id changes', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
});

await screen.findByTestId('loading');
expect(screen.getByTestId('loading')).toHaveTextContent('true');
await userEvent.click(screen.getByTestId('do-append'));

finishGeneration?.();
await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent(
'AI: Hello, world.',
);

await findByText(await screen.findByTestId('loading'), 'false');
expect(screen.getByTestId('loading')).toHaveTextContent('false');
await userEvent.click(screen.getByTestId('do-change-id'));

expect(screen.queryByTestId('message-0')).not.toBeInTheDocument();
});
});
});

test('should reset loading state on error', async () => {
mockFetchError({ statusCode: 404, errorMessage: 'Not found' });
describe('text stream', () => {
const TestComponent = () => {
const { messages, append } = useChat({
streamMode: 'text',
});

await userEvent.click(screen.getByTestId('do-append'));
return (
<div>
{messages.map((m, idx) => (
<div data-testid={`message-${idx}-text-stream`} key={m.id}>
{m.role === 'user' ? 'User: ' : 'AI: '}
{m.content}
</div>
))}

<button
data-testid="do-append-text-stream"
onClick={() => {
append({ role: 'user', content: 'hi' });
}}
/>
</div>
);
};

await screen.findByTestId('loading');
expect(screen.getByTestId('loading')).toHaveTextContent('false');
beforeEach(() => {
render(<TestComponent />);
});
});

describe('id', () => {
it('should clear out messages when the id changes', async () => {
afterEach(() => {
vi.restoreAllMocks();
cleanup();
});

it('should show streamed response', async () => {
mockFetchDataStream({
url: 'https://example.com/api/chat',
chunks: ['0:"Hello"\n', '0:","\n', '0:" world"\n', '0:"."\n'],
chunks: ['Hello', ',', ' world', '.'],
});

await userEvent.click(screen.getByTestId('do-append'));
await userEvent.click(screen.getByTestId('do-append-text-stream'));

await screen.findByTestId('message-1');
expect(screen.getByTestId('message-1')).toHaveTextContent(
'AI: Hello, world.',
await screen.findByTestId('message-0-text-stream');
expect(screen.getByTestId('message-0-text-stream')).toHaveTextContent(
'User: hi',
);

await userEvent.click(screen.getByTestId('do-change-id'));

expect(screen.queryByTestId('message-0')).not.toBeInTheDocument();
await screen.findByTestId('message-1-text-stream');
expect(screen.getByTestId('message-1-text-stream')).toHaveTextContent(
'AI: Hello, world.',
);
});
});
Loading

0 comments on commit 66b5892

Please sign in to comment.