Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs AI Assistant] Improve recall speed #176428

Merged
merged 16 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 47 additions & 59 deletions x-pack/plugins/observability_ai_assistant/server/functions/recall.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import type { Serializable } from '@kbn/utility-types';
import dedent from 'dedent';
import * as t from 'io-ts';
import { last, omit } from 'lodash';
import { compact, last, omit } from 'lodash';
import { lastValueFrom } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageRole, type Message } from '../../common/types';
Expand Down Expand Up @@ -87,12 +87,17 @@ export function registerRecallFunction({
messages.filter((message) => message.message.role === MessageRole.User)
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a good use case for findLast (well-supported and not nearly used enough :D )

Suggested change
messages.filter((message) => message.message.role === MessageRole.User)
messages.findLast((message) => message.message.role === MessageRole.User)

Copy link
Member Author

@dgieselaar dgieselaar Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know this exists 😄 unfortunately TS doesn't know about it either (??)

Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately TS doesn't know about it either (??)

Argh, I see that now. Odd since CanIUse reports mainstream support in all browsers we care about since at least 2022.

Edit seems like we need to wait until Typescript 5

);

const nonEmptyQueries = queries.filter(Boolean);
dgieselaar marked this conversation as resolved.
Show resolved Hide resolved

const queriesOrUserPrompt = nonEmptyQueries.length
? nonEmptyQueries
: compact([userMessage?.message.content]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not always filter by the user query combined with the query from the LLM? Is that too restrictive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'd like to investigate is have the LLM decide when to recall (other than the first message). E.g., if somebody asks "what does the following error mean: ..." and then "what are the consequences of this error", the latter doesn't really need a recall. If they ask "how does this affect my checkout service" it does. The LLM should be able to classify this, and rewrite it in a way that does not require us to send over the entire conversation. In that case, the query it chooses should be the only thing we use. So, I'm preparing for that future.


const suggestions = await retrieveSuggestions({
userMessage,
client,
signal,
categories,
queries,
queries: queriesOrUserPrompt,
});

resources.logger.debug(`Received ${suggestions.length} suggestions`);
Expand All @@ -107,9 +112,8 @@ export function registerRecallFunction({

const relevantDocuments = await scoreSuggestions({
suggestions,
systemMessage,
userMessage,
queries,
queries: queriesOrUserPrompt,
messages,
client,
connectorId,
signal,
Expand All @@ -126,25 +130,17 @@ export function registerRecallFunction({
}

async function retrieveSuggestions({
userMessage,
queries,
client,
categories,
signal,
}: {
userMessage?: Message;
queries: string[];
client: ObservabilityAIAssistantClient;
categories: Array<'apm' | 'lens'>;
signal: AbortSignal;
}) {
const queriesWithUserPrompt =
userMessage && userMessage.message.content
? [userMessage.message.content, ...queries]
: queries;

const recallResponse = await client.recall({
queries: queriesWithUserPrompt,
queries,
categories,
});

Expand All @@ -161,50 +157,42 @@ const scoreFunctionRequestRt = t.type({
});

const scoreFunctionArgumentsRt = t.type({
scores: t.array(
t.type({
id: t.string,
score: t.number,
})
),
scores: t.string,
});

async function scoreSuggestions({
suggestions,
systemMessage,
userMessage,
messages,
queries,
client,
connectorId,
signal,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
systemMessage: Message;
userMessage?: Message;
messages: Message[];
queries: string[];
client: ObservabilityAIAssistantClient;
connectorId: string;
signal: AbortSignal;
}) {
const systemMessageExtension =
dedent(`You have the function called score available to help you inform the user about how relevant you think a given document is to the conversation.
Please give a score between 1 and 7, fractions are allowed.
A higher score means it is more relevant.`);
const extendedSystemMessage = {
...systemMessage,
message: {
...systemMessage.message,
content: `${systemMessage.message.content}\n\n${systemMessageExtension}`,
},
};

const userMessageOrQueries =
userMessage && userMessage.message.content ? userMessage.message.content : queries.join(',');
const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index }));

const newUserMessageContent =
dedent(`Given the question "${userMessageOrQueries}", can you give me a score for how relevant the following documents are?
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely relevant, and 10 being extremely relevant. Information is relevant to the question if it helps in
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been wondering about this before: why is the scoring interval 0-7? Is it something magic?
Btw It looks like you expect a scoring between 0-10

Suggested change
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely relevant, and 10 being extremely relevant. Information is relevant to the question if it helps in
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 10,
0 being completely relevant, and 10 being extremely relevant. Information is relevant to the question if it helps in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miltonhultgren any thoughts here on 0 to 7 vs 10?

answering the question. Judge it according to the following criteria:

- The document is relevant to the question, and the rest of the conversation
- The document has information relevant to the question that is not mentioned,
or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the question compared to other documents
- The document contains new information not mentioned before in the conversation

${JSON.stringify(suggestions, null, 2)}`);
Question:
${queries.join('\n')}

Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);

const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
Expand All @@ -223,22 +211,13 @@ async function scoreSuggestions({
additionalProperties: false,
properties: {
scores: {
description: 'The document IDs and their scores',
type: 'array',
items: {
type: 'object',
additionalProperties: false,
properties: {
id: {
description: 'The ID of the document',
type: 'string',
},
score: {
description: 'The score for the document',
type: 'number',
},
},
},
description: `The document IDs and their scores, as CSV. Example:

my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['score'],
Expand All @@ -250,18 +229,27 @@ async function scoreSuggestions({
(
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
messages: [...messages.slice(-1), newUserMessage],
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we use last instead? (my head spins when doing array math)

Suggested change
messages: [...messages.slice(-1), newUserMessage],
messages: [last(messages), newUserMessage],

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a typo 😄. it's supposed to be "everything except the last", so .slice(0,-1). Will correct.

functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);

const scores = scoresAsString.split('\n').map((line) => {
const [index, score] = line
.split(',')
.map((value) => value.trim())
.map(Number);

return { id: suggestions[index].id, score };
});
Comment on lines +241 to +248
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How confident are we that this is the format the LLM will respond with (seeing we support multiple LLMs soon)? Should we handle the cases where the LLM will respond in non-csv format?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look in the Bedrock PR, but I don't think we should expect too much from or invest too much in LLMs other than OpenAI, until there is an LLM that has similar performance.


if (scores.length === 0) {
return [];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ import apm from 'elastic-apm-node';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
import type OpenAI from 'openai';
import { filter, isObservable, lastValueFrom, Observable, shareReplay, toArray } from 'rxjs';
import {
filter,
firstValueFrom,
isObservable,
lastValueFrom,
Observable,
shareReplay,
toArray,
} from 'rxjs';
import { Readable } from 'stream';
import { v4 } from 'uuid';
import {
Expand Down Expand Up @@ -449,6 +457,8 @@ export class ObservabilityAIAssistantClient {
): Promise<Observable<ChatCompletionChunkEvent>> => {
const span = apm.startSpan(`chat ${name}`);

const spanId = (span?.ids['span.id'] || '').substring(0, 6);

const messagesForOpenAI: Array<
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
role: MessageRole;
Expand Down Expand Up @@ -484,6 +494,8 @@ export class ObservabilityAIAssistantClient {
this.dependencies.logger.debug(`Sending conversation to connector`);
this.dependencies.logger.trace(JSON.stringify(request, null, 2));

const now = performance.now();

const executeResult = await this.dependencies.actionsClient.execute({
actionId: connectorId,
params: {
Expand All @@ -495,7 +507,11 @@ export class ObservabilityAIAssistantClient {
},
});

this.dependencies.logger.debug(`Received action client response: ${executeResult.status}`);
this.dependencies.logger.debug(
`Received action client response: ${executeResult.status} (took: ${Math.round(
performance.now() - now
)}ms)${spanId ? ` (${spanId})` : ''}`
);

if (executeResult.status === 'error' && executeResult?.serviceMessage) {
const tokenLimitRegex =
Expand All @@ -518,20 +534,37 @@ export class ObservabilityAIAssistantClient {

const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());

if (span) {
lastValueFrom(observable)
.then(
() => {
span.setOutcome('success');
},
() => {
span.setOutcome('failure');
}
)
.finally(() => {
span.end();
});
}
firstValueFrom(observable)
.then(
() => {},
() => {}
)
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this then necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, just doing this to catch an error so it doesn't result in an unhandled promise rejection. But I don't need a then, I can just use a catch.

.finally(() => {
this.dependencies.logger.debug(
`Received first value after ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
Copy link
Contributor

@klacabane klacabane Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these intermediate metrics only available in logs or can we somehow store them in the span ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, as labels, but not sure if it has tremendous value - adding labels changes the mapping so I'm a little wary of adding more. Let's see if we need it.

}`
);
});

lastValueFrom(observable)
.then(
() => {
span?.setOutcome('success');
},
() => {
span?.setOutcome('failure');
}
)
.finally(() => {
this.dependencies.logger.debug(
`Completed response in ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
}`
);

span?.end();
});

return observable;
};
Expand Down