Skip to content

Commit

Permalink
[.Net] fix #2722 (#2723)
Browse files Browse the repository at this point in the history
* fix bug and add tests

* update
  • Loading branch information
LittleLittleCloud committed May 21, 2024
1 parent 31d2d37 commit 3e6f073
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 16 deletions.
10 changes: 10 additions & 0 deletions dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,15 @@ public static async Task RunAsync()
calculateTax.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// parallel function calls
var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
calculateTaxes.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTaxes.GetToolCalls().Should().HaveCount(2);
calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// send aggregate message back to llm to get the final result
var finalResult = await agent.SendAsync(calculateTaxes);
}
}
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public static bool IsSystemMessage(this IMessage message)
TextMessage textMessage => textMessage.Content,
Message msg => msg.Content,
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)),
_ => null,
};
}
Expand Down
2 changes: 2 additions & 0 deletions dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public ToolCall(string functionName, string functionArgs, string result)

public string FunctionArguments { get; set; }

public string? ToolCallId { get; set; }

public string? Result { get; set; }

public override string ToString()
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ private async Task<ToolCallResultMessage> InvokeToolCallMessagesBeforeInvokingAg
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";

toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
}
else
{
Expand All @@ -156,7 +156,7 @@ private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolC
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
toolCallResult.Add(new ToolCall(fName, fArgs, result));
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse
.Where(tc => tc is ChatCompletionsFunctionToolCall)
.Select(tc => (ChatCompletionsFunctionToolCall)tc);

var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });

return new ToolCallMessage(toolCalls, from);
}
Expand Down Expand Up @@ -322,7 +322,7 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}

var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
foreach (var tc in toolCall)
{
Expand All @@ -336,7 +336,7 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallResultMessage(ToolCallRes
{
return message.ToolCalls
.Where(tc => tc.Result is not null)
.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
.Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
}

private IEnumerable<ChatRequestMessage> ProcessMessage(IAgent agent, Message message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
Expand All @@ -159,7 +159,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
},
Expand All @@ -169,12 +169,12 @@
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_0"
},
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_1"
}
]
},
Expand All @@ -190,13 +190,13 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
},
{
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_1"
}
],
"FunctionCallName": null,
Expand All @@ -216,7 +216,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
Expand All @@ -225,7 +225,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
}
Expand Down
119 changes: 118 additions & 1 deletion dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ public async Task ItProcessToolCallMessageAsync()
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
Expand All @@ -303,6 +304,41 @@ public async Task ItProcessToolCallMessageAsync()
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessParallelToolCallMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test"),
new ToolCall("test", "test"),
};
IMessage message = new ToolCallMessage(toolCalls, "assistant");
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
{
Expand All @@ -326,7 +362,7 @@ public async Task ItProcessToolCallResultMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test");
chatRequestMessage.ToolCallId.Should().Be("test_0");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
Expand All @@ -336,6 +372,37 @@ public async Task ItProcessToolCallResultMessageAsync()
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessParallelToolCallResultMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(2);
for (int i = 0; i < msgs.Count(); i++)
{
var innerMessage = msgs.ElementAt(i);
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
IMessage message = new ToolCallResultMessage(toolCalls, "user");
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
{
Expand Down Expand Up @@ -372,6 +439,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test_0");
var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
Expand All @@ -381,6 +449,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
toolCallRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
Expand All @@ -393,6 +462,54 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
await agent.GenerateReplyAsync([aggregateMessage]);
}

[Fact]
public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(3);
var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)toolCallMessage!).Content;
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
{
toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}
for (int i = 1; i < msgs.Count(); i++)
{
var toolCallResultMessage = msgs.ElementAt(i);
toolCallResultMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)toolCallResultMessage!).Content;
toolCallResultRequestMessage.Content.Should().Be("result");
toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
var aggregateMessage = new AggregateMessage<ToolCallMessage, ToolCallResultMessage>(toolCallMessage, toolCallResultMessage, "assistant");
await agent.GenerateReplyAsync([aggregateMessage]);
}

[Fact]
public async Task ItConvertChatResponseMessageToTextMessageAsync()
{
Expand Down

0 comments on commit 3e6f073

Please sign in to comment.