From fc58a0d3a71dd946fb24a75050930030c002d2a4 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Thu, 8 Feb 2024 17:27:24 +0100 Subject: [PATCH] [Obs AI Assistant] Improve recall speed (#176428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improves recall speed by outputting as CSV with zero-indexed document "ids". Previously, it was a JSON object, with the real document ids. This causes the LLM to "think" for longer, for whatever reason. I didn't actually see a difference in completion speed, but emitting the first value took significantly less time when using the CSV output. I also tried sending a single document per request using the old format, and while that certainly improves things, the slowest request becomes the bottleneck. These are results from about 10 tries per strategy (I'd love to see others reproduce at least the `batch` vs `csv` strategy results): `batch`: 24.7s `chunk`: 10s `csv`: 4.9s --------- Co-authored-by: Søren Louv-Jansen Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> --- .../server/functions/recall.ts | 106 ++++++++---------- .../server/service/client/index.ts | 62 +++++++--- 2 files changed, 91 insertions(+), 77 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts index ee0fae1f91ed1..909a823286cc6 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts @@ -9,7 +9,7 @@ import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils'; import type { Serializable } from '@kbn/utility-types'; import dedent from 'dedent'; import * as t from 'io-ts'; -import { last, omit } from 'lodash'; +import { compact, last, omit } from 'lodash'; import { lastValueFrom } from 'rxjs'; import { FunctionRegistrationParameters } from '.'; import { MessageRole, type Message } from '../../common/types'; @@ -88,12 +88,17 @@ export function registerRecallFunction({ messages.filter((message) => message.message.role === MessageRole.User) ); + const nonEmptyQueries = compact(queries); + + const queriesOrUserPrompt = nonEmptyQueries.length + ? nonEmptyQueries + : compact([userMessage?.message.content]); + const suggestions = await retrieveSuggestions({ userMessage, client, - signal, categories, - queries, + queries: queriesOrUserPrompt, }); if (suggestions.length === 0) { @@ -104,9 +109,8 @@ export function registerRecallFunction({ const relevantDocuments = await scoreSuggestions({ suggestions, - systemMessage, - userMessage, - queries, + queries: queriesOrUserPrompt, + messages, client, connectorId, signal, @@ -121,25 +125,17 @@ export function registerRecallFunction({ } async function retrieveSuggestions({ - userMessage, queries, client, categories, - signal, }: { userMessage?: Message; queries: string[]; client: ObservabilityAIAssistantClient; categories: Array<'apm' | 'lens'>; - signal: AbortSignal; }) { - const queriesWithUserPrompt = - userMessage && userMessage.message.content - ? [userMessage.message.content, ...queries] - : queries; - const recallResponse = await client.recall({ - queries: queriesWithUserPrompt, + queries, categories, }); @@ -156,18 +152,12 @@ const scoreFunctionRequestRt = t.type({ }); const scoreFunctionArgumentsRt = t.type({ - scores: t.array( - t.type({ - id: t.string, - score: t.number, - }) - ), + scores: t.string, }); async function scoreSuggestions({ suggestions, - systemMessage, - userMessage, + messages, queries, client, connectorId, @@ -175,35 +165,31 @@ async function scoreSuggestions({ resources, }: { suggestions: Awaited>; - systemMessage: Message; - userMessage?: Message; + messages: Message[]; queries: string[]; client: ObservabilityAIAssistantClient; connectorId: string; signal: AbortSignal; resources: RespondFunctionResources; }) { - resources.logger.debug(`Suggestions: ${JSON.stringify(suggestions, null, 2)}`); + const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index })); - const systemMessageExtension = - dedent(`You have the function called score available to help you inform the user about how relevant you think a given document is to the conversation. - Please give a score between 1 and 7, fractions are allowed. - A higher score means it is more relevant.`); - const extendedSystemMessage = { - ...systemMessage, - message: { - ...systemMessage.message, - content: `${systemMessage.message.content}\n\n${systemMessageExtension}`, - }, - }; + const newUserMessageContent = + dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7, + 0 being completely relevant, and 7 being extremely relevant. Information is relevant to the question if it helps in + answering the question. Judge it according to the following criteria: - const userMessageOrQueries = - userMessage && userMessage.message.content ? userMessage.message.content : queries.join(','); + - The document is relevant to the question, and the rest of the conversation + - The document has information relevant to the question that is not mentioned, + or more detailed than what is available in the conversation + - The document has a high amount of information relevant to the question compared to other documents + - The document contains new information not mentioned before in the conversation - const newUserMessageContent = - dedent(`Given the question "${userMessageOrQueries}", can you give me a score for how relevant the following documents are? + Question: + ${queries.join('\n')} - ${JSON.stringify(suggestions, null, 2)}`); + Documents: + ${JSON.stringify(indexedSuggestions, null, 2)}`); const newUserMessage: Message = { '@timestamp': new Date().toISOString(), @@ -222,22 +208,13 @@ async function scoreSuggestions({ additionalProperties: false, properties: { scores: { - description: 'The document IDs and their scores', - type: 'array', - items: { - type: 'object', - additionalProperties: false, - properties: { - id: { - description: 'The ID of the document', - type: 'string', - }, - score: { - description: 'The score for the document', - type: 'number', - }, - }, - }, + description: `The document IDs and their scores, as CSV. Example: + + my_id,7 + my_other_id,3 + my_third_id,4 + `, + type: 'string', }, }, required: ['score'], @@ -249,7 +226,7 @@ async function scoreSuggestions({ ( await client.chat('score_suggestions', { connectorId, - messages: [extendedSystemMessage, newUserMessage], + messages: [...messages.slice(0, -1), newUserMessage], functions: [scoreFunction], functionCall: 'score', signal, @@ -257,11 +234,18 @@ async function scoreSuggestions({ ).pipe(concatenateChatCompletionChunks()) ); const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); - const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( + const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( scoreFunctionRequest.message.function_call.arguments ); - resources.logger.debug(`Scores: ${JSON.stringify(scores, null, 2)}`); + const scores = scoresAsString.split('\n').map((line) => { + const [index, score] = line + .split(',') + .map((value) => value.trim()) + .map(Number); + + return { id: suggestions[index].id, score }; + }); if (scores.length === 0) { return []; diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index f3ab3e917979b..afd34aa8ea966 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -14,7 +14,15 @@ import apm from 'elastic-apm-node'; import { decode, encode } from 'gpt-tokenizer'; import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash'; import type OpenAI from 'openai'; -import { filter, isObservable, lastValueFrom, Observable, shareReplay, toArray } from 'rxjs'; +import { + filter, + firstValueFrom, + isObservable, + lastValueFrom, + Observable, + shareReplay, + toArray, +} from 'rxjs'; import { Readable } from 'stream'; import { v4 } from 'uuid'; import { @@ -455,6 +463,8 @@ export class ObservabilityAIAssistantClient { ): Promise> => { const span = apm.startSpan(`chat ${name}`); + const spanId = (span?.ids['span.id'] || '').substring(0, 6); + const messagesForOpenAI: Array< Omit & { role: MessageRole; @@ -490,6 +500,8 @@ export class ObservabilityAIAssistantClient { this.dependencies.logger.debug(`Sending conversation to connector`); this.dependencies.logger.trace(JSON.stringify(request, null, 2)); + const now = performance.now(); + const executeResult = await this.dependencies.actionsClient.execute({ actionId: connectorId, params: { @@ -501,7 +513,11 @@ export class ObservabilityAIAssistantClient { }, }); - this.dependencies.logger.debug(`Received action client response: ${executeResult.status}`); + this.dependencies.logger.debug( + `Received action client response: ${executeResult.status} (took: ${Math.round( + performance.now() - now + )}ms)${spanId ? ` (${spanId})` : ''}` + ); if (executeResult.status === 'error' && executeResult?.serviceMessage) { const tokenLimitRegex = @@ -524,20 +540,34 @@ export class ObservabilityAIAssistantClient { const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay()); - if (span) { - lastValueFrom(observable) - .then( - () => { - span.setOutcome('success'); - }, - () => { - span.setOutcome('failure'); - } - ) - .finally(() => { - span.end(); - }); - } + firstValueFrom(observable) + .catch(noop) + .finally(() => { + this.dependencies.logger.debug( + `Received first value after ${Math.round(performance.now() - now)}ms${ + spanId ? ` (${spanId})` : '' + }` + ); + }); + + lastValueFrom(observable) + .then( + () => { + span?.setOutcome('success'); + }, + () => { + span?.setOutcome('failure'); + } + ) + .finally(() => { + this.dependencies.logger.debug( + `Completed response in ${Math.round(performance.now() - now)}ms${ + spanId ? ` (${spanId})` : '' + }` + ); + + span?.end(); + }); return observable; };