From 858e70c172bd206c15a48c253219c8f32056b47c Mon Sep 17 00:00:00 2001 From: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:51:30 +1100 Subject: [PATCH] [8.x] [Security Assistant] AI Assistant - Better Solution for OSS models (#10416) (#194166) (#195324) # Backport This will backport the following commits from `main` to `8.x`: - [[Security Assistant] AI Assistant - Better Solution for OSS models (#10416) (#194166)](https://github.com/elastic/kibana/pull/194166) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) Co-authored-by: Ievgen Sorokopud --- .../impl/assistant/use_send_message/index.tsx | 16 ++ .../use_send_message/translations.ts | 15 ++ .../server/lib/langchain/executors/types.ts | 1 + .../graphs/default_assistant_graph/graph.ts | 4 + .../graphs/default_assistant_graph/helpers.ts | 11 +- .../graphs/default_assistant_graph/index.ts | 7 +- .../nodes/model_input.ts | 4 +- .../nodes/translations.ts | 56 ++++++ .../graphs/default_assistant_graph/prompts.ts | 57 +----- .../graphs/default_assistant_graph/types.ts | 2 + .../server/routes/chat/chat_complete_route.ts | 6 +- .../server/routes/evaluate/post_evaluate.ts | 22 ++- .../server/routes/helpers.ts | 3 + .../routes/post_actions_connector_execute.ts | 5 + .../server/routes/utils.test.ts | 69 ++++++++ .../elastic_assistant/server/routes/utils.ts | 28 +++ .../plugins/elastic_assistant/server/types.ts | 1 + .../esql_language_knowledge_base/common.ts | 15 ++ .../esql_language_knowledge_base_tool.test.ts | 23 +++ .../esql_language_knowledge_base_tool.ts | 12 +- .../nl_to_esql_tool.test.ts | 162 ++++++++++++++++++ .../nl_to_esql_tool.ts | 7 +- 22 files changed, 452 insertions(+), 74 deletions(-) create mode 100644 x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts create mode 100644 x-pack/plugins/elastic_assistant/server/routes/utils.test.ts create mode 100644 x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts create mode 100644 x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx index 93bd03607e71f..438b2282371d9 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx @@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react'; import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common'; import { useAssistantContext } from '../../assistant_context'; import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api'; +import * as i18n from './translations'; + +/** + * TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant. + * Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request. + * The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to + * a situation where core http client will initiate same request again and again. + * To avoid this, we abort http request after timeout which is slightly below two minutes. + */ +const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds interface SendMessageProps { apiConfig: ApiConfig; @@ -38,6 +48,11 @@ export const useSendMessage = (): UseSendMessage => { async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => { setIsLoading(true); + const timeoutId = setTimeout(() => { + abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR); + abortController.current = new AbortController(); + }, EXECUTE_ACTION_TIMEOUT); + try { return await fetchConnectorExecuteAction({ conversationId, @@ -52,6 +67,7 @@ export const useSendMessage = (): UseSendMessage => { traceOptions, }); } finally { + clearTimeout(timeoutId); setIsLoading(false); } }, diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts new file mode 100644 index 0000000000000..1185d8cfdbc65 --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; + +export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate( + 'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError', + { + defaultMessage: 'Assistant could not respond in time. Please try again later.', + } +); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 2d86d05447916..5761201849c09 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -45,6 +45,7 @@ export interface AgentExecutorParams { esClient: ElasticsearchClient; langChainMessages: BaseMessage[]; llmType?: string; + isOssModel?: boolean; logger: Logger; inference: InferenceServerStart; onNewReplacements?: (newReplacements: Replacements) => void; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 8395076ad62ee..8f2f713c170ed 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({ value: (x: boolean, y?: boolean) => y ?? x, default: () => false, }, + isOssModel: { + value: (x: boolean, y?: boolean) => y ?? x, + default: () => false, + }, conversation: { value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => y ?? x, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 93890f9dfb121..840b2a9ac8ce0 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -24,6 +24,7 @@ interface StreamGraphParams { assistantGraph: DefaultAssistantGraph; inputs: GraphInputs; logger: Logger; + isOssModel?: boolean; onLlmResponse?: OnLlmResponse; request: KibanaRequest; traceOptions?: TraceOptions; @@ -36,6 +37,7 @@ interface StreamGraphParams { * @param assistantGraph * @param inputs * @param logger + * @param isOssModel * @param onLlmResponse * @param request * @param traceOptions @@ -45,6 +47,7 @@ export const streamGraph = async ({ assistantGraph, inputs, logger, + isOssModel, onLlmResponse, request, traceOptions, @@ -80,8 +83,8 @@ export const streamGraph = async ({ }; if ( - (inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && - inputs?.bedrockChatEnabled + inputs.isOssModel || + ((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled) ) { const stream = await assistantGraph.streamEvents( inputs, @@ -92,7 +95,9 @@ export const streamGraph = async ({ version: 'v2', streamMode: 'values', }, - inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined + inputs.isOssModel || inputs?.llmType === 'bedrock' + ? { includeNames: ['Summarizer'] } + : undefined ); for await (const { event, data, tags } of stream) { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index dee23f202b3d4..daec22b436474 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor = async ({ inference, langChainMessages, llmType, + isOssModel, logger: parentLogger, isStream = false, onLlmResponse, @@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor = async ({ responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); - const isOpenAI = llmType === 'openai'; + const isOpenAI = llmType === 'openai' && !isOssModel; const llmClass = getLlmClass(llmType, bedrockChatEnabled); /** @@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor = async ({ }; const tools: StructuredTool[] = assistantTools.flatMap( - (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? [] + (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? [] ); // If KB enabled, fetch for any KB IndexEntries and generate a tool for each @@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, llmType, isStream, + isOssModel, input: latestMessage[0]?.content as string, }; @@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor = async ({ assistantGraph, inputs, logger, + isOssModel, onLlmResponse, request, traceOptions, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts index f634d10f5cd4a..5f46e1ad2a741 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts @@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase { export function modelInput({ logger, state }: ModelInputParams): Partial { logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`); - const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock'; + const hasRespondStep = + state.isStream && + (state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock')); return { hasRespondStep, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts index 9eedce48ba69d..e55e1081e6474 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts @@ -18,3 +18,59 @@ const KB_CATCH = export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`; export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return tags in the response, but make sure to include tags content in the response. Do not reflect on the quality of the returned search results in your response.`; export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`; + +export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools: + +{tools} + +The tool action_input should ALWAYS follow the tool JSON schema args. + +Valid "action" values: "Final Answer" or {tool_names} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args). + +Provide only ONE action per $JSON_BLOB, as shown: + +\`\`\` + +{{ + + "action": $TOOL_NAME, + + "action_input": $TOOL_INPUT + +}} + +\`\`\` + +Follow this format: + +Question: input question to answer + +Thought: consider previous and subsequent steps + +Action: + +\`\`\` + +$JSON_BLOB + +\`\`\` + +Observation: action result + +... (repeat Thought/Action/Observation N times) + +Thought: I know what to respond + +Action: + +\`\`\` + +{{ + + "action": "Final Answer", + + "action_input": "Final response to human"}} + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index 4a7b1fd46ccb8..883047ed7b9df 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -11,6 +11,7 @@ import { DEFAULT_SYSTEM_PROMPT, GEMINI_SYSTEM_PROMPT, GEMINI_USER_PROMPT, + STRUCTURED_SYSTEM_PROMPT, } from './nodes/translations'; export const formatPrompt = (prompt: string, additionalPrompt?: string) => @@ -26,61 +27,7 @@ export const systemPrompts = { bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`, // The default prompt overwhelms gemini, do not prepend gemini: GEMINI_SYSTEM_PROMPT, - structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools: - -{tools} - -The tool action_input should ALWAYS follow the tool JSON schema args. - -Valid "action" values: "Final Answer" or {tool_names} - -Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args). - -Provide only ONE action per $JSON_BLOB, as shown: - -\`\`\` - -{{ - - "action": $TOOL_NAME, - - "action_input": $TOOL_INPUT - -}} - -\`\`\` - -Follow this format: - -Question: input question to answer - -Thought: consider previous and subsequent steps - -Action: - -\`\`\` - -$JSON_BLOB - -\`\`\` - -Observation: action result - -... (repeat Thought/Action/Observation N times) - -Thought: I know what to respond - -Action: - -\`\`\` - -{{ - - "action": "Final Answer", - - "action_input": "Final response to human"}} - -Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`, + structuredChat: STRUCTURED_SYSTEM_PROMPT, }; export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 17d06b0f7042e..69632be2ffdcd 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -20,6 +20,7 @@ export interface GraphInputs { conversationId?: string; llmType?: string; isStream?: boolean; + isOssModel?: boolean; input: string; responseLanguage?: string; } @@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase { lastNode: string; hasRespondStep: boolean; isStream: boolean; + isOssModel: boolean; bedrockChatEnabled: boolean; llmType: string; responseLanguage: string; diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index dd90241809015..47f6f1a486957 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -30,6 +30,7 @@ import { } from '../helpers'; import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers'; import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types'; +import { isOpenSourceModel } from '../utils'; export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => { return `CONTEXT:\n"""\n${context}\n"""`; @@ -99,7 +100,9 @@ export const chatCompleteRoute = ( const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); - actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; + const connector = connectors.length > 0 ? connectors[0] : undefined; + actionTypeId = connector?.actionTypeId ?? '.gen-ai'; + const isOssModel = isOpenSourceModel(connector); // replacements const anonymizationFieldsRes = @@ -192,6 +195,7 @@ export const chatCompleteRoute = ( actionsClient, actionTypeId, connectorId, + isOssModel, conversationId: conversationId ?? newConversation?.id, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index c0c7bf3f6bc4e..59436070a7125 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -46,7 +46,7 @@ import { openAIFunctionAgentPrompt, structuredChatAgentPrompt, } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; -import { getLlmClass, getLlmType } from '../utils'; +import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils'; const DEFAULT_SIZE = 20; const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes @@ -174,10 +174,12 @@ export const postEvaluateRoute = ( name: string; graph: DefaultAssistantGraph; llmType: string | undefined; + isOssModel: boolean | undefined; }> = await Promise.all( connectors.map(async (connector) => { const llmType = getLlmType(connector.actionTypeId); - const isOpenAI = llmType === 'openai'; + const isOssModel = isOpenSourceModel(connector); + const isOpenAI = llmType === 'openai' && !isOssModel; const llmClass = getLlmClass(llmType, true); const createLlmInstance = () => new llmClass({ @@ -232,6 +234,7 @@ export const postEvaluateRoute = ( isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, llm, + isOssModel, logger, modelExists: isEnabledKnowledgeBase, request: skeletonRequest, @@ -274,6 +277,7 @@ export const postEvaluateRoute = ( return { name: `${runName} - ${connector.name}`, llmType, + isOssModel, graph: getDefaultAssistantGraph({ agentRunnable, dataClients, @@ -287,7 +291,7 @@ export const postEvaluateRoute = ( ); // Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector) - await asyncForEach(graphs, async ({ name, graph, llmType }) => { + await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => { // Wrapper function for invoking the graph (to parse different input/output formats) const predict = async (input: { input: string }) => { logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`); @@ -300,6 +304,7 @@ export const postEvaluateRoute = ( llmType, bedrockChatEnabled: true, isStreaming: false, + isOssModel, }, // TODO: Update to use the correct input format per dataset type { runName, @@ -310,15 +315,20 @@ export const postEvaluateRoute = ( return output; }; - const evalOutput = await evaluate(predict, { + evaluate(predict, { data: datasetName ?? '', evaluators: [], // Evals to be managed in LangSmith for now experimentPrefix: name, client: new Client({ apiKey: langSmithApiKey }), // prevent rate limiting and unexpected multiple experiment runs maxConcurrency: 5, - }); - logger.debug(`runResp:\n ${JSON.stringify(evalOutput, null, 2)}`); + }) + .then((output) => { + logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`); + }) + .catch((err) => { + logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`); + }); }); return response.ok({ diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 2c0c56c73a2b3..ebd9fd996dfe1 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -322,6 +322,7 @@ export interface LangChainExecuteParams { actionTypeId: string; connectorId: string; inference: InferenceServerStart; + isOssModel?: boolean; conversationId?: string; context: AwaitedProperties< Pick @@ -348,6 +349,7 @@ export const langChainExecute = async ({ telemetry, actionTypeId, connectorId, + isOssModel, context, actionsClient, inference, @@ -412,6 +414,7 @@ export const langChainExecute = async ({ inference, isStream, llmType: getLlmType(actionTypeId), + isOssModel, langChainMessages, logger, onNewReplacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 736d60ff666b0..4b65b5bb3f1e5 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -29,6 +29,7 @@ import { getSystemPromptFromUserConversation, langChainExecute, } from './helpers'; +import { isOpenSourceModel } from './utils'; export const postActionsConnectorExecuteRoute = ( router: IRouter, @@ -94,6 +95,9 @@ export const postActionsConnectorExecuteRoute = ( const actions = ctx.elasticAssistant.actions; const inference = ctx.elasticAssistant.inference; const actionsClient = await actions.getActionsClientWithRequest(request); + const connectors = await actionsClient.getBulk({ ids: [connectorId] }); + const connector = connectors.length > 0 ? connectors[0] : undefined; + const isOssModel = isOpenSourceModel(connector); const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); @@ -129,6 +133,7 @@ export const postActionsConnectorExecuteRoute = ( actionsClient, actionTypeId, connectorId, + isOssModel, conversationId, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts new file mode 100644 index 0000000000000..3ca1b8edb5036 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import { isOpenSourceModel } from './utils'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; + +describe('Utils', () => { + describe('isOpenSourceModel', () => { + it('should return `false` when connector is undefined', async () => { + const isOpenModel = isOpenSourceModel(); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a Bedrock', async () => { + const connector = { actionTypeId: '.bedrock' } as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a Gemini', async () => { + const connector = { actionTypeId: '.gemini' } as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a OpenAI and API url is not specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a OpenAI and OpenAI API url is specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiUrl: OPENAI_CHAT_URL }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a AzureOpenAI', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiProvider: OpenAiProviderType.AzureAi }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `true` when connector is a OpenAI and non-OpenAI API url is specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiUrl: 'https://elastic.llm.com/llama/chat/completions' }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(true); + }); + }); +}); diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 651a809e1a56e..5811109b94ede 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -19,6 +19,11 @@ import { ActionsClientSimpleChatModel, ActionsClientChatVertexAI, } from '@kbn/langchain/server'; +import { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { CustomHttpRequestError } from './custom_http_request_error'; export interface OutputError { @@ -189,3 +194,26 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) => : llmType === 'gemini' && bedrockChatEnabled ? ActionsClientChatVertexAI : ActionsClientSimpleChatModel; + +export const isOpenSourceModel = (connector?: Connector): boolean => { + if (connector == null) { + return false; + } + + const llmType = getLlmType(connector.actionTypeId); + const connectorApiUrl = connector.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; + const connectorApiProvider = connector.config?.apiProvider + ? (connector.config?.apiProvider as OpenAiProviderType) + : undefined; + + const isOpenAiType = llmType === 'openai'; + const isOpenAI = + isOpenAiType && + (!connectorApiUrl || + connectorApiUrl === OPENAI_CHAT_URL || + connectorApiProvider === OpenAiProviderType.AzureAi); + + return isOpenAiType && !isOpenAI; +}; diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index af8d019539a66..9062bc5a434b1 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -244,6 +244,7 @@ export interface AssistantToolParams { kbDataClient?: AIAssistantKnowledgeBaseDataClient; langChainTimeout?: number; llm?: ActionsClientLlm | AssistantToolLlm; + isOssModel?: boolean; logger: Logger; modelExists: boolean; onNewReplacements?: (newReplacements: Replacements) => void; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts new file mode 100644 index 0000000000000..ee2bee8fab806 --- /dev/null +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const getPromptSuffixForOssModel = (toolName: string) => ` + When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool. + + Always return value from ${toolName} tool as is. + + The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. + + It is important that ES|QL query is preceeded by a new line.`; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts index 7eeb11e8df37a..589c95e8483bf 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts @@ -12,6 +12,7 @@ import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen'; import { loggerMock } from '@kbn/logging-mocks'; import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base'; +import { getPromptSuffixForOssModel } from './common'; describe('EsqlLanguageKnowledgeBaseTool', () => { const kbDataClient = jest.fn() as unknown as AIAssistantKnowledgeBaseDataClient; @@ -108,5 +109,27 @@ describe('EsqlLanguageKnowledgeBaseTool', () => { expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']); }); + + it('should return tool with the expected description for OSS model', () => { + const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: true, + ...rest, + }) as DynamicTool; + + expect(tool.description).toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool')); + }); + + it('should return tool with the expected description for non-OSS model', () => { + const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: false, + ...rest, + }) as DynamicTool; + + expect(tool.description).not.toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool')); + }); }); }); diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts index 6bf116c28719a..37e037898cd20 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts @@ -14,12 +14,15 @@ import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { ESQL_RESOURCE } from '@kbn/elastic-assistant-plugin/server/routes/knowledge_base/constants'; import { APP_UI_ID } from '../../../../common'; +import { getPromptSuffixForOssModel } from './common'; + +const TOOL_NAME = 'ESQLKnowledgeBaseTool'; const toolDetails = { + id: 'esql-knowledge-base-tool', + name: TOOL_NAME, description: 'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the user query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.', - id: 'esql-knowledge-base-tool', - name: 'ESQLKnowledgeBaseTool', }; export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { ...toolDetails, @@ -31,12 +34,13 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { getTool(params: AssistantToolParams) { if (!this.isSupported(params)) return null; - const { kbDataClient } = params as AssistantToolParams; + const { kbDataClient, isOssModel } = params as AssistantToolParams; if (kbDataClient == null) return null; return new DynamicStructuredTool({ name: toolDetails.name, - description: toolDetails.description, + description: + toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }), diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts new file mode 100644 index 0000000000000..f078bccb24a36 --- /dev/null +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { RetrievalQAChain } from 'langchain/chains'; +import type { DynamicTool } from '@langchain/core/tools'; +import { NL_TO_ESQL_TOOL } from './nl_to_esql_tool'; +import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; +import type { KibanaRequest } from '@kbn/core-http-server'; +import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen'; +import { loggerMock } from '@kbn/logging-mocks'; +import { getPromptSuffixForOssModel } from './common'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; + +describe('NaturalLanguageESQLTool', () => { + const chain = {} as RetrievalQAChain; + const esClient = { + search: jest.fn().mockResolvedValue({}), + } as unknown as ElasticsearchClient; + const request = { + body: { + isEnabledKnowledgeBase: false, + alertsIndexPattern: '.alerts-security.alerts-default', + allow: ['@timestamp', 'cloud.availability_zone', 'user.name'], + allowReplacement: ['user.name'], + replacements: { key: 'value' }, + size: 20, + }, + } as unknown as KibanaRequest; + const logger = loggerMock.create(); + const inference = {} as InferenceServerStart; + const connectorId = 'fake-connector'; + const rest = { + chain, + esClient, + logger, + request, + inference, + connectorId, + }; + + describe('isSupported', () => { + it('returns false if isEnabledKnowledgeBase is false', () => { + const params = { + isEnabledKnowledgeBase: false, + modelExists: true, + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false); + }); + + it('returns false if modelExists is false (the ELSER model is not installed)', () => { + const params = { + isEnabledKnowledgeBase: true, + modelExists: false, // <-- ELSER model is not installed + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false); + }); + + it('returns true if isEnabledKnowledgeBase and modelExists are true', () => { + const params = { + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(true); + }); + }); + + describe('getTool', () => { + it('returns null if isEnabledKnowledgeBase is false', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: false, + modelExists: true, + ...rest, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if modelExists is false (the ELSER model is not installed)', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: false, // <-- ELSER model is not installed + ...rest, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if inference plugin is not provided', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + inference: undefined, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if connectorId is not provided', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + connectorId: undefined, + }); + + expect(tool).toBeNull(); + }); + + it('should return a Tool instance if isEnabledKnowledgeBase and modelExists are true', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }); + + expect(tool?.name).toEqual('NaturalLanguageESQLTool'); + }); + + it('should return a tool with the expected tags', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }) as DynamicTool; + + expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']); + }); + + it('should return tool with the expected description for OSS model', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: true, + ...rest, + }) as DynamicTool; + + expect(tool.description).toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool')); + }); + + it('should return tool with the expected description for non-OSS model', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: false, + ...rest, + }) as DynamicTool; + + expect(tool.description).not.toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool')); + }); + }); +}); diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts index a26d16607ac46..96b865efeaed4 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts @@ -11,6 +11,7 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant- import { lastValueFrom } from 'rxjs'; import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; import { APP_UI_ID } from '../../../../common'; +import { getPromptSuffixForOssModel } from './common'; export type ESQLToolParams = AssistantToolParams; @@ -37,7 +38,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { getTool(params: ESQLToolParams) { if (!this.isSupported(params)) return null; - const { connectorId, inference, logger, request } = params as ESQLToolParams; + const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams; if (inference == null || connectorId == null) return null; const callNaturalLanguageToEsql = async (question: string) => { @@ -46,6 +47,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { client: inference.getClient({ request }), connectorId, input: question, + ...(isOssModel ? { functionCalling: 'simulated' } : {}), logger: { debug: (source) => { logger.debug(typeof source === 'function' ? source() : source); @@ -57,7 +59,8 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { return new DynamicStructuredTool({ name: toolDetails.name, - description: toolDetails.description, + description: + toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }),