-
Notifications
You must be signed in to change notification settings - Fork 8.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Obs AI Assistant] Improve recall speed #176428
Changes from all commits
260ba9b
749f22b
46d82b5
80877b7
8db2c42
5d99236
3d2d81b
4ee1847
394e3ce
127a422
6a095e6
9bfbc32
9ce2046
1b36de7
4dac0f3
37fb2c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils'; | |
import type { Serializable } from '@kbn/utility-types'; | ||
import dedent from 'dedent'; | ||
import * as t from 'io-ts'; | ||
import { last, omit } from 'lodash'; | ||
import { compact, last, omit } from 'lodash'; | ||
import { lastValueFrom } from 'rxjs'; | ||
import { FunctionRegistrationParameters } from '.'; | ||
import { MessageRole, type Message } from '../../common/types'; | ||
|
@@ -88,12 +88,17 @@ export function registerRecallFunction({ | |
messages.filter((message) => message.message.role === MessageRole.User) | ||
); | ||
|
||
const nonEmptyQueries = compact(queries); | ||
|
||
const queriesOrUserPrompt = nonEmptyQueries.length | ||
? nonEmptyQueries | ||
: compact([userMessage?.message.content]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not always filter by the user query combined with the query from the LLM? Is that too restrictive? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I'd like to investigate is have the LLM decide when to recall (other than the first message). E.g., if somebody asks "what does the following error mean: ..." and then "what are the consequences of this error", the latter doesn't really need a recall. If they ask "how does this affect my checkout service" it does. The LLM should be able to classify this, and rewrite it in a way that does not require us to send over the entire conversation. In that case, the query it chooses should be the only thing we use. So, I'm preparing for that future. |
||
|
||
const suggestions = await retrieveSuggestions({ | ||
userMessage, | ||
client, | ||
signal, | ||
categories, | ||
queries, | ||
queries: queriesOrUserPrompt, | ||
}); | ||
|
||
if (suggestions.length === 0) { | ||
|
@@ -104,9 +109,8 @@ export function registerRecallFunction({ | |
|
||
const relevantDocuments = await scoreSuggestions({ | ||
suggestions, | ||
systemMessage, | ||
userMessage, | ||
queries, | ||
queries: queriesOrUserPrompt, | ||
messages, | ||
client, | ||
connectorId, | ||
signal, | ||
|
@@ -121,25 +125,17 @@ export function registerRecallFunction({ | |
} | ||
|
||
async function retrieveSuggestions({ | ||
userMessage, | ||
queries, | ||
client, | ||
categories, | ||
signal, | ||
}: { | ||
userMessage?: Message; | ||
queries: string[]; | ||
client: ObservabilityAIAssistantClient; | ||
categories: Array<'apm' | 'lens'>; | ||
signal: AbortSignal; | ||
}) { | ||
const queriesWithUserPrompt = | ||
userMessage && userMessage.message.content | ||
? [userMessage.message.content, ...queries] | ||
: queries; | ||
|
||
const recallResponse = await client.recall({ | ||
queries: queriesWithUserPrompt, | ||
queries, | ||
categories, | ||
}); | ||
|
||
|
@@ -156,54 +152,44 @@ const scoreFunctionRequestRt = t.type({ | |
}); | ||
|
||
const scoreFunctionArgumentsRt = t.type({ | ||
scores: t.array( | ||
t.type({ | ||
id: t.string, | ||
score: t.number, | ||
}) | ||
), | ||
scores: t.string, | ||
}); | ||
|
||
async function scoreSuggestions({ | ||
suggestions, | ||
systemMessage, | ||
userMessage, | ||
messages, | ||
queries, | ||
client, | ||
connectorId, | ||
signal, | ||
resources, | ||
}: { | ||
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>; | ||
systemMessage: Message; | ||
userMessage?: Message; | ||
messages: Message[]; | ||
queries: string[]; | ||
client: ObservabilityAIAssistantClient; | ||
connectorId: string; | ||
signal: AbortSignal; | ||
resources: RespondFunctionResources; | ||
}) { | ||
resources.logger.debug(`Suggestions: ${JSON.stringify(suggestions, null, 2)}`); | ||
const indexedSuggestions = suggestions.map((suggestion, index) => ({ ...suggestion, id: index })); | ||
|
||
const systemMessageExtension = | ||
dedent(`You have the function called score available to help you inform the user about how relevant you think a given document is to the conversation. | ||
Please give a score between 1 and 7, fractions are allowed. | ||
A higher score means it is more relevant.`); | ||
const extendedSystemMessage = { | ||
...systemMessage, | ||
message: { | ||
...systemMessage.message, | ||
content: `${systemMessage.message.content}\n\n${systemMessageExtension}`, | ||
}, | ||
}; | ||
const newUserMessageContent = | ||
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7, | ||
0 being completely relevant, and 7 being extremely relevant. Information is relevant to the question if it helps in | ||
answering the question. Judge it according to the following criteria: | ||
const userMessageOrQueries = | ||
userMessage && userMessage.message.content ? userMessage.message.content : queries.join(','); | ||
- The document is relevant to the question, and the rest of the conversation | ||
- The document has information relevant to the question that is not mentioned, | ||
or more detailed than what is available in the conversation | ||
- The document has a high amount of information relevant to the question compared to other documents | ||
- The document contains new information not mentioned before in the conversation | ||
const newUserMessageContent = | ||
dedent(`Given the question "${userMessageOrQueries}", can you give me a score for how relevant the following documents are? | ||
Question: | ||
${queries.join('\n')} | ||
${JSON.stringify(suggestions, null, 2)}`); | ||
Documents: | ||
${JSON.stringify(indexedSuggestions, null, 2)}`); | ||
|
||
const newUserMessage: Message = { | ||
'@timestamp': new Date().toISOString(), | ||
|
@@ -222,22 +208,13 @@ async function scoreSuggestions({ | |
additionalProperties: false, | ||
properties: { | ||
scores: { | ||
description: 'The document IDs and their scores', | ||
type: 'array', | ||
items: { | ||
type: 'object', | ||
additionalProperties: false, | ||
properties: { | ||
id: { | ||
description: 'The ID of the document', | ||
type: 'string', | ||
}, | ||
score: { | ||
description: 'The score for the document', | ||
type: 'number', | ||
}, | ||
}, | ||
}, | ||
description: `The document IDs and their scores, as CSV. Example: | ||
my_id,7 | ||
my_other_id,3 | ||
my_third_id,4 | ||
`, | ||
type: 'string', | ||
}, | ||
}, | ||
required: ['score'], | ||
|
@@ -249,19 +226,26 @@ async function scoreSuggestions({ | |
( | ||
await client.chat('score_suggestions', { | ||
connectorId, | ||
messages: [extendedSystemMessage, newUserMessage], | ||
messages: [...messages.slice(0, -1), newUserMessage], | ||
functions: [scoreFunction], | ||
functionCall: 'score', | ||
signal, | ||
}) | ||
).pipe(concatenateChatCompletionChunks()) | ||
); | ||
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); | ||
const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( | ||
const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( | ||
scoreFunctionRequest.message.function_call.arguments | ||
); | ||
|
||
resources.logger.debug(`Scores: ${JSON.stringify(scores, null, 2)}`); | ||
const scores = scoresAsString.split('\n').map((line) => { | ||
const [index, score] = line | ||
.split(',') | ||
.map((value) => value.trim()) | ||
.map(Number); | ||
|
||
return { id: suggestions[index].id, score }; | ||
}); | ||
Comment on lines
+241
to
+248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How confident are we that this is the format the LLM will respond with (seeing we support multiple LLMs soon)? Should we handle the cases where the LLM will respond in non-csv format? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have a look in the Bedrock PR, but I don't think we should expect too much from or invest too much in LLMs other than OpenAI, until there is an LLM that has similar performance. |
||
|
||
if (scores.length === 0) { | ||
return []; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,15 @@ import apm from 'elastic-apm-node'; | |
import { decode, encode } from 'gpt-tokenizer'; | ||
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash'; | ||
import type OpenAI from 'openai'; | ||
import { filter, isObservable, lastValueFrom, Observable, shareReplay, toArray } from 'rxjs'; | ||
import { | ||
filter, | ||
firstValueFrom, | ||
isObservable, | ||
lastValueFrom, | ||
Observable, | ||
shareReplay, | ||
toArray, | ||
} from 'rxjs'; | ||
import { Readable } from 'stream'; | ||
import { v4 } from 'uuid'; | ||
import { | ||
|
@@ -455,6 +463,8 @@ export class ObservabilityAIAssistantClient { | |
): Promise<Observable<ChatCompletionChunkEvent>> => { | ||
const span = apm.startSpan(`chat ${name}`); | ||
|
||
const spanId = (span?.ids['span.id'] || '').substring(0, 6); | ||
|
||
const messagesForOpenAI: Array< | ||
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & { | ||
role: MessageRole; | ||
|
@@ -490,6 +500,8 @@ export class ObservabilityAIAssistantClient { | |
this.dependencies.logger.debug(`Sending conversation to connector`); | ||
this.dependencies.logger.trace(JSON.stringify(request, null, 2)); | ||
|
||
const now = performance.now(); | ||
|
||
const executeResult = await this.dependencies.actionsClient.execute({ | ||
actionId: connectorId, | ||
params: { | ||
|
@@ -501,7 +513,11 @@ export class ObservabilityAIAssistantClient { | |
}, | ||
}); | ||
|
||
this.dependencies.logger.debug(`Received action client response: ${executeResult.status}`); | ||
this.dependencies.logger.debug( | ||
`Received action client response: ${executeResult.status} (took: ${Math.round( | ||
performance.now() - now | ||
)}ms)${spanId ? ` (${spanId})` : ''}` | ||
); | ||
|
||
if (executeResult.status === 'error' && executeResult?.serviceMessage) { | ||
const tokenLimitRegex = | ||
|
@@ -524,20 +540,34 @@ export class ObservabilityAIAssistantClient { | |
|
||
const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay()); | ||
|
||
if (span) { | ||
lastValueFrom(observable) | ||
.then( | ||
() => { | ||
span.setOutcome('success'); | ||
}, | ||
() => { | ||
span.setOutcome('failure'); | ||
} | ||
) | ||
.finally(() => { | ||
span.end(); | ||
}); | ||
} | ||
firstValueFrom(observable) | ||
.catch(noop) | ||
.finally(() => { | ||
this.dependencies.logger.debug( | ||
`Received first value after ${Math.round(performance.now() - now)}ms${ | ||
spanId ? ` (${spanId})` : '' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these intermediate metrics only available in logs or can we somehow store them in the span ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can, as labels, but not sure if it has tremendous value - adding labels changes the mapping so I'm a little wary of adding more. Let's see if we need it. |
||
}` | ||
); | ||
}); | ||
|
||
lastValueFrom(observable) | ||
.then( | ||
() => { | ||
span?.setOutcome('success'); | ||
}, | ||
() => { | ||
span?.setOutcome('failure'); | ||
} | ||
) | ||
.finally(() => { | ||
this.dependencies.logger.debug( | ||
`Completed response in ${Math.round(performance.now() - now)}ms${ | ||
spanId ? ` (${spanId})` : '' | ||
}` | ||
); | ||
|
||
span?.end(); | ||
}); | ||
|
||
return observable; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is a good use case for
findLast
(well-supported and not nearly used enough :D )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not know this exists 😄 unfortunately TS doesn't know about it either (??)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh, I see that now. Odd since CanIUse reports mainstream support in all browsers we care about since at least 2022.
Edit seems like we need to wait until Typescript 5