Skip to content

Commit

Permalink
💄 adopt the latest tool result api (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
rebornix authored Oct 21, 2024
1 parent b690710 commit 74ba8d8
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 134 deletions.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"author": "Microsoft Corporation",
"homepage": "https://github.com/microsoft/Advanced-Data-Analysis-for-Copilot",
"icon": "images/icon.png",
"version": "0.1.2",
"version": "0.1.3",
"engines": {
"vscode": "^1.95.0"
},
Expand Down
140 changes: 69 additions & 71 deletions src/base.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
PromptSizing,
UserMessage,
} from '@vscode/prompt-tsx';
import { Chunk, TextChunk, ToolCall, ToolMessage } from '@vscode/prompt-tsx/dist/base/promptElements';
import { Chunk, TextChunk, ToolCall, ToolMessage, ToolResult } from '@vscode/prompt-tsx/dist/base/promptElements';
import * as path from 'path';
import * as vscode from "vscode";
import { logger } from './logger';
Expand Down Expand Up @@ -393,95 +393,76 @@ class ToolCalls extends PromptElement<ToolCallsProps, void> {
};
}

private async _renderOneToolCall(toolCall: vscode.LanguageModelToolCallPart, resultsFromCurrentRound: Record<string, vscode.LanguageModelToolResult>, sizing: PromptSizing, toolInvocationToken: vscode.ChatParticipantToolToken | undefined): Promise<{ promptPiece: PromptPiece, hasError: boolean, size: number }> {
private async _renderOneToolCall(toolCall: vscode.LanguageModelToolCallPart, resultsFromCurrentRound: Record<string, vscode.LanguageModelToolResult | Error>, sizing: PromptSizing, toolInvocationToken: vscode.ChatParticipantToolToken | undefined): Promise<{ promptPiece: PromptPiece, hasError: boolean, size: number }> {
const tool = vscode.lm.tools.find((tool) => tool.name === toolCall.name);
if (!tool) {
logger.error(`Tool not found: ${toolCall.name}`);
return { promptPiece: <ToolMessage toolCallId={getToolCallId(toolCall)}>Tool not found</ToolMessage>, hasError: false, size: await sizing.countTokens('Tool not found') };
}

const toolResult = resultsFromCurrentRound[getToolCallId(toolCall)] || await this._getToolCallResult(tool, toolCall, toolInvocationToken, sizing);
const toolResult = await this._getToolCallResult(tool, toolCall, resultsFromCurrentRound, toolInvocationToken, sizing);

const error = getToolResultValue<Error>(toolResult, 'application/vnd.code.notebook.error');
if (error) {
const errorContent = [error.name || '', error.message || '', error.stack || ''].filter((part) => part).join('\n');
if (isError(toolResult)) {
const errorContent = [toolResult.name || '', toolResult.message || '', toolResult.stack || ''].filter((part) => part).join('\n');
const errorMessage = `The tool returned an error, analyze this error and attempt to resolve this. Error: ${errorContent}`;

const result = new vscode.LanguageModelToolResult([new vscode.LanguageModelTextPart(errorMessage)]);
const size = await sizing.countTokens(errorMessage);
return {
promptPiece: <ToolMessage toolCallId={getToolCallId(toolCall)}>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), toolResult)}></meta>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), result)}></meta>
<TextChunk>{errorMessage}</TextChunk>
</ToolMessage>, hasError: true, size: size
};
}

const image = getToolResultValue<string>(toolResult, 'image/png');
const plainText = getToolResultValue<string>(toolResult, 'text/plain');
if (image) {
const imageOutput = await this._processImageOutput(toolCall.name, image, sizing);
const promptSize = await this._countToolCallResultsize(toolResult, sizing);

if (plainText) {
const text = plainText;
const textPromptSize = await sizing.countTokens(text);
return {
promptPiece: <ToolMessage toolCallId={getToolCallId(toolCall)}>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), toolResult)}></meta>
<ToolResult data={toolResult}/>
</ToolMessage>, hasError: false, size: promptSize
};
}

