Skip to content

Commit

Permalink
feat: Adding ability to attach 'annotations' (custom metadata) to mes…
Browse files Browse the repository at this point in the history
…sages (#879)
  • Loading branch information
nick-inkeep authored Jan 30, 2024
1 parent fb799f1 commit 7851fa0
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/tasty-bobcat-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

StreamData: add `annotations` and `appendMessageAnnotation` support
41 changes: 41 additions & 0 deletions packages/core/shared/parse-complex-response.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,45 @@ describe('parseComplexResponse function', () => {
data: [{ t1: 'v1' }, 3, null, false, 'text'],
});
});

it('should parse a combination of a text message and message annotations', async () => {
const mockUpdate = vi.fn();

// Execute the parser function
const result = await parseComplexResponse({
reader: createTestReader([
'0:"Sample text message."\n',
'8:[{"key":"value"}, 2]\n',
]),
abortControllerRef: { current: new AbortController() },
update: mockUpdate,
generateId: () => 'test-id',
getCurrentDate: () => new Date(0),
});

// check the mockUpdate call:
expect(mockUpdate).toHaveBeenCalledTimes(2);

expect(mockUpdate.mock.calls[0][0]).toEqual([
assistantTextMessage('Sample text message.'),
]);

expect(mockUpdate.mock.calls[1][0]).toEqual([
{
...assistantTextMessage('Sample text message.'),
annotations: [{ key: 'value' }, 2],
},
]);

// check the result
expect(result).toEqual({
messages: [
{
...assistantTextMessage('Sample text message.'),
annotations: [{ key: 'value' }, 2],
},
],
data: [],
});
});
});
33 changes: 33 additions & 0 deletions packages/core/shared/parse-complex-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ type PrefixMap = {
data: JSONValue[];
};

function initializeMessage({
generateId,
...rest
}: {
generateId: () => string;
content: string;
createdAt: Date;
annotations?: JSONValue[];
}): Message {
return {
id: generateId(),
role: 'assistant',
...rest,
};
}

