diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs index 57b9ea76dcb..e1eb5dea3dd 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs @@ -92,5 +92,15 @@ public static async Task RunAsync() calculateTax.Should().BeOfType>(); calculateTax.GetToolCalls().Should().HaveCount(1); calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax)); + + // parallel function calls + var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2"); + calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40 + calculateTaxes.Should().BeOfType>(); + calculateTaxes.GetToolCalls().Should().HaveCount(2); + calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax)); + + // send aggregate message back to llm to get the final result + var finalResult = await agent.SendAsync(calculateTaxes); } } diff --git a/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs index 47dbad55e30..830772c4b07 100644 --- a/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs @@ -169,7 +169,7 @@ public static bool IsSystemMessage(this IMessage message) TextMessage textMessage => textMessage.Content, Message msg => msg.Content, ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null, - AggregateMessage aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null, + AggregateMessage aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)), _ => null, }; } diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs index 8dcd98ea0ec..7312fd67136 100644 --- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs @@ -26,6 +26,8 @@ public ToolCall(string functionName, string functionArgs, string result) public string FunctionArguments { get; set; } + public string? ToolCallId { get; set; } + public string? Result { get; set; } public override string ToString() diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index 2bc02805538..3bc8b2f1c2c 100644 --- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -128,13 +128,13 @@ private async Task InvokeToolCallMessagesBeforeInvokingAg if (this.functionMap?.TryGetValue(functionName, out var func) is true) { var result = await func(functionArguments); - toolCallResult.Add(new ToolCall(functionName, functionArguments, result)); + toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId }); } else if (this.functionMap is not null) { var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage)); + toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId }); } else { @@ -156,7 +156,7 @@ private async Task InvokeToolCallMessagesAfterInvokingAgentAsync(ToolC if (this.functionMap?.TryGetValue(fName, out var func) is true) { var result = await func(fArgs); - toolCallResult.Add(new ToolCall(fName, fArgs, result)); + toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId }); } } diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 2925a43e16f..106296fced9 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -152,7 +152,7 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse .Where(tc => tc is ChatCompletionsFunctionToolCall) .Select(tc => (ChatCompletionsFunctionToolCall)tc); - var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments)); + var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id }); return new ToolCallMessage(toolCalls, from); } @@ -322,7 +322,7 @@ private IEnumerable ProcessToolCallMessage(IAgent agent, Too throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); } - var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From }; foreach (var tc in toolCall) { @@ -336,7 +336,7 @@ private IEnumerable ProcessToolCallResultMessage(ToolCallRes { return message.ToolCalls .Where(tc => tc.Result is not null) - .Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); + .Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}")); } private IEnumerable ProcessMessage(IAgent agent, Message message) diff --git a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index d17de56e129..b0f70409d29 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.OpenAI.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -145,7 +145,7 @@ "Type": "Function", "Name": "test", "Arguments": "test", - "Id": "test" + "Id": "test_0" } ], "FunctionCallName": null, @@ -159,7 +159,7 @@ { "Role": "tool", "Content": "result", - "ToolCallId": "test" + "ToolCallId": "test_0" } ] }, @@ -169,12 +169,12 @@ { "Role": "tool", "Content": "test", - "ToolCallId": "result" + "ToolCallId": "result_0" }, { "Role": "tool", "Content": "test", - "ToolCallId": "result" + "ToolCallId": "result_1" } ] }, @@ -190,13 +190,13 @@ "Type": "Function", "Name": "test", "Arguments": "test", - "Id": "test" + "Id": "test_0" }, { "Type": "Function", "Name": "test", "Arguments": "test", - "Id": "test" + "Id": "test_1" } ], "FunctionCallName": null, @@ -216,7 +216,7 @@ "Type": "Function", "Name": "test", "Arguments": "test", - "Id": "test" + "Id": "test_0" } ], "FunctionCallName": null, @@ -225,7 +225,7 @@ { "Role": "tool", "Content": "result", - "ToolCallId": "test" + "ToolCallId": "test_0" } ] } diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs index a8c1d3f7860..71a50608c74 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs @@ -293,6 +293,7 @@ public async Task ItProcessToolCallMessageAsync() chatRequestMessage.ToolCalls.First().Should().BeOfType(); var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First(); functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test_0"); functionToolCall.Arguments.Should().Be("test"); return await innerAgent.GenerateReplyAsync(msgs); }) @@ -303,6 +304,41 @@ public async Task ItProcessToolCallMessageAsync() await agent.GenerateReplyAsync([message]); } + [Fact] + public async Task ItProcessParallelToolCallMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + var innerMessage = msgs.Last(); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().BeNullOrEmpty(); + chatRequestMessage.Name.Should().Be("assistant"); + chatRequestMessage.ToolCalls.Count().Should().Be(2); + for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) + { + chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test"), + new ToolCall("test", "test"), + }; + IMessage message = new ToolCallMessage(toolCalls, "assistant"); + await agent.GenerateReplyAsync([message]); + } + [Fact] public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync() { @@ -326,7 +362,7 @@ public async Task ItProcessToolCallResultMessageAsync() innerMessage!.Should().BeOfType>(); var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; chatRequestMessage.Content.Should().Be("result"); - chatRequestMessage.ToolCallId.Should().Be("test"); + chatRequestMessage.ToolCallId.Should().Be("test_0"); return await innerAgent.GenerateReplyAsync(msgs); }) .RegisterMiddleware(middleware); @@ -336,6 +372,37 @@ public async Task ItProcessToolCallResultMessageAsync() await agent.GenerateReplyAsync([message]); } + [Fact] + public async Task ItProcessParallelToolCallResultMessageAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(2); + + for (int i = 0; i < msgs.Count(); i++) + { + var innerMessage = msgs.ElementAt(i); + innerMessage!.Should().BeOfType>(); + var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; + chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be($"test_{i}"); + } + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + IMessage message = new ToolCallResultMessage(toolCalls, "user"); + await agent.GenerateReplyAsync([message]); + } + [Fact] public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync() { @@ -372,6 +439,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() innerMessage!.Should().BeOfType>(); var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)innerMessage!).Content; chatRequestMessage.Content.Should().Be("result"); + chatRequestMessage.ToolCallId.Should().Be("test_0"); var toolCallMessage = msgs.First(); toolCallMessage!.Should().BeOfType>(); @@ -381,6 +449,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() toolCallRequestMessage.ToolCalls.First().Should().BeOfType(); var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First(); functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be("test_0"); functionToolCall.Arguments.Should().Be("test"); return await innerAgent.GenerateReplyAsync(msgs); }) @@ -393,6 +462,54 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync() await agent.GenerateReplyAsync([aggregateMessage]); } + [Fact] + public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync() + { + var middleware = new OpenAIChatRequestMessageConnector(); + var agent = new EchoAgent("assistant") + .RegisterMiddleware(async (msgs, _, innerAgent, _) => + { + msgs.Count().Should().Be(3); + var toolCallMessage = msgs.First(); + toolCallMessage!.Should().BeOfType>(); + var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope)toolCallMessage!).Content; + toolCallRequestMessage.Content.Should().BeNullOrEmpty(); + toolCallRequestMessage.ToolCalls.Count().Should().Be(2); + + for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++) + { + toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType(); + var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i); + functionToolCall.Name.Should().Be("test"); + functionToolCall.Id.Should().Be($"test_{i}"); + functionToolCall.Arguments.Should().Be("test"); + } + + for (int i = 1; i < msgs.Count(); i++) + { + var toolCallResultMessage = msgs.ElementAt(i); + toolCallResultMessage!.Should().BeOfType>(); + var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope)toolCallResultMessage!).Content; + toolCallResultRequestMessage.Content.Should().Be("result"); + toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}"); + } + + return await innerAgent.GenerateReplyAsync(msgs); + }) + .RegisterMiddleware(middleware); + + // user message + var toolCalls = new[] + { + new ToolCall("test", "test", "result"), + new ToolCall("test", "test", "result"), + }; + var toolCallMessage = new ToolCallMessage(toolCalls, "assistant"); + var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant"); + var aggregateMessage = new AggregateMessage(toolCallMessage, toolCallResultMessage, "assistant"); + await agent.GenerateReplyAsync([aggregateMessage]); + } + [Fact] public async Task ItConvertChatResponseMessageToTextMessageAsync() {