Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Assistant] AI Assistant - Better Solution for OSS models (#10416) #194166

Merged
merged 25 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc4228d
OSS LLM
e40pud Sep 16, 2024
4d5c841
OSS LLMs streaming fixes
e40pud Sep 16, 2024
45c8ac0
Use api URL to verify OSS llms vs OpenAI
e40pud Sep 17, 2024
98abd53
Prompting
e40pud Sep 19, 2024
7a7f7e2
Use provider type to better identify OSS model - handles case with th…
e40pud Sep 19, 2024
6903909
Enable `NaturalLanguageESQLTool` for OSS models like Llama
e40pud Sep 24, 2024
7436f34
Fix the issue with extra escape backslash characters which breaks the…
e40pud Sep 24, 2024
ae638e6
Revert streaming events parsing
e40pud Sep 24, 2024
a03e5ea
Simplified OSS model streaming
e40pud Sep 24, 2024
5d2e1f2
Add OSS model specific prompt to the tool description
e40pud Sep 25, 2024
f9eb9d7
Brush up implementation and add some unit tests
e40pud Sep 26, 2024
351c2f7
Remove redundant code
e40pud Sep 27, 2024
d78af80
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Sep 27, 2024
8689cc2
Merge branch 'main' into security/genai/10416-OSS-models
elasticmachine Sep 30, 2024
046c3c5
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 1, 2024
c0a00eb
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 2, 2024
69543ce
Make sure we log evaluation results and errors
e40pud Oct 2, 2024
5c7bf7f
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 3, 2024
db9fd74
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 4, 2024
d0458eb
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 4, 2024
26cb4c3
Merge branch 'main' into security/genai/10416-OSS-models
elasticmachine Oct 7, 2024
513d4bf
Review feedback: long time request issue
e40pud Oct 7, 2024
1d60365
Update x-pack/plugins/elastic_assistant/server/routes/utils.ts
e40pud Oct 7, 2024
0f96231
Review feedback: naming
e40pud Oct 7, 2024
59fb225
Review feedback: re-instantiate AbortController after the `abort`
e40pud Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react';
import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common';
import { useAssistantContext } from '../../assistant_context';
import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api';
import * as i18n from './translations';

/**
* TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant.
* Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request.
* The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to
* a situation where core http client will initiate same request again and again.
* To avoid this, we abort http request after timeout which is slightly below two minutes.
*/
const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds

interface SendMessageProps {
apiConfig: ApiConfig;
Expand Down Expand Up @@ -38,6 +48,11 @@ export const useSendMessage = (): UseSendMessage => {
async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => {
setIsLoading(true);

const timeoutId = setTimeout(() => {
abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR);
abortController.current = new AbortController();
}, EXECUTE_ACTION_TIMEOUT);

try {
return await fetchConnectorExecuteAction({
conversationId,
Expand All @@ -52,6 +67,7 @@ export const useSendMessage = (): UseSendMessage => {
traceOptions,
});
} finally {
clearTimeout(timeoutId);
setIsLoading(false);
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { i18n } from '@kbn/i18n';

export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate(
'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError',
{
defaultMessage: 'Assistant could not respond in time. Please try again later.',
}
);
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface AgentExecutorParams<T extends boolean> {
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isOssModel: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ interface StreamGraphParams {
assistantGraph: DefaultAssistantGraph;
inputs: GraphInputs;
logger: Logger;
isOssModel?: boolean;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
traceOptions?: TraceOptions;
Expand All @@ -36,6 +37,7 @@ interface StreamGraphParams {
* @param assistantGraph
* @param inputs
* @param logger
* @param isOssModel
* @param onLlmResponse
* @param request
* @param traceOptions
Expand All @@ -45,6 +47,7 @@ export const streamGraph = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down Expand Up @@ -80,8 +83,8 @@ export const streamGraph = async ({
};

if (
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
inputs?.bedrockChatEnabled
inputs.isOssModel ||
((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled)
) {
const stream = await assistantGraph.streamEvents(
inputs,
Expand All @@ -92,7 +95,9 @@ export const streamGraph = async ({
version: 'v2',
streamMode: 'values',
},
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
inputs.isOssModel || inputs?.llmType === 'bedrock'
? { includeNames: ['Summarizer'] }
: undefined
);

for await (const { event, data, tags } of stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
langChainMessages,
llmType,
isOssModel,
logger: parentLogger,
isStream = false,
onLlmResponse,
Expand All @@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
responseLanguage = 'English',
}) => {
const logger = parentLogger.get('defaultAssistantGraph');
const isOpenAI = llmType === 'openai';
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, bedrockChatEnabled);

/**
Expand Down Expand Up @@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
};

const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
);

// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
Expand Down Expand Up @@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
conversationId,
llmType,
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
};

Expand All @@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase {
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
const hasRespondStep =
state.isStream &&
(state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock'));

return {
hasRespondStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,59 @@ const KB_CATCH =
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools:

{tools}

The tool action_input should ALWAYS follow the tool JSON schema args.

Valid "action" values: "Final Answer" or {tool_names}

Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).

Provide only ONE action per $JSON_BLOB, as shown:

\`\`\`

{{

"action": $TOOL_NAME,

"action_input": $TOOL_INPUT

}}

\`\`\`

Follow this format:

Question: input question to answer

Thought: consider previous and subsequent steps

Action:

\`\`\`

$JSON_BLOB

\`\`\`

Observation: action result

... (repeat Thought/Action/Observation N times)

Thought: I know what to respond

Action:

\`\`\`

{{

"action": "Final Answer",

"action_input": "Final response to human"}}

Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './nodes/translations';

export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
Expand All @@ -26,61 +27,7 @@ export const systemPrompts = {
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
// The default prompt overwhelms gemini, do not prepend
gemini: GEMINI_SYSTEM_PROMPT,
structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools:

{tools}

The tool action_input should ALWAYS follow the tool JSON schema args.

Valid "action" values: "Final Answer" or {tool_names}

Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).

Provide only ONE action per $JSON_BLOB, as shown:

\`\`\`

{{

"action": $TOOL_NAME,

"action_input": $TOOL_INPUT

}}

\`\`\`

Follow this format:

Question: input question to answer

Thought: consider previous and subsequent steps

Action:

\`\`\`

$JSON_BLOB

\`\`\`

Observation: action result

... (repeat Thought/Action/Observation N times)

Thought: I know what to respond

Action:

\`\`\`

{{

"action": "Final Answer",

"action_input": "Final response to human"}}

Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`,
structuredChat: STRUCTURED_SYSTEM_PROMPT,
};

export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export interface GraphInputs {
conversationId?: string;
llmType?: string;
isStream?: boolean;
isOssModel?: boolean;
input: string;
responseLanguage?: string;
}
Expand All @@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase {
lastNode: string;
hasRespondStep: boolean;
isStream: boolean;
isOssModel: boolean;
bedrockChatEnabled: boolean;
llmType: string;
responseLanguage: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from '../helpers';
import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers';
import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types';
import { isOpenSourceModel } from '../utils';

export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => {
return `CONTEXT:\n"""\n${context}\n"""`;
Expand Down Expand Up @@ -99,7 +100,9 @@ export const chatCompleteRoute = (
const actions = ctx.elasticAssistant.actions;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai';
const connector = connectors.length > 0 ? connectors[0] : undefined;
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
const isOssModel = isOpenSourceModel(connector);

// replacements
const anonymizationFieldsRes =
Expand Down Expand Up @@ -192,6 +195,7 @@ export const chatCompleteRoute = (
actionsClient,
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
context: ctx,
getElser,
Expand Down
Loading