Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs AI Assistant] Improve recall speed #176428

Merged
merged 16 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 45 additions & 61 deletions x-pack/plugins/observability_ai_assistant/server/functions/recall.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import type { Serializable } from '@kbn/utility-types';
import dedent from 'dedent';
import * as t from 'io-ts';
import { last, omit } from 'lodash';
import { compact, last, omit } from 'lodash';
import { lastValueFrom } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageRole, type Message } from '../../common/types';
Expand Down Expand Up @@ -88,12 +88,17 @@ export function registerRecallFunction({
messages.filter((message) => message.message.role === MessageRole.User)
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a good use case for findLast (well-supported and not nearly used enough :D )

Suggested change
messages.filter((message) => message.message.role === MessageRole.User)
messages.findLast((message) => message.message.role === MessageRole.User)

Copy link
Member Author

@dgieselaar dgieselaar Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know this exists 😄 unfortunately TS doesn't know about it either (??)

Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately TS doesn't know about it either (??)

Argh, I see that now. Odd since CanIUse reports mainstream support in all browsers we care about since at least 2022.

Edit seems like we need to wait until Typescript 5

);

const nonEmptyQueries = compact(queries);

const queriesOrUserPrompt = nonEmptyQueries.length
? nonEmptyQueries
: compact([userMessage?.message.content]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not always filter by the user query combined with the query from the LLM? Is that too restrictive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'd like to investigate is have the LLM decide when to recall (other than the first message). E.g., if somebody asks "what does the following error mean: ..." and then "what are the consequences of this error", the latter doesn't really need a recall. If they ask "how does this affect my checkout service" it does. The LLM should be able to classify this, and rewrite it in a way that does not require us to send over the entire conversation. In that case, the query it chooses should be the only thing we use. So, I'm preparing for that future.


const suggestions = await retrieveSuggestions({
userMessage,
client,
signal,
categories,
queries,
queries: queriesOrUserPrompt,
});

if (suggestions.length === 0) {
Expand All @@ -104,9 +109,8 @@ export function registerRecallFunction({

const relevantDocuments = await scoreSuggestions({
suggestions,
systemMessage,
userMessage,
queries,
queries: queriesOrUserPrompt,
messages,
client,
connectorId,
signal,
Expand All @@ -121,25 +125,17 @@ export function registerRecallFunction({
}

async function retrieveSuggestions({
userMessage,
queries,
client,
categories,
signal,
}: {
userMessage?: Message;
queries: string[];
client: ObservabilityAIAssistantClient;
categories: Array<'apm' | 'lens'>;
signal: AbortSignal;
}) {
const queriesWithUserPrompt =
userMessage && userMessage.message.content
? [userMessage.message.content, ...queries]
: queries;

const recallResponse = await client.recall({
queries: queriesWithUserPrompt,
queries,
categories,
});

Expand All @@ -156,54 +152,44 @@ const scoreFunctionRequestRt = t.type({
});

const scoreFunctionArgumentsRt = t.type({
scores: t.array(
t.type({
id: t.string,
score: t.number,
})
),
scores: t.string,
});

async function scoreSuggestions({
suggestions,
systemMessage,
userMessage,
messages,
queries,
client,
connectorId,
signal,
resources,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
systemMessage: Message;
userMessage?: Message;
messages: Message[];
queries: string[];
client: ObservabilityAIAssistantClient;
connectorId: string;
signal: AbortSignal;
resources: RespondFunctionResources;
}) {
resources.logger.debug(`Suggestions: ${JSON.stringify(suggestions, null, 2)}`);
const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index }));

const systemMessageExtension =
dedent(`You have the function called score available to help you inform the user about how relevant you think a given document is to the conversation.
Please give a score between 1 and 7, fractions are allowed.
A higher score means it is more relevant.`);
const extendedSystemMessage = {
...systemMessage,
message: {
...systemMessage.message,
content: `${systemMessage.message.content}\n\n${systemMessageExtension}`,
},
};
const newUserMessageContent =
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely relevant, and 7 being extremely relevant. Information is relevant to the question if it helps in
answering the question. Judge it according to the following criteria:
const userMessageOrQueries =
userMessage && userMessage.message.content ? userMessage.message.content : queries.join(',');
- The document is relevant to the question, and the rest of the conversation
- The document has information relevant to the question that is not mentioned,
or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the question compared to other documents
- The document contains new information not mentioned before in the conversation
const newUserMessageContent =
dedent(`Given the question "${userMessageOrQueries}", can you give me a score for how relevant the following documents are?
Question:
${queries.join('\n')}
${JSON.stringify(suggestions, null, 2)}`);
Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);

const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
Expand All @@ -222,22 +208,13 @@ async function scoreSuggestions({
additionalProperties: false,
properties: {
scores: {
description: 'The document IDs and their scores',
type: 'array',
items: {
type: 'object',
additionalProperties: false,
properties: {
id: {
description: 'The ID of the document',
type: 'string',
},
score: {
description: 'The score for the document',
type: 'number',
},
},
},
description: `The document IDs and their scores, as CSV. Example:
my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['score'],
Expand All @@ -249,19 +226,26 @@ async function scoreSuggestions({
(
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
messages: [...messages.slice(0, -1), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);

resources.logger.debug(`Scores: ${JSON.stringify(scores, null, 2)}`);
const scores = scoresAsString.split('\n').map((line) => {
const [index, score] = line
.split(',')
.map((value) => value.trim())
.map(Number);

return { id: suggestions[index].id, score };
});
Comment on lines +241 to +248
Copy link
Member

@sorenlouv sorenlouv Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How confident are we that this is the format the LLM will respond with (seeing we support multiple LLMs soon)? Should we handle the cases where the LLM will respond in non-csv format?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look in the Bedrock PR, but I don't think we should expect too much from or invest too much in LLMs other than OpenAI, until there is an LLM that has similar performance.


if (scores.length === 0) {
return [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ import apm from 'elastic-apm-node';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
import type OpenAI from 'openai';
import { filter, isObservable, lastValueFrom, Observable, shareReplay, toArray } from 'rxjs';
import {
filter,
firstValueFrom,
isObservable,
lastValueFrom,
Observable,
shareReplay,
toArray,
} from 'rxjs';
import { Readable } from 'stream';
import { v4 } from 'uuid';
import {
Expand Down Expand Up @@ -455,6 +463,8 @@ export class ObservabilityAIAssistantClient {
): Promise<Observable<ChatCompletionChunkEvent>> => {
const span = apm.startSpan(`chat ${name}`);

const spanId = (span?.ids['span.id'] || '').substring(0, 6);

const messagesForOpenAI: Array<
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
role: MessageRole;
Expand Down Expand Up @@ -490,6 +500,8 @@ export class ObservabilityAIAssistantClient {
this.dependencies.logger.debug(`Sending conversation to connector`);
this.dependencies.logger.trace(JSON.stringify(request, null, 2));

const now = performance.now();

const executeResult = await this.dependencies.actionsClient.execute({
actionId: connectorId,
params: {
Expand All @@ -501,7 +513,11 @@ export class ObservabilityAIAssistantClient {
},
});

this.dependencies.logger.debug(`Received action client response: ${executeResult.status}`);
this.dependencies.logger.debug(
`Received action client response: ${executeResult.status} (took: ${Math.round(
performance.now() - now
)}ms)${spanId ? ` (${spanId})` : ''}`
);

if (executeResult.status === 'error' && executeResult?.serviceMessage) {
const tokenLimitRegex =
Expand All @@ -524,20 +540,34 @@ export class ObservabilityAIAssistantClient {

const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());

if (span) {
lastValueFrom(observable)
.then(
() => {
span.setOutcome('success');
},
() => {
span.setOutcome('failure');
}
)
.finally(() => {
span.end();
});
}
firstValueFrom(observable)
.catch(noop)
.finally(() => {
this.dependencies.logger.debug(
`Received first value after ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
Copy link
Contributor

@klacabane klacabane Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these intermediate metrics only available in logs or can we somehow store them in the span ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, as labels, but not sure if it has tremendous value - adding labels changes the mapping so I'm a little wary of adding more. Let's see if we need it.

}`
);
});

lastValueFrom(observable)
.then(
() => {
span?.setOutcome('success');
},
() => {
span?.setOutcome('failure');
}
)
.finally(() => {
this.dependencies.logger.debug(
`Completed response in ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
}`
);

span?.end();
});

return observable;
};
Expand Down