export async function parseComplexResponse({
reader,
abortControllerRef,
Expand Down Expand Up @@ -57,6 +73,23 @@ export async function parseComplexResponse({
}
}

if (type == 'message_annotations') {
if (prefixMap['text']) {
prefixMap['text'] = {
...prefixMap['text'],
annotations: [...(prefixMap['text'].annotations || []), ...value],
};
} else {
prefixMap['text'] = {
id: generateId(),
role: 'assistant',
content: '',
annotations: [...value],
createdAt,
};
}
}

let functionCallMessage: Message | null = null;

if (type === 'function_call') {
Expand Down
9 changes: 9 additions & 0 deletions packages/core/shared/stream-parts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ describe('stream-parts', () => {
expect(parseStreamPart(input)).toEqual(expectedOutput);
});

it('should parse a message data line', () => {
const input = '8:[{"test":"value"}]';
const expectedOutput = {
type: 'message_annotations',
value: [{ test: 'value' }],
};
expect(parseStreamPart(input)).toEqual(expectedOutput);
});

it('should throw an error if the input does not contain a colon separator', () => {
const input = 'invalid stream string';
expect(() => parseStreamPart(input)).toThrow();
Expand Down
26 changes: 23 additions & 3 deletions packages/core/shared/stream-parts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ const toolCallStreamPart: StreamPart<
},
};

const messageAnnotationsStreamPart: StreamPart<
'8',
'message_annotations',
Array<JSONValue>
> = {
code: '8',
name: 'message_annotations',
parse: (value: JSONValue) => {
if (!Array.isArray(value)) {
throw new Error('"message_annotations" parts expect an array value.');
}

return { type: 'message_annotations', value };
},
};

const streamParts = [
textStreamPart,
functionCallStreamPart,
Expand All @@ -230,6 +246,7 @@ const streamParts = [
assistantControlDataStreamPart,
dataMessageStreamPart,
toolCallStreamPart,
messageAnnotationsStreamPart,
] as const;

// union type of all stream parts
Expand All @@ -241,8 +258,8 @@ type StreamParts =
| typeof assistantMessageStreamPart
| typeof assistantControlDataStreamPart
| typeof dataMessageStreamPart
| typeof toolCallStreamPart;

| typeof toolCallStreamPart
| typeof messageAnnotationsStreamPart;
/**
* Maps the type of a stream part to its value type.
*/
Expand All @@ -258,7 +275,8 @@ export type StreamPartType =
| ReturnType<typeof assistantMessageStreamPart.parse>
| ReturnType<typeof assistantControlDataStreamPart.parse>
| ReturnType<typeof dataMessageStreamPart.parse>
| ReturnType<typeof toolCallStreamPart.parse>;
| ReturnType<typeof toolCallStreamPart.parse>
| ReturnType<typeof messageAnnotationsStreamPart.parse>;

export const streamPartsByCode = {
[textStreamPart.code]: textStreamPart,
Expand All @@ -269,6 +287,7 @@ export const streamPartsByCode = {
[assistantControlDataStreamPart.code]: assistantControlDataStreamPart,
[dataMessageStreamPart.code]: dataMessageStreamPart,
[toolCallStreamPart.code]: toolCallStreamPart,
[messageAnnotationsStreamPart.code]: messageAnnotationsStreamPart,
} as const;

/**
Expand Down Expand Up @@ -302,6 +321,7 @@ export const StreamStringPrefixes = {
[assistantControlDataStreamPart.name]: assistantControlDataStreamPart.code,
[dataMessageStreamPart.name]: dataMessageStreamPart.code,
[toolCallStreamPart.name]: toolCallStreamPart.code,
[messageAnnotationsStreamPart.name]: messageAnnotationsStreamPart.code,
} as const;

export const validCodes = streamParts.map(part => part.code);
Expand Down
5 changes: 5 additions & 0 deletions packages/core/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ export interface Message {
* the tool call name and arguments. Otherwise, the field should not be set.
*/
tool_calls?: string | ToolCall[];

/**
* Additional message-specific information added on the server via StreamData
*/
annotations?: JSONValue[] | undefined;
}

export type CreateMessage = Omit<Message, 'id'> & {
Expand Down
20 changes: 16 additions & 4 deletions packages/core/streams/ai-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,24 @@ export interface AIStreamCallbacksAndOptions {
experimental_streamData?: boolean;
}

// new TokenData()
// data: TokenData,
/**
* Options for the AIStreamParser.
* @interface
* @property {string} event - The event (type) from the server side event stream.
*/
export interface AIStreamParserOptions {
event?: string;
}

/**
* Custom parser for AIStream data.
* @interface
* @param {string} data - The data to be parsed.
* @param {AIStreamParserOptions} options - The options for the parser.
* @returns {string | void} The parsed data or void.
*/
export interface AIStreamParser {
(data: string): string | void;
(data: string, options: AIStreamParserOptions): string | void;
}

/**
Expand Down Expand Up @@ -82,7 +92,9 @@ export function createEventStreamTransformer(

if ('data' in event) {
const parsedMessage = customParser
? customParser(event.data)
? customParser(event.data, {
event: event.event,
})
: event.data;
if (parsedMessage) controller.enqueue(parsedMessage);
}
Expand Down
24 changes: 24 additions & 0 deletions packages/core/streams/stream-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export class experimental_StreamData {

// array to store appended data
private data: JSONValue[] = [];
private messageAnnotations: JSONValue[] = [];

constructor() {
this.isClosedPromise = new Promise(resolve => {
this.isClosedPromiseResolver = resolve;
Expand All @@ -39,6 +41,13 @@ export class experimental_StreamData {
controller.enqueue(encodedData);
}

if (self.messageAnnotations.length) {
const encodedmessageAnnotations = self.encoder.encode(
formatStreamPart('message_annotations', self.messageAnnotations),
);
controller.enqueue(encodedmessageAnnotations);
}

controller.enqueue(chunk);
},
async flush(controller) {
Expand All @@ -64,6 +73,13 @@ export class experimental_StreamData {
);
controller.enqueue(encodedData);
}

if (self.messageAnnotations.length) {
const encodedData = self.encoder.encode(
formatStreamPart('message_annotations', self.messageAnnotations),
);
controller.enqueue(encodedData);
}
},
});
}
Expand All @@ -88,6 +104,14 @@ export class experimental_StreamData {

this.data.push(value);
}

appendMessageAnnotation(value: JSONValue): void {
if (this.isClosed) {
throw new Error('Data Stream has already been closed.');
}

this.messageAnnotations.push(value);
}
}

/**
Expand Down

0 comments on commit 7851fa0

Please sign in to comment.