Skip to content

Commit

Permalink
Multi-action renderConversationForModel with cli (#4963)
Browse files Browse the repository at this point in the history
* Multi-action renderConversationForModel with cli

* Replace tool by function

* Snake case as expected by the model

* I can't be git blamed on the fuck it comment :D

* Back to function_calls & improve token count estimation

* Get change from #4973

* Fix we want a single assistant message for all function calls of a given AgentMessageType

* Better naming
  • Loading branch information
PopDaph authored May 6, 2024
1 parent 030edd1 commit 8940f0a
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 6 deletions.
88 changes: 85 additions & 3 deletions front/admin/cli.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import { ConnectorsAPI, removeNulls } from "@dust-tt/types";
import {
CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG,
ConnectorsAPI,
removeNulls,
SUPPORTED_MODEL_CONFIGS,
} from "@dust-tt/types";
import { CoreAPI } from "@dust-tt/types";
import { Storage } from "@google-cloud/storage";
import parseArgs from "minimist";
import readline from "readline";

import { getConversation } from "@app/lib/api/assistant/conversation";
import { renderConversationForModelMultiActions } from "@app/lib/api/assistant/generation";
import { getDataSources } from "@app/lib/api/data_sources";
import { renderUserType } from "@app/lib/api/user";
import { Authenticator } from "@app/lib/auth";
Expand Down Expand Up @@ -507,14 +514,87 @@ const eventSchema = async (command: string, args: parseArgs.ParsedArgs) => {
}
};

const conversation = async (command: string, args: parseArgs.ParsedArgs) => {
switch (command) {
case "render-for-model": {
if (!args.wId) {
throw new Error("Missing --wId argument");
}
if (!args.cId) {
throw new Error("Missing --cId argument");
}
const verbose = args.verbose === "true";

const modelId =
args.modelId ?? CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.modelId;
const model = SUPPORTED_MODEL_CONFIGS.find(
(m) => m.modelId === modelId && m.supportsMultiActions
);
if (!model) {
throw new Error(`Model not found: modelId='${modelId}'`);
}

const auth = await Authenticator.internalAdminForWorkspace(args.wId);
const conversation = await getConversation(auth, args.cId as string);

if (!conversation) {
throw new Error(`Conversation not found: cId='${args.cId}'`);
}

const MIN_GENERATION_TOKENS = 2048;
const allowedTokenCount = model.contextSize - MIN_GENERATION_TOKENS;
const prompt = "";

const response = await renderConversationForModelMultiActions({
conversation,
model,
prompt,
allowedTokenCount,
});

if (response.isErr()) {
logger.error(response.error.message);
} else {
logger.info(
{
model,
prompt,
},
"Called renderConversationForModel with params:"
);
const result = response.value;

if (!verbose) {
// For convenience we shorten the content when role = "tool"
result.modelConversation.messages =
result.modelConversation.messages.map((m) => {
if (m.role === "function") {
return {
...m,
content: m.content.slice(0, 200) + "...",
};
}
return m;
});
}

logger.info(result, "Result from renderConversationForModel:");
}
return;
}
}
};

const main = async () => {
const argv = parseArgs(process.argv.slice(2));

if (argv._.length < 2) {
console.log(
"Expects object type and command as first two arguments, eg: `cli workspace create ...`"
);
console.log("Possible object types: `workspace`, `user`, `data-source`");
console.log(
"Possible object types: `workspace`, `user`, `data-source`, `conversation`"
);
return;
}

Expand All @@ -533,9 +613,11 @@ const main = async () => {
case "event-schema":
await eventSchema(command, argv);
return;
case "conversation":
return conversation(command, argv);
default:
console.log(
"Unknown object type, possible values: `workspace`, `user`, `data-source`, `event-schema`"
"Unknown object type, possible values: `workspace`, `user`, `data-source`, `event-schema`, `conversation`"
);
return;
}
Expand Down
30 changes: 30 additions & 0 deletions front/lib/api/assistant/actions/dust_app_run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import type {
DustAppRunErrorEvent,
DustAppRunParamsEvent,
DustAppRunSuccessEvent,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
ModelMessageType,
} from "@dust-tt/types";
Expand Down Expand Up @@ -47,6 +49,34 @@ export function renderDustAppRunActionForModel(
};
}

export function renderDustAppRunActionFunctionCall(
action: DustAppRunActionType
): FunctionCallType {
return {
id: action.id.toString(), // @todo Daph replace with the actual tool id
type: "function",
function: {
name: action.appName,
arguments: JSON.stringify(action.params),
},
};
}
export function renderDustAppRunActionForMultiActionsModel(
action: DustAppRunActionType
): FunctionMessageTypeModel {
let content = "";

// Note action.output can be any valid JSON including null.
content += `OUTPUT:\n`;
content += `${JSON.stringify(action.output, null, 2)}\n`;

return {
role: "function" as const,
function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id
content,
};
}

