Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Assistant] AI Assistant - Better Solution for OSS models (#10416) #194166

Merged
merged 25 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc4228d
OSS LLM
e40pud Sep 16, 2024
4d5c841
OSS LLMs streaming fixes
e40pud Sep 16, 2024
45c8ac0
Use api URL to verify OSS llms vs OpenAI
e40pud Sep 17, 2024
98abd53
Prompting
e40pud Sep 19, 2024
7a7f7e2
Use provider type to better identify OSS model - handles case with th…
e40pud Sep 19, 2024
6903909
Enable `NaturalLanguageESQLTool` for OSS models like Llama
e40pud Sep 24, 2024
7436f34
Fix the issue with extra escape backslash characters which breaks the…
e40pud Sep 24, 2024
ae638e6
Revert streaming events parsing
e40pud Sep 24, 2024
a03e5ea
Simplified OSS model streaming
e40pud Sep 24, 2024
5d2e1f2
Add OSS model specific prompt to the tool description
e40pud Sep 25, 2024
f9eb9d7
Brush up implementation and add some unit tests
e40pud Sep 26, 2024
351c2f7
Remove redundant code
e40pud Sep 27, 2024
d78af80
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Sep 27, 2024
8689cc2
Merge branch 'main' into security/genai/10416-OSS-models
elasticmachine Sep 30, 2024
046c3c5
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 1, 2024
c0a00eb
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 2, 2024
69543ce
Make sure we log evaluation results and errors
e40pud Oct 2, 2024
5c7bf7f
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 3, 2024
db9fd74
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 4, 2024
d0458eb
Merge branch 'main' into security/genai/10416-OSS-models
e40pud Oct 4, 2024
26cb4c3
Merge branch 'main' into security/genai/10416-OSS-models
elasticmachine Oct 7, 2024
513d4bf
Review feedback: long time request issue
e40pud Oct 7, 2024
1d60365
Update x-pack/plugins/elastic_assistant/server/routes/utils.ts
e40pud Oct 7, 2024
0f96231
Review feedback: naming
e40pud Oct 7, 2024
59fb225
Review feedback: re-instantiate AbortController after the `abort`
e40pud Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface AgentExecutorParams<T extends boolean> {
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isOssModel: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ interface StreamGraphParams {
assistantGraph: DefaultAssistantGraph;
inputs: GraphInputs;
logger: Logger;
isOssModel?: boolean;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
traceOptions?: TraceOptions;
Expand All @@ -36,6 +37,7 @@ interface StreamGraphParams {
* @param assistantGraph
* @param inputs
* @param logger
* @param isOssModel
* @param onLlmResponse
* @param request
* @param traceOptions
Expand All @@ -45,6 +47,7 @@ export const streamGraph = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down Expand Up @@ -80,8 +83,8 @@ export const streamGraph = async ({
};

if (
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
inputs?.bedrockChatEnabled
inputs.isOssModel ||
((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled)
) {
const stream = await assistantGraph.streamEvents(
inputs,
Expand All @@ -92,7 +95,9 @@ export const streamGraph = async ({
version: 'v2',
streamMode: 'values',
},
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
inputs.isOssModel || inputs?.llmType === 'bedrock'
? { includeNames: ['Summarizer'] }
: undefined
);

for await (const { event, data, tags } of stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
langChainMessages,
llmType,
isOssModel,
logger: parentLogger,
isStream = false,
onLlmResponse,
Expand All @@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
responseLanguage = 'English',
}) => {
const logger = parentLogger.get('defaultAssistantGraph');
const isOpenAI = llmType === 'openai';
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, bedrockChatEnabled);

/**
Expand Down Expand Up @@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
};

const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
);

// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
Expand Down Expand Up @@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
conversationId,
llmType,
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
};

Expand All @@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase {
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
const hasRespondStep =
state.isStream &&
(state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock'));

return {
hasRespondStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,59 @@ export const GEMINI_SYSTEM_PROMPT =
`ALWAYS use the provided tools, as they have access to the latest data and syntax.` +
"The final response is the only output the user sees and should be a complete answer to the user's question. Do not leave out important tool output. The final response should never be empty. Don't forget to use tools.";
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools:

{tools}

The tool action_input should ALWAYS follow the tool JSON schema args.

Valid "action" values: "Final Answer" or {tool_names}

Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).

Provide only ONE action per $JSON_BLOB, as shown:

\`\`\`

{{

"action": $TOOL_NAME,

"action_input": $TOOL_INPUT

}}

\`\`\`

Follow this format:

Question: input question to answer

Thought: consider previous and subsequent steps

Action:

\`\`\`

$JSON_BLOB

\`\`\`

Observation: action result

... (repeat Thought/Action/Observation N times)

Thought: I know what to respond

Action:

\`\`\`

{{

"action": "Final Answer",

"action_input": "Final response to human"}}

Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
BEDROCK_SYSTEM_PROMPT,
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './nodes/translations';

export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
Expand All @@ -24,61 +25,7 @@ export const systemPrompts = {
openai: DEFAULT_SYSTEM_PROMPT,
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
gemini: `${DEFAULT_SYSTEM_PROMPT} ${GEMINI_SYSTEM_PROMPT}`,
structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools:

{tools}

The tool action_input should ALWAYS follow the tool JSON schema args.

Valid "action" values: "Final Answer" or {tool_names}

Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).

Provide only ONE action per $JSON_BLOB, as shown:

\`\`\`

{{

"action": $TOOL_NAME,

"action_input": $TOOL_INPUT

}}

\`\`\`

Follow this format:

Question: input question to answer

Thought: consider previous and subsequent steps

Action:

\`\`\`

$JSON_BLOB

\`\`\`

Observation: action result

... (repeat Thought/Action/Observation N times)

Thought: I know what to respond

Action:

\`\`\`

{{

"action": "Final Answer",

"action_input": "Final response to human"}}

Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`,
structuredChat: STRUCTURED_SYSTEM_PROMPT,
};

export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export interface GraphInputs {
conversationId?: string;
llmType?: string;
isStream?: boolean;
isOssModel?: boolean;
input: string;
responseLanguage?: string;
}
Expand All @@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase {
lastNode: string;
hasRespondStep: boolean;
isStream: boolean;
isOssModel: boolean;
bedrockChatEnabled: boolean;
llmType: string;
responseLanguage: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from '../helpers';
import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers';
import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types';
import { isOpenSourceModel } from '../utils';

export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => {
return `CONTEXT:\n"""\n${context}\n"""`;
Expand Down Expand Up @@ -99,7 +100,9 @@ export const chatCompleteRoute = (
const actions = ctx.elasticAssistant.actions;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai';
const connector = connectors.length > 0 ? connectors[0] : undefined;
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
const isOssModel = isOpenSourceModel(connector);

// replacements
const anonymizationFieldsRes =
Expand Down Expand Up @@ -192,6 +195,7 @@ export const chatCompleteRoute = (
actionsClient,
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
context: ctx,
getElser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import {
openAIFunctionAgentPrompt,
structuredChatAgentPrompt,
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
import { getLlmClass, getLlmType } from '../utils';
import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils';

const DEFAULT_SIZE = 20;
const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes
Expand Down Expand Up @@ -174,10 +174,12 @@ export const postEvaluateRoute = (
name: string;
graph: DefaultAssistantGraph;
llmType: string | undefined;
isOssModel: boolean | undefined;
}> = await Promise.all(
connectors.map(async (connector) => {
const llmType = getLlmType(connector.actionTypeId);
const isOpenAI = llmType === 'openai';
const isOssModel = isOpenSourceModel(connector);
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, true);
const createLlmInstance = () =>
new llmClass({
Expand Down Expand Up @@ -232,6 +234,7 @@ export const postEvaluateRoute = (
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
llm,
isOssModel,
logger,
modelExists: isEnabledKnowledgeBase,
request: skeletonRequest,
Expand Down Expand Up @@ -274,6 +277,7 @@ export const postEvaluateRoute = (
return {
name: `${runName} - ${connector.name}`,
llmType,
isOssModel,
graph: getDefaultAssistantGraph({
agentRunnable,
dataClients,
Expand All @@ -287,7 +291,7 @@ export const postEvaluateRoute = (
);

// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
await asyncForEach(graphs, async ({ name, graph, llmType }) => {
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => {
// Wrapper function for invoking the graph (to parse different input/output formats)
const predict = async (input: { input: string }) => {
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
Expand All @@ -300,6 +304,7 @@ export const postEvaluateRoute = (
llmType,
bedrockChatEnabled: true,
isStreaming: false,
isOssModel,
}, // TODO: Update to use the correct input format per dataset type
{
runName,
Expand All @@ -310,15 +315,20 @@ export const postEvaluateRoute = (
return output;
};

const evalOutput = await evaluate(predict, {
evaluate(predict, {
data: datasetName ?? '',
evaluators: [], // Evals to be managed in LangSmith for now
experimentPrefix: name,
client: new Client({ apiKey: langSmithApiKey }),
// prevent rate limiting and unexpected multiple experiment runs
maxConcurrency: 5,
});
logger.debug(`runResp:\n ${JSON.stringify(evalOutput, null, 2)}`);
})
.then((output) => {
logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`);
})
.catch((err) => {
logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`);
});
});

return response.ok({
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugins/elastic_assistant/server/routes/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ export interface LangChainExecuteParams {
actionTypeId: string;
connectorId: string;
inference: InferenceServerStart;
isOssModel?: boolean;
conversationId?: string;
context: AwaitedProperties<
Pick<ElasticAssistantRequestHandlerContext, 'elasticAssistant' | 'licensing' | 'core'>
Expand All @@ -348,6 +349,7 @@ export const langChainExecute = async ({
telemetry,
actionTypeId,
connectorId,
isOssModel,
context,
actionsClient,
inference,
Expand Down Expand Up @@ -412,6 +414,7 @@ export const langChainExecute = async ({
inference,
isStream,
llmType: getLlmType(actionTypeId),
isOssModel,
langChainMessages,
logger,
onNewReplacements,
Expand Down
Loading