From 8940f0ac284a3fd8821606d6989c71a2c9da3eb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daphn=C3=A9=20Popin?= Date: Mon, 6 May 2024 09:15:37 +0200 Subject: [PATCH] Multi-action renderConversationForModel with cli (#4963) * Multi-action renderConversationForModel with cli * Replace tool by function * Snake case as expected by the model * I can't be git blamed on the fuck it comment :D * Back to function_calls & improve token count estimation * Get change from #4973 * Fix we want a single assistant message for all function calls of a given AgentMessageType * Better naming --- front/admin/cli.ts | 88 ++++++- .../lib/api/assistant/actions/dust_app_run.ts | 30 +++ front/lib/api/assistant/actions/process.ts | 41 ++++ front/lib/api/assistant/actions/retrieval.ts | 62 +++++ .../lib/api/assistant/actions/tables_query.ts | 33 +++ front/lib/api/assistant/generation.ts | 223 +++++++++++++++++- .../src/front/lib/api/assistant/generation.ts | 49 ++++ 7 files changed, 520 insertions(+), 6 deletions(-) diff --git a/front/admin/cli.ts b/front/admin/cli.ts index 6e07dd18b706..cab599712396 100644 --- a/front/admin/cli.ts +++ b/front/admin/cli.ts @@ -1,9 +1,16 @@ -import { ConnectorsAPI, removeNulls } from "@dust-tt/types"; +import { + CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, + ConnectorsAPI, + removeNulls, + SUPPORTED_MODEL_CONFIGS, +} from "@dust-tt/types"; import { CoreAPI } from "@dust-tt/types"; import { Storage } from "@google-cloud/storage"; import parseArgs from "minimist"; import readline from "readline"; +import { getConversation } from "@app/lib/api/assistant/conversation"; +import { renderConversationForModelMultiActions } from "@app/lib/api/assistant/generation"; import { getDataSources } from "@app/lib/api/data_sources"; import { renderUserType } from "@app/lib/api/user"; import { Authenticator } from "@app/lib/auth"; @@ -507,6 +514,77 @@ const eventSchema = async (command: string, args: parseArgs.ParsedArgs) => { } }; +const conversation = async (command: string, args: parseArgs.ParsedArgs) => { + switch (command) { + case "render-for-model": { + if (!args.wId) { + throw new Error("Missing --wId argument"); + } + if (!args.cId) { + throw new Error("Missing --cId argument"); + } + const verbose = args.verbose === "true"; + + const modelId = + args.modelId ?? CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.modelId; + const model = SUPPORTED_MODEL_CONFIGS.find( + (m) => m.modelId === modelId && m.supportsMultiActions + ); + if (!model) { + throw new Error(`Model not found: modelId='${modelId}'`); + } + + const auth = await Authenticator.internalAdminForWorkspace(args.wId); + const conversation = await getConversation(auth, args.cId as string); + + if (!conversation) { + throw new Error(`Conversation not found: cId='${args.cId}'`); + } + + const MIN_GENERATION_TOKENS = 2048; + const allowedTokenCount = model.contextSize - MIN_GENERATION_TOKENS; + const prompt = ""; + + const response = await renderConversationForModelMultiActions({ + conversation, + model, + prompt, + allowedTokenCount, + }); + + if (response.isErr()) { + logger.error(response.error.message); + } else { + logger.info( + { + model, + prompt, + }, + "Called renderConversationForModel with params:" + ); + const result = response.value; + + if (!verbose) { + // For convenience we shorten the content when role = "tool" + result.modelConversation.messages = + result.modelConversation.messages.map((m) => { + if (m.role === "function") { + return { + ...m, + content: m.content.slice(0, 200) + "...", + }; + } + return m; + }); + } + + logger.info(result, "Result from renderConversationForModel:"); + } + return; + } + } +}; + const main = async () => { const argv = parseArgs(process.argv.slice(2)); @@ -514,7 +592,9 @@ const main = async () => { console.log( "Expects object type and command as first two arguments, eg: `cli workspace create ...`" ); - console.log("Possible object types: `workspace`, `user`, `data-source`"); + console.log( + "Possible object types: `workspace`, `user`, `data-source`, `conversation`" + ); return; } @@ -533,9 +613,11 @@ const main = async () => { case "event-schema": await eventSchema(command, argv); return; + case "conversation": + return conversation(command, argv); default: console.log( - "Unknown object type, possible values: `workspace`, `user`, `data-source`, `event-schema`" + "Unknown object type, possible values: `workspace`, `user`, `data-source`, `event-schema`, `conversation`" ); return; } diff --git a/front/lib/api/assistant/actions/dust_app_run.ts b/front/lib/api/assistant/actions/dust_app_run.ts index e247cd9ab461..cf11a757c5af 100644 --- a/front/lib/api/assistant/actions/dust_app_run.ts +++ b/front/lib/api/assistant/actions/dust_app_run.ts @@ -4,6 +4,8 @@ import type { DustAppRunErrorEvent, DustAppRunParamsEvent, DustAppRunSuccessEvent, + FunctionCallType, + FunctionMessageTypeModel, ModelId, ModelMessageType, } from "@dust-tt/types"; @@ -47,6 +49,34 @@ export function renderDustAppRunActionForModel( }; } +export function renderDustAppRunActionFunctionCall( + action: DustAppRunActionType +): FunctionCallType { + return { + id: action.id.toString(), // @todo Daph replace with the actual tool id + type: "function", + function: { + name: action.appName, + arguments: JSON.stringify(action.params), + }, + }; +} +export function renderDustAppRunActionForMultiActionsModel( + action: DustAppRunActionType +): FunctionMessageTypeModel { + let content = ""; + + // Note action.output can be any valid JSON including null. + content += `OUTPUT:\n`; + content += `${JSON.stringify(action.output, null, 2)}\n`; + + return { + role: "function" as const, + function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id + content, + }; +} + /** * Params generation. */ diff --git a/front/lib/api/assistant/actions/process.ts b/front/lib/api/assistant/actions/process.ts index 229a955c88bb..3be75e47b61c 100644 --- a/front/lib/api/assistant/actions/process.ts +++ b/front/lib/api/assistant/actions/process.ts @@ -3,6 +3,8 @@ import type { AgentConfigurationType, AgentMessageType, ConversationType, + FunctionCallType, + FunctionMessageTypeModel, ModelId, ModelMessageType, ProcessActionOutputsType, @@ -69,6 +71,45 @@ export function renderProcessActionForModel( }; } +export function renderProcessActionFunctionCall( + action: ProcessActionType +): FunctionCallType { + return { + id: action.id.toString(), // @todo Daph replace with the actual tool id + type: "function", + function: { + name: "process_data_sources", + arguments: JSON.stringify(action.params), + }, + }; +} +export function renderProcessActionForMultiActionsModel( + action: ProcessActionType +): FunctionMessageTypeModel { + let content = ""; + if (action.outputs === null) { + throw new Error( + "Output not set on process action; this usually means the process action is not finished." + ); + } + + content += "PROCESSED OUTPUTS:\n"; + + // TODO(spolu): figure out if we want to add the schema here? + + if (action.outputs) { + for (const o of action.outputs.data) { + content += `${JSON.stringify(o)}\n`; + } + } + + return { + role: "function" as const, + function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id + content, + }; +} + /** * Params generation. */ diff --git a/front/lib/api/assistant/actions/retrieval.ts b/front/lib/api/assistant/actions/retrieval.ts index 23fdad21671e..544ad31fa9e2 100644 --- a/front/lib/api/assistant/actions/retrieval.ts +++ b/front/lib/api/assistant/actions/retrieval.ts @@ -1,4 +1,6 @@ import type { + FunctionCallType, + FunctionMessageTypeModel, ModelId, ModelMessageType, RetrievalErrorEvent, @@ -149,6 +151,66 @@ export function renderRetrievalActionForModel( }; } +export function rendeRetrievalActionFunctionCall( + action: RetrievalActionType +): FunctionCallType { + const timeFrame = action.params.relativeTimeFrame; + const params = { + query: action.params.query, + relativeTimeFrame: timeFrame + ? `${timeFrame.duration}${timeFrame.unit}` + : "all", + topK: action.params.topK, + }; + + return { + id: action.id.toString(), // @todo Daph replace with the actual tool id + type: "function", + function: { + name: "search_data_sources", + arguments: JSON.stringify(params), + }, + }; +} +export function renderRetrievalActionForMultiActionsModel( + action: RetrievalActionType +): FunctionMessageTypeModel { + let content = ""; + if (!action.documents) { + throw new Error( + "Documents not set on retrieval action; this usually means the retrieval action is not finished." + ); + } + for (const d of action.documents) { + let title = d.documentId; + for (const t of d.tags) { + if (t.startsWith("title:")) { + title = t.substring(6); + break; + } + } + + let dataSourceName = d.dataSourceId; + if (d.dataSourceId.startsWith("managed-")) { + dataSourceName = d.dataSourceId.substring(8); + } + + content += `TITLE: ${title} (data source: ${dataSourceName})\n`; + content += `REFERENCE: ${d.reference}\n`; + content += `EXTRACTS:\n`; + for (const c of d.chunks) { + content += `${c.text}\n`; + } + content += "\n"; + } + + return { + role: "function" as const, + function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id + content, + }; +} + /** * Params generation. */ diff --git a/front/lib/api/assistant/actions/tables_query.ts b/front/lib/api/assistant/actions/tables_query.ts index b20509f9ce9c..a6c4175ade3a 100644 --- a/front/lib/api/assistant/actions/tables_query.ts +++ b/front/lib/api/assistant/actions/tables_query.ts @@ -4,6 +4,8 @@ import type { AgentMessageType, ConversationType, DustAppParameters, + FunctionCallType, + FunctionMessageTypeModel, ModelId, ModelMessageType, Result, @@ -44,6 +46,37 @@ export function renderTablesQueryActionForModel( }; } +export function rendeTablesQueryActionFunctionCall( + action: TablesQueryActionType +): FunctionCallType { + return { + id: action.id.toString(), // @todo Daph replace with the actual tool id + type: "function", + function: { + name: "query_tables", + arguments: JSON.stringify(action.params), + }, + }; +} +export function renderTablesQueryActionForMultiActionsModel( + action: TablesQueryActionType +): FunctionMessageTypeModel { + let content = ""; + if (!action.output) { + throw new Error( + "Output not set on TablesQuery action; execution is likely not finished." + ); + } + content += `OUTPUT:\n`; + content += `${JSON.stringify(action.output, null, 2)}\n`; + + return { + role: "function" as const, + function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id + content, + }; +} + // Internal interface for the retrieval and rendering of a TableQuery action. This should not be // used outside of api/assistant. We allow a ModelId interface here because we don't have `sId` on // actions (the `sId` is on the `Message` object linked to the `UserMessage` parent of this action). diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index 73644b99e1b0..a279bb74d2d5 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -1,13 +1,16 @@ import type { AgentConfigurationType, AgentMessageType, + ContentFragmentMessageTypeModel, ConversationType, GenerationCancelEvent, GenerationErrorEvent, GenerationSuccessEvent, GenerationTokensEvent, ModelConversationType, + ModelConversationTypeMultiActions, ModelMessageType, + ModelMessageTypeMultiActions, Result, UserMessageType, } from "@dust-tt/types"; @@ -25,17 +28,32 @@ import { isTablesQueryActionType, isUserMessageType, Ok, + removeNulls, } from "@dust-tt/types"; import moment from "moment-timezone"; import { runActionStreamed } from "@app/lib/actions/server"; -import { renderDustAppRunActionForModel } from "@app/lib/api/assistant/actions/dust_app_run"; -import { renderProcessActionForModel } from "@app/lib/api/assistant/actions/process"; import { + renderDustAppRunActionForModel, + renderDustAppRunActionForMultiActionsModel, + renderDustAppRunActionFunctionCall, +} from "@app/lib/api/assistant/actions/dust_app_run"; +import { + renderProcessActionForModel, + renderProcessActionForMultiActionsModel, + renderProcessActionFunctionCall, +} from "@app/lib/api/assistant/actions/process"; +import { + rendeRetrievalActionFunctionCall, renderRetrievalActionForModel, + renderRetrievalActionForMultiActionsModel, retrievalMetaPrompt, } from "@app/lib/api/assistant/actions/retrieval"; -import { renderTablesQueryActionForModel } from "@app/lib/api/assistant/actions/tables_query"; +import { + renderTablesQueryActionForModel, + renderTablesQueryActionForMultiActionsModel, + rendeTablesQueryActionFunctionCall, +} from "@app/lib/api/assistant/actions/tables_query"; import { getAgentConfigurations } from "@app/lib/api/assistant/configuration"; import { getSupportedModelConfig, isLargeModel } from "@app/lib/assistant"; import type { Authenticator } from "@app/lib/auth"; @@ -255,6 +273,205 @@ export async function renderConversationForModel({ }); } +export async function renderConversationForModelMultiActions({ + conversation, + model, + prompt, + allowedTokenCount, +}: { + conversation: ConversationType; + model: { providerId: string; modelId: string }; + prompt: string; + allowedTokenCount: number; +}): Promise< + Result< + { + modelConversation: ModelConversationTypeMultiActions; + tokensUsed: number; + }, + Error + > +> { + const messages: ModelMessageTypeMultiActions[] = []; + + // Render all messages and all actions. + for (let i = conversation.content.length - 1; i >= 0; i--) { + const versions = conversation.content[i]; + const m = versions[versions.length - 1]; + + if (isAgentMessageType(m)) { + if (m.content) { + messages.unshift({ + role: "assistant" as const, + name: m.configuration.name, + content: m.content, + }); + } + + const actions = removeNulls([m.action]); // Should be replaced with `m.actions` once we it on AgentMessageType. + const function_calls = []; + const function_messages = []; + + for (const action of actions) { + if (isRetrievalActionType(action)) { + function_messages.unshift( + renderRetrievalActionForMultiActionsModel(action) + ); + function_calls.unshift(rendeRetrievalActionFunctionCall(action)); + } else if (isDustAppRunActionType(action)) { + function_messages.unshift( + renderDustAppRunActionForMultiActionsModel(action) + ); + function_calls.unshift(renderDustAppRunActionFunctionCall(action)); + } else if (isTablesQueryActionType(action)) { + function_messages.unshift( + renderTablesQueryActionForMultiActionsModel(action) + ); + function_calls.unshift(rendeTablesQueryActionFunctionCall(action)); + } else if (isProcessActionType(action)) { + function_messages.unshift( + renderProcessActionForMultiActionsModel(action) + ); + function_calls.unshift(renderProcessActionFunctionCall(action)); + } else { + assertNever(action); + } + } + + if (function_calls.length > 0) { + messages.unshift({ + role: "assistant", + content: null, + function_calls, + }); + } + } else if (isUserMessageType(m)) { + // Replace all `:mention[{name}]{.*}` with `@name`. + const content = m.content.replaceAll( + /:mention\[([^\]]+)\]\{[^}]+\}/g, + (_, name) => { + return `@${name}`; + } + ); + messages.unshift({ + role: "user" as const, + name: m.context.fullName || m.context.username, + content, + }); + } else if (isContentFragmentType(m)) { + try { + const content = await getContentFragmentText({ + workspaceId: conversation.owner.sId, + conversationId: conversation.sId, + messageId: m.sId, + }); + messages.unshift({ + role: "content_fragment", + name: `inject_${m.contentType}`, + content: + `TITLE: ${m.title}\n` + + `TYPE: ${m.contentType}${ + m.contentType === "file_attachment" ? " (user provided)" : "" + }\n` + + `CONTENT:\n${content}`, + }); + } catch (error) { + logger.error( + { + error, + workspaceId: conversation.owner.sId, + conversationId: conversation.sId, + messageId: m.sId, + }, + "Failed to retrieve content fragment text" + ); + return new Err(new Error("Failed to retrieve content fragment text")); + } + } else { + assertNever(m); + } + } + + // Compute in parallel the token count for each message and the prompt. + const [messagesCountRes, promptCountRes] = await Promise.all([ + Promise.all( + messages.map((m) => { + let text = `${m.role} ${"name" in m ? m.name : ""} ${m.content ?? ""}`; + if ("function_calls" in m) { + text += m.function_calls + .map((f) => `${f.function.name} ${f.function.arguments}`) + .join(" "); + } + return tokenCountForText(text, model); + }) + ), + tokenCountForText(prompt, model), + ]); + + if (promptCountRes.isErr()) { + return new Err(promptCountRes.error); + } + + // We initialize `tokensUsed` to the prompt tokens + a bit of buffer for message rendering + // approximations, 64 tokens seems small enough and ample enough. + const tokensMargin = 64; + let tokensUsed = promptCountRes.value + tokensMargin; + + // Go backward and accumulate as much as we can within allowedTokenCount. + const selected: ModelMessageTypeMultiActions[] = []; + const truncationMessage = `... (content truncated)`; + const approxTruncMsgTokenCount = truncationMessage.length / 3; + + for (let i = messages.length - 1; i >= 0; i--) { + const r = messagesCountRes[i]; + if (r.isErr()) { + return new Err(r.error); + } + const c = r.value; + if (tokensUsed + c <= allowedTokenCount) { + tokensUsed += c; + selected.unshift(messages[i]); + } else if ( + // When a content fragment has more than the remaining number of tokens, we split it. + messages[i].role === "content_fragment" && + // Allow at least tokensMargin tokens in addition to the truncation message. + tokensUsed + approxTruncMsgTokenCount + tokensMargin < allowedTokenCount + ) { + const msg = messages[i] as ContentFragmentMessageTypeModel; + const remainingTokens = + allowedTokenCount - tokensUsed - approxTruncMsgTokenCount; + const contentRes = await tokenSplit(msg.content, model, remainingTokens); + if (contentRes.isErr()) { + return new Err(contentRes.error); + } + selected.unshift({ + ...msg, + content: contentRes.value + truncationMessage, + }); + tokensUsed += remainingTokens; + break; + } else { + break; + } + } + + while (selected.length > 0 && selected[0].role === "assistant") { + const tokenCountRes = messagesCountRes[messages.length - selected.length]; + if (tokenCountRes.isErr()) { + return new Err(tokenCountRes.error); + } + tokensUsed -= tokenCountRes.value; + selected.shift(); + } + + return new Ok({ + modelConversation: { + messages: selected, + }, + tokensUsed, + }); +} + /** * Generation execution. */ diff --git a/types/src/front/lib/api/assistant/generation.ts b/types/src/front/lib/api/assistant/generation.ts index e4f6f669b2e7..3914e278ef5a 100644 --- a/types/src/front/lib/api/assistant/generation.ts +++ b/types/src/front/lib/api/assistant/generation.ts @@ -12,6 +12,55 @@ export type ModelConversationType = { messages: ModelMessageType[]; }; +export type ContentFragmentMessageTypeModel = { + role: "content_fragment"; + name: string; + content: string; +}; + +export type UserMessageTypeModel = { + role: "user"; + name: string; + content: string; +}; + +export type FunctionCallType = { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; +}; + +export type AssistantFunctionCallMessageTypeModel = { + role: "assistant"; + content: string | null; + function_calls: FunctionCallType[]; +}; +export type AssistantContentMessageTypeModel = { + role: "assistant"; + name: string; + content: string; +}; + +export type FunctionMessageTypeModel = { + role: "function"; + function_call_id: string; + content: string; +}; + +export type ModelMessageTypeMultiActions = + | ContentFragmentMessageTypeModel + | UserMessageTypeModel + | AssistantFunctionCallMessageTypeModel + | AssistantContentMessageTypeModel + | FunctionMessageTypeModel; + +export type ModelConversationTypeMultiActions = { + messages: ModelMessageTypeMultiActions[]; +}; + /** * Generation execution. */