/**
* Params generation.
*/
Expand Down
41 changes: 41 additions & 0 deletions front/lib/api/assistant/actions/process.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type {
AgentConfigurationType,
AgentMessageType,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
ModelMessageType,
ProcessActionOutputsType,
Expand Down Expand Up @@ -69,6 +71,45 @@ export function renderProcessActionForModel(
};
}

export function renderProcessActionFunctionCall(
action: ProcessActionType
): FunctionCallType {
return {
id: action.id.toString(), // @todo Daph replace with the actual tool id
type: "function",
function: {
name: "process_data_sources",
arguments: JSON.stringify(action.params),
},
};
}
export function renderProcessActionForMultiActionsModel(
action: ProcessActionType
): FunctionMessageTypeModel {
let content = "";
if (action.outputs === null) {
throw new Error(
"Output not set on process action; this usually means the process action is not finished."
);
}

content += "PROCESSED OUTPUTS:\n";

// TODO(spolu): figure out if we want to add the schema here?

if (action.outputs) {
for (const o of action.outputs.data) {
content += `${JSON.stringify(o)}\n`;
}
}

return {
role: "function" as const,
function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id
content,
};
}

/**
* Params generation.
*/
Expand Down
62 changes: 62 additions & 0 deletions front/lib/api/assistant/actions/retrieval.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type {
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
ModelMessageType,
RetrievalErrorEvent,
Expand Down Expand Up @@ -149,6 +151,66 @@ export function renderRetrievalActionForModel(
};
}

export function rendeRetrievalActionFunctionCall(
action: RetrievalActionType
): FunctionCallType {
const timeFrame = action.params.relativeTimeFrame;
const params = {
query: action.params.query,
relativeTimeFrame: timeFrame
? `${timeFrame.duration}${timeFrame.unit}`
: "all",
topK: action.params.topK,
};

return {
id: action.id.toString(), // @todo Daph replace with the actual tool id
type: "function",
function: {
name: "search_data_sources",
arguments: JSON.stringify(params),
},
};
}
export function renderRetrievalActionForMultiActionsModel(
action: RetrievalActionType
): FunctionMessageTypeModel {
let content = "";
if (!action.documents) {
throw new Error(
"Documents not set on retrieval action; this usually means the retrieval action is not finished."
);
}
for (const d of action.documents) {
let title = d.documentId;
for (const t of d.tags) {
if (t.startsWith("title:")) {
title = t.substring(6);
break;
}
}

let dataSourceName = d.dataSourceId;
if (d.dataSourceId.startsWith("managed-")) {
dataSourceName = d.dataSourceId.substring(8);
}

content += `TITLE: ${title} (data source: ${dataSourceName})\n`;
content += `REFERENCE: ${d.reference}\n`;
content += `EXTRACTS:\n`;
for (const c of d.chunks) {
content += `${c.text}\n`;
}
content += "\n";
}

return {
role: "function" as const,
function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id
content,
};
}

/**
* Params generation.
*/
Expand Down
33 changes: 33 additions & 0 deletions front/lib/api/assistant/actions/tables_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import type {
AgentMessageType,
ConversationType,
DustAppParameters,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
ModelMessageType,
Result,
Expand Down Expand Up @@ -44,6 +46,37 @@ export function renderTablesQueryActionForModel(
};
}

export function rendeTablesQueryActionFunctionCall(
action: TablesQueryActionType
): FunctionCallType {
return {
id: action.id.toString(), // @todo Daph replace with the actual tool id
type: "function",
function: {
name: "query_tables",
arguments: JSON.stringify(action.params),
},
};
}
export function renderTablesQueryActionForMultiActionsModel(
action: TablesQueryActionType
): FunctionMessageTypeModel {
let content = "";
if (!action.output) {
throw new Error(
"Output not set on TablesQuery action; execution is likely not finished."
);
}
content += `OUTPUT:\n`;
content += `${JSON.stringify(action.output, null, 2)}\n`;

return {
role: "function" as const,
function_call_id: action.id.toString(), // @todo Daph replace with the actual tool id
content,
};
}

// Internal interface for the retrieval and rendering of a TableQuery action. This should not be
// used outside of api/assistant. We allow a ModelId interface here because we don't have `sId` on
// actions (the `sId` is on the `Message` object linked to the `UserMessage` parent of this action).
Expand Down
Loading

0 comments on commit 8940f0a

Please sign in to comment.