Skip to content

Commit

Permalink
[Obs AI Assistant] Improve recall speed (elastic#176428)
Browse files Browse the repository at this point in the history
Improves recall speed by outputting as CSV with zero-indexed document
"ids". Previously, it was a JSON object, with the real document ids.
This causes the LLM to "think" for longer, for whatever reason. I didn't
actually see a difference in completion speed, but emitting the first
value took significantly less time when using the CSV output. I also
tried sending a single document per request using the old format, and
while that certainly improves things, the slowest request becomes the
bottleneck. These are results from about 10 tries per strategy (I'd love
to see others reproduce at least the `batch` vs `csv` strategy results):

`batch`: 24.7s
`chunk`: 10s
`csv`: 4.9s

---------

Co-authored-by: Søren Louv-Jansen <sorenlouv@gmail.com>
Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
  • Loading branch information
3 people authored and fkanout committed Mar 4, 2024
1 parent 48d1094 commit d7cef59
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 77 deletions.
106 changes: 45 additions & 61 deletions x-pack/plugins/observability_ai_assistant/server/functions/recall.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import type { Serializable } from '@kbn/utility-types';
import dedent from 'dedent';
import * as t from 'io-ts';
import { last, omit } from 'lodash';
import { compact, last, omit } from 'lodash';
import { lastValueFrom } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageRole, type Message } from '../../common/types';
Expand Down Expand Up @@ -88,12 +88,17 @@ export function registerRecallFunction({
messages.filter((message) => message.message.role === MessageRole.User)
);

const nonEmptyQueries = compact(queries);

const queriesOrUserPrompt = nonEmptyQueries.length
? nonEmptyQueries
: compact([userMessage?.message.content]);

const suggestions = await retrieveSuggestions({
userMessage,
client,
signal,
categories,
queries,
queries: queriesOrUserPrompt,
});

if (suggestions.length === 0) {
Expand All @@ -104,9 +109,8 @@ export function registerRecallFunction({

const relevantDocuments = await scoreSuggestions({
suggestions,
systemMessage,
userMessage,
queries,
queries: queriesOrUserPrompt,
messages,
client,
connectorId,
signal,
Expand All @@ -121,25 +125,17 @@ export function registerRecallFunction({
}

async function retrieveSuggestions({
userMessage,
queries,
client,
categories,
signal,
}: {
userMessage?: Message;
queries: string[];
client: ObservabilityAIAssistantClient;
categories: Array<'apm' | 'lens'>;
signal: AbortSignal;
}) {
const queriesWithUserPrompt =
userMessage && userMessage.message.content
? [userMessage.message.content, ...queries]
: queries;

const recallResponse = await client.recall({
queries: queriesWithUserPrompt,
queries,
categories,
});

Expand All @@ -156,54 +152,44 @@ const scoreFunctionRequestRt = t.type({
});

const scoreFunctionArgumentsRt = t.type({
scores: t.array(
t.type({
id: t.string,
score: t.number,
})
),
scores: t.string,
});

async function scoreSuggestions({
suggestions,
systemMessage,
userMessage,
messages,
queries,
client,
connectorId,
signal,
resources,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
systemMessage: Message;
userMessage?: Message;
messages: Message[];
queries: string[];
client: ObservabilityAIAssistantClient;
connectorId: string;
signal: AbortSignal;
resources: RespondFunctionResources;
}) {
resources.logger.debug(`Suggestions: ${JSON.stringify(suggestions, null, 2)}`);
const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index }));

const systemMessageExtension =
dedent(`You have the function called score available to help you inform the user about how relevant you think a given document is to the conversation.
Please give a score between 1 and 7, fractions are allowed.
A higher score means it is more relevant.`);
const extendedSystemMessage = {
...systemMessage,
message: {
...systemMessage.message,
content: `${systemMessage.message.content}\n\n${systemMessageExtension}`,
},
};
const newUserMessageContent =
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely relevant, and 7 being extremely relevant. Information is relevant to the question if it helps in
answering the question. Judge it according to the following criteria:
const userMessageOrQueries =
userMessage && userMessage.message.content ? userMessage.message.content : queries.join(',');
- The document is relevant to the question, and the rest of the conversation
- The document has information relevant to the question that is not mentioned,
or more detailed than what is available in the conversation
- The document has a high amount of information relevant to the question compared to other documents
- The document contains new information not mentioned before in the conversation
const newUserMessageContent =
dedent(`Given the question "${userMessageOrQueries}", can you give me a score for how relevant the following documents are?
Question:
${queries.join('\n')}
${JSON.stringify(suggestions, null, 2)}`);
Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);

const newUserMessage: Message = {
'@timestamp': new Date().toISOString(),
Expand All @@ -222,22 +208,13 @@ async function scoreSuggestions({
additionalProperties: false,
properties: {
scores: {
description: 'The document IDs and their scores',
type: 'array',
items: {
type: 'object',
additionalProperties: false,
properties: {
id: {
description: 'The ID of the document',
type: 'string',
},
score: {
description: 'The score for the document',
type: 'number',
},
},
},
description: `The document IDs and their scores, as CSV. Example:
my_id,7
my_other_id,3
my_third_id,4
`,
type: 'string',
},
},
required: ['score'],
Expand All @@ -249,19 +226,26 @@ async function scoreSuggestions({
(
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
messages: [...messages.slice(0, -1), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
);

resources.logger.debug(`Scores: ${JSON.stringify(scores, null, 2)}`);
const scores = scoresAsString.split('\n').map((line) => {
const [index, score] = line
.split(',')
.map((value) => value.trim())
.map(Number);

return { id: suggestions[index].id, score };
});

if (scores.length === 0) {
return [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ import apm from 'elastic-apm-node';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
import type OpenAI from 'openai';
import { filter, isObservable, lastValueFrom, Observable, shareReplay, toArray } from 'rxjs';
import {
filter,
firstValueFrom,
isObservable,
lastValueFrom,
Observable,
shareReplay,
toArray,
} from 'rxjs';
import { Readable } from 'stream';
import { v4 } from 'uuid';
import {
Expand Down Expand Up @@ -455,6 +463,8 @@ export class ObservabilityAIAssistantClient {
): Promise<Observable<ChatCompletionChunkEvent>> => {
const span = apm.startSpan(`chat ${name}`);

const spanId = (span?.ids['span.id'] || '').substring(0, 6);

const messagesForOpenAI: Array<
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
role: MessageRole;
Expand Down Expand Up @@ -490,6 +500,8 @@ export class ObservabilityAIAssistantClient {
this.dependencies.logger.debug(`Sending conversation to connector`);
this.dependencies.logger.trace(JSON.stringify(request, null, 2));

const now = performance.now();

const executeResult = await this.dependencies.actionsClient.execute({
actionId: connectorId,
params: {
Expand All @@ -501,7 +513,11 @@ export class ObservabilityAIAssistantClient {
},
});

this.dependencies.logger.debug(`Received action client response: ${executeResult.status}`);
this.dependencies.logger.debug(
`Received action client response: ${executeResult.status} (took: ${Math.round(
performance.now() - now
)}ms)${spanId ? ` (${spanId})` : ''}`
);

if (executeResult.status === 'error' && executeResult?.serviceMessage) {
const tokenLimitRegex =
Expand All @@ -524,20 +540,34 @@ export class ObservabilityAIAssistantClient {

const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());

if (span) {
lastValueFrom(observable)
.then(
() => {
span.setOutcome('success');
},
() => {
span.setOutcome('failure');
}
)
.finally(() => {
span.end();
});
}
firstValueFrom(observable)
.catch(noop)
.finally(() => {
this.dependencies.logger.debug(
`Received first value after ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
}`
);
});

lastValueFrom(observable)
.then(
() => {
span?.setOutcome('success');
},
() => {
span?.setOutcome('failure');
}
)
.finally(() => {
this.dependencies.logger.debug(
`Completed response in ${Math.round(performance.now() - now)}ms${
spanId ? ` (${spanId})` : ''
}`
);

span?.end();
});

return observable;
};
Expand Down

0 comments on commit d7cef59

Please sign in to comment.