Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] fix #2722 #2723

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,15 @@ public static async Task RunAsync()
calculateTax.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// parallel function calls
var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
calculateTaxes.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTaxes.GetToolCalls().Should().HaveCount(2);
calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// send aggregate message back to llm to get the final result
var finalResult = await agent.SendAsync(calculateTaxes);
}
}
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public static bool IsSystemMessage(this IMessage message)
TextMessage textMessage => textMessage.Content,
Message msg => msg.Content,
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)),
_ => null,
};
}
Expand Down
2 changes: 2 additions & 0 deletions dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public ToolCall(string functionName, string functionArgs, string result)

public string FunctionArguments { get; set; }

public string? ToolCallId { get; set; }

public string? Result { get; set; }

public override string ToString()
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ private async Task<ToolCallResultMessage> InvokeToolCallMessagesBeforeInvokingAg
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";

toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
}
else
{
Expand All @@ -156,7 +156,7 @@ private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolC
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
toolCallResult.Add(new ToolCall(fName, fArgs, result));
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private IMessage PostProcessChatResponseMessage(ChatResponseMessage chatResponse
.Where(tc => tc is ChatCompletionsFunctionToolCall)
.Select(tc => (ChatCompletionsFunctionToolCall)tc);

var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });

return new ToolCallMessage(toolCalls, from);
}
Expand Down Expand Up @@ -322,7 +322,7 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallMessage(IAgent agent, Too
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}

var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
foreach (var tc in toolCall)
{
Expand All @@ -336,7 +336,7 @@ private IEnumerable<ChatRequestMessage> ProcessToolCallResultMessage(ToolCallRes
{
return message.ToolCalls
.Where(tc => tc.Result is not null)
.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
.Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
}

private IEnumerable<ChatRequestMessage> ProcessMessage(IAgent agent, Message message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
Expand All @@ -159,7 +159,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
},
Expand All @@ -169,12 +169,12 @@
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_0"
},
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_1"
}
]
},
Expand All @@ -190,13 +190,13 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
},
{
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_1"
}
],
"FunctionCallName": null,
Expand All @@ -216,7 +216,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
Expand All @@ -225,7 +225,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
}
Expand Down
119 changes: 118 additions & 1 deletion dotnet/test/AutoGen.OpenAI.Tests/OpenAIMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ public async Task ItProcessToolCallMessageAsync()
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
Expand All @@ -303,6 +304,41 @@ public async Task ItProcessToolCallMessageAsync()
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessParallelToolCallMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test"),
new ToolCall("test", "test"),
};
IMessage message = new ToolCallMessage(toolCalls, "assistant");
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
{
Expand All @@ -326,7 +362,7 @@ public async Task ItProcessToolCallResultMessageAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test");
chatRequestMessage.ToolCallId.Should().Be("test_0");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
Expand All @@ -336,6 +372,37 @@ public async Task ItProcessToolCallResultMessageAsync()
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessParallelToolCallResultMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(2);

for (int i = 0; i < msgs.Count(); i++)
{
var innerMessage = msgs.ElementAt(i);
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
IMessage message = new ToolCallResultMessage(toolCalls, "user");
await agent.GenerateReplyAsync([message]);
}

[Fact]
public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
{
Expand Down Expand Up @@ -372,6 +439,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test_0");

var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
Expand All @@ -381,6 +449,7 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
toolCallRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
Expand All @@ -393,6 +462,54 @@ public async Task ItProcessFunctionCallMiddlewareMessageFromAssistantAsync()
await agent.GenerateReplyAsync([aggregateMessage]);
}

[Fact]
public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(3);
var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)toolCallMessage!).Content;
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
toolCallRequestMessage.ToolCalls.Count().Should().Be(2);

for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
{
toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}

for (int i = 1; i < msgs.Count(); i++)
{
var toolCallResultMessage = msgs.ElementAt(i);
toolCallResultMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)toolCallResultMessage!).Content;
toolCallResultRequestMessage.Content.Should().Be("result");
toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
}

return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);

// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
var aggregateMessage = new AggregateMessage<ToolCallMessage, ToolCallResultMessage>(toolCallMessage, toolCallResultMessage, "assistant");
await agent.GenerateReplyAsync([aggregateMessage]);
}

[Fact]
public async Task ItConvertChatResponseMessageToTextMessageAsync()
{
Expand Down
Loading