From 7851fa0d16be705ab930babdd44f6513694c5582 Mon Sep 17 00:00:00 2001 From: Nick Gomez <122398915+nick-inkeep@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:49:19 -0500 Subject: [PATCH] feat: Adding ability to attach 'annotations' (custom metadata) to messages (#879) --- .changeset/tasty-bobcat-check.md | 5 +++ .../shared/parse-complex-response.test.ts | 41 +++++++++++++++++++ .../core/shared/parse-complex-response.ts | 33 +++++++++++++++ packages/core/shared/stream-parts.test.ts | 9 ++++ packages/core/shared/stream-parts.ts | 26 ++++++++++-- packages/core/shared/types.ts | 5 +++ packages/core/streams/ai-stream.ts | 20 +++++++-- packages/core/streams/stream-data.ts | 24 +++++++++++ 8 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 .changeset/tasty-bobcat-check.md diff --git a/.changeset/tasty-bobcat-check.md b/.changeset/tasty-bobcat-check.md new file mode 100644 index 000000000000..795339393803 --- /dev/null +++ b/.changeset/tasty-bobcat-check.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +StreamData: add `annotations` and `appendMessageAnnotation` support diff --git a/packages/core/shared/parse-complex-response.test.ts b/packages/core/shared/parse-complex-response.test.ts index e61dd1a93209..480dd8fd42f2 100644 --- a/packages/core/shared/parse-complex-response.test.ts +++ b/packages/core/shared/parse-complex-response.test.ts @@ -224,4 +224,45 @@ describe('parseComplexResponse function', () => { data: [{ t1: 'v1' }, 3, null, false, 'text'], }); }); + + it('should parse a combination of a text message and message annotations', async () => { + const mockUpdate = vi.fn(); + + // Execute the parser function + const result = await parseComplexResponse({ + reader: createTestReader([ + '0:"Sample text message."\n', + '8:[{"key":"value"}, 2]\n', + ]), + abortControllerRef: { current: new AbortController() }, + update: mockUpdate, + generateId: () => 'test-id', + getCurrentDate: () => new Date(0), + }); + + // check the mockUpdate call: + expect(mockUpdate).toHaveBeenCalledTimes(2); + + expect(mockUpdate.mock.calls[0][0]).toEqual([ + assistantTextMessage('Sample text message.'), + ]); + + expect(mockUpdate.mock.calls[1][0]).toEqual([ + { + ...assistantTextMessage('Sample text message.'), + annotations: [{ key: 'value' }, 2], + }, + ]); + + // check the result + expect(result).toEqual({ + messages: [ + { + ...assistantTextMessage('Sample text message.'), + annotations: [{ key: 'value' }, 2], + }, + ], + data: [], + }); + }); }); diff --git a/packages/core/shared/parse-complex-response.ts b/packages/core/shared/parse-complex-response.ts index b4d8ee2f162c..11241324f963 100644 --- a/packages/core/shared/parse-complex-response.ts +++ b/packages/core/shared/parse-complex-response.ts @@ -15,6 +15,22 @@ type PrefixMap = { data: JSONValue[]; }; +function initializeMessage({ + generateId, + ...rest +}: { + generateId: () => string; + content: string; + createdAt: Date; + annotations?: JSONValue[]; +}): Message { + return { + id: generateId(), + role: 'assistant', + ...rest, + }; +} + export async function parseComplexResponse({ reader, abortControllerRef, @@ -57,6 +73,23 @@ export async function parseComplexResponse({ } } + if (type == 'message_annotations') { + if (prefixMap['text']) { + prefixMap['text'] = { + ...prefixMap['text'], + annotations: [...(prefixMap['text'].annotations || []), ...value], + }; + } else { + prefixMap['text'] = { + id: generateId(), + role: 'assistant', + content: '', + annotations: [...value], + createdAt, + }; + } + } + let functionCallMessage: Message | null = null; if (type === 'function_call') { diff --git a/packages/core/shared/stream-parts.test.ts b/packages/core/shared/stream-parts.test.ts index e079eaef94f2..8b00adb6c4f3 100644 --- a/packages/core/shared/stream-parts.test.ts +++ b/packages/core/shared/stream-parts.test.ts @@ -69,6 +69,15 @@ describe('stream-parts', () => { expect(parseStreamPart(input)).toEqual(expectedOutput); }); + it('should parse a message data line', () => { + const input = '8:[{"test":"value"}]'; + const expectedOutput = { + type: 'message_annotations', + value: [{ test: 'value' }], + }; + expect(parseStreamPart(input)).toEqual(expectedOutput); + }); + it('should throw an error if the input does not contain a colon separator', () => { const input = 'invalid stream string'; expect(() => parseStreamPart(input)).toThrow(); diff --git a/packages/core/shared/stream-parts.ts b/packages/core/shared/stream-parts.ts index e013d790704c..55392a3475af 100644 --- a/packages/core/shared/stream-parts.ts +++ b/packages/core/shared/stream-parts.ts @@ -221,6 +221,22 @@ const toolCallStreamPart: StreamPart< }, }; +const messageAnnotationsStreamPart: StreamPart< + '8', + 'message_annotations', + Array +> = { + code: '8', + name: 'message_annotations', + parse: (value: JSONValue) => { + if (!Array.isArray(value)) { + throw new Error('"message_annotations" parts expect an array value.'); + } + + return { type: 'message_annotations', value }; + }, +}; + const streamParts = [ textStreamPart, functionCallStreamPart, @@ -230,6 +246,7 @@ const streamParts = [ assistantControlDataStreamPart, dataMessageStreamPart, toolCallStreamPart, + messageAnnotationsStreamPart, ] as const; // union type of all stream parts @@ -241,8 +258,8 @@ type StreamParts = | typeof assistantMessageStreamPart | typeof assistantControlDataStreamPart | typeof dataMessageStreamPart - | typeof toolCallStreamPart; - + | typeof toolCallStreamPart + | typeof messageAnnotationsStreamPart; /** * Maps the type of a stream part to its value type. */ @@ -258,7 +275,8 @@ export type StreamPartType = | ReturnType | ReturnType | ReturnType - | ReturnType; + | ReturnType + | ReturnType; export const streamPartsByCode = { [textStreamPart.code]: textStreamPart, @@ -269,6 +287,7 @@ export const streamPartsByCode = { [assistantControlDataStreamPart.code]: assistantControlDataStreamPart, [dataMessageStreamPart.code]: dataMessageStreamPart, [toolCallStreamPart.code]: toolCallStreamPart, + [messageAnnotationsStreamPart.code]: messageAnnotationsStreamPart, } as const; /** @@ -302,6 +321,7 @@ export const StreamStringPrefixes = { [assistantControlDataStreamPart.name]: assistantControlDataStreamPart.code, [dataMessageStreamPart.name]: dataMessageStreamPart.code, [toolCallStreamPart.name]: toolCallStreamPart.code, + [messageAnnotationsStreamPart.name]: messageAnnotationsStreamPart.code, } as const; export const validCodes = streamParts.map(part => part.code); diff --git a/packages/core/shared/types.ts b/packages/core/shared/types.ts index a88dbc28bca0..be7be26b2375 100644 --- a/packages/core/shared/types.ts +++ b/packages/core/shared/types.ts @@ -110,6 +110,11 @@ export interface Message { * the tool call name and arguments. Otherwise, the field should not be set. */ tool_calls?: string | ToolCall[]; + + /** + * Additional message-specific information added on the server via StreamData + */ + annotations?: JSONValue[] | undefined; } export type CreateMessage = Omit & { diff --git a/packages/core/streams/ai-stream.ts b/packages/core/streams/ai-stream.ts index 441d030b5c45..c92785ddfaea 100644 --- a/packages/core/streams/ai-stream.ts +++ b/packages/core/streams/ai-stream.ts @@ -43,14 +43,24 @@ export interface AIStreamCallbacksAndOptions { experimental_streamData?: boolean; } -// new TokenData() -// data: TokenData, +/** + * Options for the AIStreamParser. + * @interface + * @property {string} event - The event (type) from the server side event stream. + */ +export interface AIStreamParserOptions { + event?: string; +} + /** * Custom parser for AIStream data. * @interface + * @param {string} data - The data to be parsed. + * @param {AIStreamParserOptions} options - The options for the parser. + * @returns {string | void} The parsed data or void. */ export interface AIStreamParser { - (data: string): string | void; + (data: string, options: AIStreamParserOptions): string | void; } /** @@ -82,7 +92,9 @@ export function createEventStreamTransformer( if ('data' in event) { const parsedMessage = customParser - ? customParser(event.data) + ? customParser(event.data, { + event: event.event, + }) : event.data; if (parsedMessage) controller.enqueue(parsedMessage); } diff --git a/packages/core/streams/stream-data.ts b/packages/core/streams/stream-data.ts index 2f1c452d6825..905e3da5c807 100644 --- a/packages/core/streams/stream-data.ts +++ b/packages/core/streams/stream-data.ts @@ -19,6 +19,8 @@ export class experimental_StreamData { // array to store appended data private data: JSONValue[] = []; + private messageAnnotations: JSONValue[] = []; + constructor() { this.isClosedPromise = new Promise(resolve => { this.isClosedPromiseResolver = resolve; @@ -39,6 +41,13 @@ export class experimental_StreamData { controller.enqueue(encodedData); } + if (self.messageAnnotations.length) { + const encodedmessageAnnotations = self.encoder.encode( + formatStreamPart('message_annotations', self.messageAnnotations), + ); + controller.enqueue(encodedmessageAnnotations); + } + controller.enqueue(chunk); }, async flush(controller) { @@ -64,6 +73,13 @@ export class experimental_StreamData { ); controller.enqueue(encodedData); } + + if (self.messageAnnotations.length) { + const encodedData = self.encoder.encode( + formatStreamPart('message_annotations', self.messageAnnotations), + ); + controller.enqueue(encodedData); + } }, }); } @@ -88,6 +104,14 @@ export class experimental_StreamData { this.data.push(value); } + + appendMessageAnnotation(value: JSONValue): void { + if (this.isClosed) { + throw new Error('Data Stream has already been closed.'); + } + + this.messageAnnotations.push(value); + } } /**