return {
promptPiece: <Chunk>
<ToolMessage toolCallId={getToolCallId(toolCall)}>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), toolResult)}></meta>
{text}
{imageOutput.result}
</ToolMessage>
<UserMessage>{imageOutput.additionalUserMessage}</UserMessage>
</Chunk>, hasError: false, size: imageOutput.size + textPromptSize
};
} else {
return {
promptPiece: <Chunk>
<ToolMessage toolCallId={getToolCallId(toolCall)}>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), toolResult)}></meta>
{imageOutput.result}
</ToolMessage>
<UserMessage>{imageOutput.additionalUserMessage}</UserMessage>
</Chunk>, hasError: false, size: imageOutput.size
};
}
private async _getToolCallResult(tool: vscode.LanguageModelToolInformation, toolCall: vscode.LanguageModelToolCallPart, resultsFromCurrentRound: Record<string, vscode.LanguageModelToolResult | Error>, toolInvocationToken: vscode.ChatParticipantToolToken | undefined, sizing: PromptSizing) {
if (resultsFromCurrentRound[getToolCallId(toolCall)]) {
return resultsFromCurrentRound[getToolCallId(toolCall)];
}

if (plainText) {
const text = plainText;
const promptSize = await sizing.countTokens(text);
const token = new vscode.CancellationTokenSource().token;
try {
const toolResult = await vscode.lm.invokeTool(
tool.name,
{
parameters: toolCall.parameters,
toolInvocationToken: toolInvocationToken,
tokenizationOptions: {
tokenBudget: sizing.tokenBudget,
countTokens: async (text, token) => {
return sizing.countTokens(text, token);
}
}
},
token
);

return {
promptPiece: <ToolMessage toolCallId={getToolCallId(toolCall)}>
<meta value={new ToolResultMetadata(getToolCallId(toolCall), toolResult)}></meta>
{text}
</ToolMessage>, hasError: false, size: promptSize
};
return toolResult as vscode.LanguageModelToolResult;
} catch (e: unknown) {
const error = e as Error;
return error;
}

return { promptPiece: <></>, hasError: false, size: 0 };
}

private async _getToolCallResult(tool: vscode.LanguageModelToolInformation, toolCall: vscode.LanguageModelToolCallPart, toolInvocationToken: vscode.ChatParticipantToolToken | undefined, sizing: PromptSizing) {
const token = new vscode.CancellationTokenSource().token;

const toolResult = await vscode.lm.invokeTool(
tool.name,
{
parameters: toolCall.parameters,
toolInvocationToken: toolInvocationToken,
tokenizationOptions: {
tokenBudget: sizing.tokenBudget,
countTokens: async (text, token) => {
return sizing.countTokens(text, token);
}
}
},
token
);
private async _countToolCallResultsize(toolResult: vscode.LanguageModelToolResult, sizing: PromptSizing) {
let size = 0;
for (const part of toolResult.content) {
if (part instanceof vscode.LanguageModelTextPart) {
size += await sizing.countTokens(part.value);
}
}

return toolResult;
return size;
}

private async _processImageOutput(toolCallName: string, base64Png: string, sizing: PromptSizing) {
Expand Down Expand Up @@ -539,18 +520,35 @@ export class ToolResultMetadata extends PromptMetadata {
}
}

export function getToolResultValue<T>(result: vscode.LanguageModelToolResult | undefined, mime: string): T | undefined {
export function isError(e: unknown): e is Error {
return e instanceof Error || (
typeof e === 'object' &&
e !== null &&
typeof (e as Error).message === 'string' &&
typeof (e as Error).name === 'string'
);
}

export function isTextPart(e: unknown): e is vscode.LanguageModelTextPart {
return e instanceof vscode.LanguageModelTextPart || !!((e as vscode.LanguageModelTextPart).value);
}

export function getToolResultValue<T>(result: vscode.LanguageModelToolResult | Error | undefined, mime: string): T | undefined {
if (!result) {
return;
}
const item = result.content.filter(c => c instanceof vscode.LanguageModelPromptTsxPart).find(c => c.mime === mime);
if (!item && mime === 'text/plain') {
return result.content.find(c => c instanceof vscode.LanguageModelTextPart)?.value as T;

if ((result as vscode.LanguageModelToolResult).content) {
const content = (result as vscode.LanguageModelToolResult).content;
const item = content.filter(c => (c instanceof vscode.LanguageModelPromptTsxPart)).find(c => c.mime === mime);
if (!item && mime === 'text/plain') {
return content.filter(c => isTextPart(c)).map(c => c.value).join('\n') as unknown as T;
}
return item?.value as T;
}
return item?.value as T;
}

export function getToolCallId(tool: vscode.LanguageModelToolCallPart) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return tool.callId || (tool as any).toolCallId
return tool.callId || (tool as any).toolCallId;
}
54 changes: 3 additions & 51 deletions src/dataAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) Microsoft Corporation and GitHub. All rights reserved.
*--------------------------------------------------------------------------------------------*/

import { ChatMessage, ChatRole, HTMLTracer, PromptRenderer } from '@vscode/prompt-tsx';
import { ChatMessage, HTMLTracer, PromptRenderer, toVsCodeChatMessages } from '@vscode/prompt-tsx';
import * as vscode from 'vscode';
import { DataAgentPrompt, getToolCallId, PromptProps, ToolCallRound, ToolResultMetadata, TsxToolUserMetadata } from './base';
import { Exporter } from './exportCommand';
Expand All @@ -14,54 +14,6 @@ const MODEL_SELECTOR: vscode.LanguageModelChatSelector = {
family: 'gpt-4o'
};

export function toVsCodeChatMessages(messages: ChatMessage[], toolResultMetadata: ToolResultMetadata[]) {
return messages.map(m => {
switch (m.role) {
case ChatRole.Assistant:
{
const message: vscode.LanguageModelChatMessage = vscode.LanguageModelChatMessage.Assistant(
m.content,
m.name
);
if (m.tool_calls) {
message.content2 = [m.content];
message.content2.push(
...m.tool_calls.map(
tc =>
new vscode.LanguageModelToolCallPart(tc.function.name, tc.id, JSON.parse(tc.function.arguments))
)
);
}
return message;
}
case ChatRole.User:
return vscode.LanguageModelChatMessage.User(m.content, m.name);
case ChatRole.Function: {
const message: vscode.LanguageModelChatMessage = vscode.LanguageModelChatMessage.User('');
// const content = toolResultMetadata.find(c => c.toolCallId === m.tool_call_id)?.result.content;
// if (m.tool_call_id && content) {
// message.content2 = [new vscode.LanguageModelToolResultPart(m.tool_call_id!, content)];
// }
// message.content2 = [new vscode.LanguageModelToolResultPart(m.name, m.content)];
return message;
}
case ChatRole.Tool: {
{
const message: vscode.LanguageModelChatMessage = vscode.LanguageModelChatMessage.User(m.content);
const content = toolResultMetadata.find(c => c.toolCallId === m.tool_call_id)?.result.content;
if (m.tool_call_id && content) {
message.content2 = [new vscode.LanguageModelToolResultPart(m.tool_call_id, content)];
}
return message;
}
}
default:
throw new Error(
`Converting chat message with role ${m.role} to VS Code chat message is not supported.`
);
}
});
}

export class DataAgent implements vscode.Disposable {
private _disposables: vscode.Disposable[] = [];
Expand Down Expand Up @@ -131,7 +83,7 @@ export class DataAgent implements vscode.Disposable {
};

const result = await this._renderMessages(chat, { userQuery: request.prompt, references: request.references, history: chatContext.history, currentToolCallRounds: [], toolInvocationToken: request.toolInvocationToken, extensionContext: this.extensionContext }, stream);
let messages = toVsCodeChatMessages(result.messages, []);
let messages = toVsCodeChatMessages(result.messages);
const toolReferences = [...request.toolReferences];
const toolCallRounds: ToolCallRound[] = [];

Expand Down Expand Up @@ -177,7 +129,7 @@ export class DataAgent implements vscode.Disposable {

const result = await this._renderMessages(chat, { userQuery: request.prompt, references: request.references, history: chatContext.history, currentToolCallRounds: toolCallRounds, toolInvocationToken: request.toolInvocationToken, extensionContext: this.extensionContext }, stream);
const toolResultMetadata = result.metadata.getAll(ToolResultMetadata)
messages = toVsCodeChatMessages(result.messages, toolResultMetadata);
messages = toVsCodeChatMessages(result.messages);
logger.info('Token count', result.tokenCount);
if (toolResultMetadata?.length) {
toolResultMetadata.forEach(meta => {
Expand Down
47 changes: 38 additions & 9 deletions src/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type { Kernel } from '../pyodide/node/index';
import { logger } from './logger';

export const ErrorMime = 'application/vnd.code.notebook.error';
const ImagePrefix = `8a59d504`;

interface IFindFilesParameters {
pattern: string;
Expand Down Expand Up @@ -59,7 +60,7 @@ export class RunPythonTool implements vscode.LanguageModelTool<IRunPythonParamet
public static Id = 'dachat_data_runPython';
private _kernel: Kernel;
private pendingRequests: Promise<unknown> = Promise.resolve();
constructor(context: vscode.ExtensionContext) {
constructor(readonly context: vscode.ExtensionContext) {
const pyodidePath = vscode.Uri.joinPath(context.extensionUri, 'pyodide');
const kernelPath = vscode.Uri.joinPath(pyodidePath, 'node', 'index.js').fsPath;
const workerPath = vscode.Uri.joinPath(pyodidePath, 'node', 'comlink.worker.js').fsPath;
Expand Down Expand Up @@ -104,17 +105,11 @@ export class RunPythonTool implements vscode.LanguageModelTool<IRunPythonParamet
}

if (result && result['image/png']) {
content.push(new vscode.LanguageModelPromptTsxPart(result['image/png'], 'image/png'));
content.push(await this._processImageOutput(result['image/png']));
}

if (result && result['application/vnd.code.notebook.error']) {
const error = result['application/vnd.code.notebook.error'] as Error;
// We need to ensure we pass back plain objects to VS Code that can be serialized..
content.push(new vscode.LanguageModelPromptTsxPart({
name: error.name || '',
message: error.message || '',
stack: error.stack || ''
}, 'application/vnd.code.notebook.error'));
throw result['application/vnd.code.notebook.error'] as Error;
}
return new vscode.LanguageModelToolResult(content);
}
Expand All @@ -128,6 +123,40 @@ export class RunPythonTool implements vscode.LanguageModelTool<IRunPythonParamet
invocationMessage: `Executing Code${reasonMessage}`
};
}

private async _processImageOutput(base64Png: string) {
const userMessageWithWithImageFromToolCall = `Return this image link in your response. Do not modify the markdown image link at all. The path is already absolute local file path, do not put "https" or "blob" in the link`;
if (this.context.storageUri) {
const imagePath = await this._saveImage(this.context.storageUri, RunPythonTool.Id, Buffer.from(base64Png, 'base64'));
if (imagePath) {
const markdownTextForImage = `The image generated from the code is ![${RunPythonTool.Id} result](${imagePath}). You can give this markdown link to users!`;
return new vscode.LanguageModelTextPart(markdownTextForImage + '\n' + userMessageWithWithImageFromToolCall);
}
}

const markdownTextForImage = `![${RunPythonTool.Id} result](data:image/png;base64,${base64Png})`;
return new vscode.LanguageModelTextPart(markdownTextForImage + '\n' + userMessageWithWithImageFromToolCall);
}

private async _saveImage(storageUri: vscode.Uri, tool: string, imageBuffer: Buffer): Promise<string | undefined> {
try {
await vscode.workspace.fs.stat(storageUri);
} catch {
await vscode.workspace.fs.createDirectory(storageUri);
}

const storagePath = storageUri.fsPath;
const imagePath = path.join(storagePath, `result-${tool}-${ImagePrefix}-${Date.now()}.png`);
const imageUri = vscode.Uri.file(imagePath);
try {
await vscode.workspace.fs.writeFile(imageUri, imageBuffer);
const encodedPath = encodeURI(imageUri.fsPath);
return encodedPath;
} catch (ex) {
logger.error('Error saving image', ex);
return undefined;
}
}
}

/**
Expand Down

0 comments on commit 74ba8d8

Please sign in to comment.