From 4e95630fa95111166d29d24fe47ffcdea059acbb Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Wed, 10 Jul 2024 15:12:42 -0700 Subject: [PATCH] [.Net] fix #2695 and #2884 (#3069) * add round robin orchestrator * add constructor for orchestrators * add tests * revert change * return single orchestrator * address comment --- dotnet/Directory.Build.props | 1 + .../Agent/AnthropicClientAgent.cs | 21 +- .../src/AutoGen.Anthropic/AnthropicClient.cs | 10 +- dotnet/src/AutoGen.Core/Agent/IAgent.cs | 6 +- .../Extension/GroupChatExtension.cs | 3 +- .../src/AutoGen.Core/GroupChat/GroupChat.cs | 78 +++-- .../{ => GroupChat}/IGroupChat.cs | 0 .../GroupChat/RoundRobinGroupChat.cs | 71 +--- .../Orchestrator/IOrchestrator.cs | 28 ++ .../Orchestrator/RolePlayOrchestrator.cs | 116 ++++++ .../Orchestrator/RoundRobinOrchestrator.cs | 45 +++ .../Orchestrator/WorkflowOrchestrator.cs | 53 +++ .../Agent/MistralClientAgent.cs | 1 + .../DTOs/ChatCompletionRequest.cs | 3 + .../AnthropicClientAgentTest.cs | 24 ++ .../test/AutoGen.Tests/AutoGen.Tests.csproj | 1 + .../Orchestrator/RolePlayOrchestratorTests.cs | 331 ++++++++++++++++++ .../RoundRobinOrchestratorTests.cs | 103 ++++++ .../Orchestrator/WorkflowOrchestratorTests.cs | 112 ++++++ 19 files changed, 901 insertions(+), 106 deletions(-) rename dotnet/src/AutoGen.Core/{ => GroupChat}/IGroupChat.cs (100%) create mode 100644 dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs create mode 100644 dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs create mode 100644 dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs create mode 100644 dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs create mode 100644 dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs create mode 100644 dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs create mode 100644 dotnet/test/AutoGen.Tests/Orchestrator/WorkflowOrchestratorTests.cs diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index 4b3e9441f1e..29e40fff384 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -31,6 +31,7 @@ + diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs index bf05ee97444..a3ecc1ccef6 100644 --- a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs +++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs @@ -67,7 +67,8 @@ private ChatCompletionRequest CreateParameters(IEnumerable messages, G Stream = shouldStream, Temperature = (decimal?)options?.Temperature ?? _temperature, Tools = _tools?.ToList(), - ToolChoice = _toolChoice ?? ToolChoice.Auto + ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null), + StopSequences = options?.StopSequence?.ToArray(), }; chatCompletionRequest.Messages = BuildMessages(messages); @@ -95,6 +96,22 @@ private List BuildMessages(IEnumerable messages) } } - return chatMessages; + // merge messages with the same role + // fixing #2884 + var mergedMessages = chatMessages.Aggregate(new List(), (acc, message) => + { + if (acc.Count > 0 && acc.Last().Role == message.Role) + { + acc.Last().Content.AddRange(message.Content); + } + else + { + acc.Add(message); + } + + return acc; + }); + + return mergedMessages; } } diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs index babcd5302aa..dd35638c4f3 100644 --- a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs +++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // AnthropicClient.cs using System; @@ -90,13 +90,7 @@ public async IAsyncEnumerable StreamingChatCompletionsAs { var res = await JsonSerializer.DeserializeAsync( new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), - cancellationToken: cancellationToken); - - if (res == null) - { - throw new Exception("Failed to deserialize response"); - } - + cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response"); if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) && currentEvent.ContentBlock != null) { diff --git a/dotnet/src/AutoGen.Core/Agent/IAgent.cs b/dotnet/src/AutoGen.Core/Agent/IAgent.cs index b9149008480..34a31055d1b 100644 --- a/dotnet/src/AutoGen.Core/Agent/IAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IAgent.cs @@ -7,10 +7,14 @@ using System.Threading.Tasks; namespace AutoGen.Core; -public interface IAgent + +public interface IAgentMetaInformation { public string Name { get; } +} +public interface IAgent : IAgentMetaInformation +{ /// /// Generate reply /// diff --git a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs index e3e44622c81..45728023b96 100644 --- a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs @@ -100,8 +100,7 @@ internal static IEnumerable ProcessConversationsForRolePlay( var msg = @$"From {x.From}: {x.GetContent()} -round # - {i}"; +round # {i}"; return new TextMessage(Role.User, content: msg); }); diff --git a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs index 5e82931ab65..57e15c18ca6 100644 --- a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs +++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs @@ -15,6 +15,7 @@ public class GroupChat : IGroupChat private List agents = new List(); private IEnumerable initializeMessages = new List(); private Graph? workflow = null; + private readonly IOrchestrator orchestrator; public IEnumerable? Messages { get; private set; } @@ -36,6 +37,37 @@ public GroupChat( this.initializeMessages = initializeMessages ?? new List(); this.workflow = workflow; + if (admin is not null) + { + this.orchestrator = new RolePlayOrchestrator(admin, workflow); + } + else if (workflow is not null) + { + this.orchestrator = new WorkflowOrchestrator(workflow); + } + else + { + this.orchestrator = new RoundRobinOrchestrator(); + } + + this.Validation(); + } + + /// + /// Create a group chat which uses the to decide the next speaker(s). + /// + /// + /// + /// + public GroupChat( + IEnumerable members, + IOrchestrator orchestrator, + IEnumerable? initializeMessages = null) + { + this.agents = members.ToList(); + this.initializeMessages = initializeMessages ?? new List(); + this.orchestrator = orchestrator; + this.Validation(); } @@ -64,12 +96,6 @@ private void Validation() throw new Exception("All agents in the workflow must be in the group chat."); } } - - // must provide one of admin or workflow - if (this.admin == null && this.workflow == null) - { - throw new Exception("Must provide one of admin or workflow."); - } } /// @@ -81,6 +107,7 @@ private void Validation() /// current speaker /// conversation history /// next speaker. + [Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")] public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory) { var agentNames = this.agents.Select(x => x.Name).ToList(); @@ -140,37 +167,40 @@ public void AddInitializeMessage(IMessage message) } public async Task> CallAsync( - IEnumerable? conversationWithName = null, + IEnumerable? chatHistory = null, int maxRound = 10, CancellationToken ct = default) { var conversationHistory = new List(); - if (conversationWithName != null) + conversationHistory.AddRange(this.initializeMessages); + if (chatHistory != null) { - conversationHistory.AddRange(conversationWithName); + conversationHistory.AddRange(chatHistory); } + var roundLeft = maxRound; - var lastSpeaker = conversationHistory.LastOrDefault()?.From switch + while (roundLeft > 0) { - null => this.agents.First(), - _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"), - }; - var round = 0; - while (round < maxRound) - { - var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory); - var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory); - var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned."); + var orchestratorContext = new OrchestrationContext + { + Candidates = this.agents, + ChatHistory = conversationHistory, + }; + var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct); + if (nextSpeaker == null) + { + break; + } + + var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct); conversationHistory.Add(result); - // if message is terminate message, then terminate the conversation - if (result?.IsGroupChatTerminateMessage() ?? false) + if (result.IsGroupChatTerminateMessage()) { - break; + return conversationHistory; } - lastSpeaker = currentSpeaker; - round++; + roundLeft--; } return conversationHistory; diff --git a/dotnet/src/AutoGen.Core/IGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs similarity index 100% rename from dotnet/src/AutoGen.Core/IGroupChat.cs rename to dotnet/src/AutoGen.Core/GroupChat/IGroupChat.cs diff --git a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs index b8de89b834f..b95cd1958fc 100644 --- a/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs +++ b/dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs @@ -3,9 +3,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; namespace AutoGen.Core; @@ -25,76 +22,12 @@ public SequentialGroupChat(IEnumerable agents, List? initializ /// /// A group chat that allows agents to talk in a round-robin manner. /// -public class RoundRobinGroupChat : IGroupChat +public class RoundRobinGroupChat : GroupChat { - private readonly List agents = new List(); - private readonly List initializeMessages = new List(); - public RoundRobinGroupChat( IEnumerable agents, List? initializeMessages = null) + : base(agents, initializeMessages: initializeMessages) { - this.agents.AddRange(agents); - this.initializeMessages = initializeMessages ?? new List(); - } - - /// - public void AddInitializeMessage(IMessage message) - { - this.SendIntroduction(message); - } - - public async Task> CallAsync( - IEnumerable? conversationWithName = null, - int maxRound = 10, - CancellationToken ct = default) - { - var conversationHistory = new List(); - if (conversationWithName != null) - { - conversationHistory.AddRange(conversationWithName); - } - - var lastSpeaker = conversationHistory.LastOrDefault()?.From switch - { - null => this.agents.First(), - _ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"), - }; - var round = 0; - while (round < maxRound) - { - var currentSpeaker = this.SelectNextSpeaker(lastSpeaker); - var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory); - var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned."); - conversationHistory.Add(result); - - // if message is terminate message, then terminate the conversation - if (result?.IsGroupChatTerminateMessage() ?? false) - { - break; - } - - lastSpeaker = currentSpeaker; - round++; - } - - return conversationHistory; - } - - public void SendIntroduction(IMessage message) - { - this.initializeMessages.Add(message); - } - - private IAgent SelectNextSpeaker(IAgent currentSpeaker) - { - var index = this.agents.IndexOf(currentSpeaker); - if (index == -1) - { - throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker)); - } - - var nextIndex = (index + 1) % this.agents.Count; - return this.agents[nextIndex]; } } diff --git a/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs new file mode 100644 index 00000000000..777834871f6 --- /dev/null +++ b/dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IOrchestrator.cs + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace AutoGen.Core; + +public class OrchestrationContext +{ + public IEnumerable Candidates { get; set; } = Array.Empty(); + + public IEnumerable ChatHistory { get; set; } = Array.Empty(); +} + +public interface IOrchestrator +{ + /// + /// Return the next agent as the next speaker. return null if no agent is selected. + /// + /// orchestration context, such as candidate agents and chat history. + /// cancellation token + public Task GetNextSpeakerAsync( + OrchestrationContext context, + CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs new file mode 100644 index 00000000000..6798f23f2df --- /dev/null +++ b/dotnet/src/AutoGen.Core/Orchestrator/RolePlayOrchestrator.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RolePlayOrchestrator.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace AutoGen.Core; + +public class RolePlayOrchestrator : IOrchestrator +{ + private readonly IAgent admin; + private readonly Graph? workflow = null; + public RolePlayOrchestrator(IAgent admin, Graph? workflow = null) + { + this.admin = admin; + this.workflow = workflow; + } + + public async Task GetNextSpeakerAsync( + OrchestrationContext context, + CancellationToken cancellationToken = default) + { + var candidates = context.Candidates.ToList(); + + if (candidates.Count == 0) + { + return null; + } + + if (candidates.Count == 1) + { + return candidates.First(); + } + + // if there's a workflow + // and the next available agent from the workflow is in the group chat + // then return the next agent from the workflow + if (this.workflow != null) + { + var lastMessage = context.ChatHistory.LastOrDefault(); + if (lastMessage == null) + { + return null; + } + var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From); + var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory); + nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name)); + candidates = nextAgents.ToList(); + if (!candidates.Any()) + { + return null; + } + + if (candidates is { Count: 1 }) + { + return candidates.First(); + } + } + + // In this case, since there are more than one available agents from the workflow for the next speaker + // the admin will be invoked to decide the next speaker + var agentNames = candidates.Select(candidate => candidate.Name); + var rolePlayMessage = new TextMessage(Role.User, + content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation. +The available roles are: +{string.Join(",", agentNames)} + +Each message will start with 'From name:', e.g: +From {agentNames.First()}: +//your message//."); + + var chatHistoryWithName = this.ProcessConversationsForRolePlay(context.ChatHistory); + var messages = new IMessage[] { rolePlayMessage }.Concat(chatHistoryWithName); + + var response = await this.admin.GenerateReplyAsync( + messages: messages, + options: new GenerateReplyOptions + { + Temperature = 0, + MaxToken = 128, + StopSequence = [":"], + Functions = null, + }, + cancellationToken: cancellationToken); + + var name = response.GetContent() ?? throw new Exception("No name is returned."); + + // remove From + name = name!.Substring(5); + var candidate = candidates.FirstOrDefault(x => x.Name!.ToLower() == name.ToLower()); + + if (candidate != null) + { + return candidate; + } + + var errorMessage = $"The response from admin is {name}, which is either not in the candidates list or not in the correct format."; + throw new Exception(errorMessage); + } + + private IEnumerable ProcessConversationsForRolePlay(IEnumerable messages) + { + return messages.Select((x, i) => + { + var msg = @$"From {x.From}: +{x.GetContent()} + +round # {i}"; + + return new TextMessage(Role.User, content: msg); + }); + } +} diff --git a/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs new file mode 100644 index 00000000000..0f8b8e483c6 --- /dev/null +++ b/dotnet/src/AutoGen.Core/Orchestrator/RoundRobinOrchestrator.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RoundRobinOrchestrator.cs + +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace AutoGen.Core; + +/// +/// Return the next agent in a round-robin fashion. +/// +/// If the last message is from one of the candidates, the next agent will be the next candidate in the list. +/// +/// +/// Otherwise, no agent will be selected. In this case, the orchestrator will return an empty list. +/// +/// +/// This orchestrator always return a single agent. +/// +/// +public class RoundRobinOrchestrator : IOrchestrator +{ + public async Task GetNextSpeakerAsync( + OrchestrationContext context, + CancellationToken cancellationToken = default) + { + var lastMessage = context.ChatHistory.LastOrDefault(); + + if (lastMessage == null) + { + return null; + } + + var candidates = context.Candidates.ToList(); + var lastAgentIndex = candidates.FindIndex(a => a.Name == lastMessage.From); + if (lastAgentIndex == -1) + { + return null; + } + + var nextAgentIndex = (lastAgentIndex + 1) % candidates.Count; + return candidates[nextAgentIndex]; + } +} diff --git a/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs new file mode 100644 index 00000000000..b84850a07c7 --- /dev/null +++ b/dotnet/src/AutoGen.Core/Orchestrator/WorkflowOrchestrator.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// WorkflowOrchestrator.cs + +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace AutoGen.Core; + +public class WorkflowOrchestrator : IOrchestrator +{ + private readonly Graph workflow; + + public WorkflowOrchestrator(Graph workflow) + { + this.workflow = workflow; + } + + public async Task GetNextSpeakerAsync( + OrchestrationContext context, + CancellationToken cancellationToken = default) + { + var lastMessage = context.ChatHistory.LastOrDefault(); + if (lastMessage == null) + { + return null; + } + + var candidates = context.Candidates.ToList(); + var currentSpeaker = candidates.FirstOrDefault(candidates => candidates.Name == lastMessage.From); + + if (currentSpeaker == null) + { + return null; + } + var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory); + nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name)); + candidates = nextAgents.ToList(); + if (!candidates.Any()) + { + return null; + } + + if (candidates is { Count: 1 }) + { + return candidates.First(); + } + else + { + throw new System.Exception("There are more than one available agents from the workflow for the next speaker."); + } + } +} diff --git a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs index ac144854fac..db14d68a121 100644 --- a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs +++ b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs @@ -97,6 +97,7 @@ private ChatCompletionRequest BuildChatRequest(IEnumerable messages, G var chatHistory = BuildChatHistory(messages); var chatRequest = new ChatCompletionRequest(model: _model, messages: chatHistory.ToList(), temperature: options?.Temperature, randomSeed: _randomSeed) { + Stop = options?.StopSequence, MaxTokens = options?.MaxToken, ResponseFormat = _jsonOutput ? new ResponseFormat() { ResponseFormatType = "json_object" } : null, }; diff --git a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs index 71a084673f1..affe2bb6dcc 100644 --- a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs +++ b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionRequest.cs @@ -105,6 +105,9 @@ public class ChatCompletionRequest [JsonPropertyName("random_seed")] public int? RandomSeed { get; set; } + [JsonPropertyName("stop")] + public string[]? Stop { get; set; } + [JsonPropertyName("tools")] public List? Tools { get; set; } diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs index 49cbb54af31..552408f1d05 100644 --- a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -32,6 +32,30 @@ public async Task AnthropicAgentChatCompletionTestAsync() reply.From.Should().Be(agent.Name); } + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentMergeMessageWithSameRoleTests() + { + // this test is added to fix issue #2884 + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that convert user message to upper case") + .RegisterMessageConnector(); + + var uppCaseMessage = new TextMessage(Role.User, "abcdefg"); + var anotherUserMessage = new TextMessage(Role.User, "hijklmn"); + var assistantMessage = new TextMessage(Role.Assistant, "opqrst"); + var anotherAssistantMessage = new TextMessage(Role.Assistant, "uvwxyz"); + var yetAnotherUserMessage = new TextMessage(Role.User, "123456"); + + // just make sure it doesn't throw exception + var reply = await agent.SendAsync(chatHistory: [uppCaseMessage, anotherUserMessage, assistantMessage, anotherAssistantMessage, yetAnotherUserMessage]); + reply.GetContent().Should().NotBeNull(); + } + [ApiKeyFact("ANTHROPIC_API_KEY")] public async Task AnthropicAgentTestProcessImageAsync() { diff --git a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj index 4def281ed7b..3dc669b5edd 100644 --- a/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj +++ b/dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj @@ -9,6 +9,7 @@ + diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs new file mode 100644 index 00000000000..eea0a90c1ca --- /dev/null +++ b/dotnet/test/AutoGen.Tests/Orchestrator/RolePlayOrchestratorTests.cs @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RolePlayOrchestratorTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Anthropic; +using AutoGen.Anthropic.Extensions; +using AutoGen.Anthropic.Utils; +using AutoGen.Gemini; +using AutoGen.Mistral; +using AutoGen.Mistral.Extension; +using AutoGen.OpenAI; +using AutoGen.OpenAI.Extension; +using Azure.AI.OpenAI; +using FluentAssertions; +using Moq; +using Xunit; + +namespace AutoGen.Tests; + +public class RolePlayOrchestratorTests +{ + [Fact] + public async Task ItReturnNextSpeakerTestAsync() + { + var admin = Mock.Of(); + Mock.Get(admin).Setup(x => x.Name).Returns("Admin"); + Mock.Get(admin).Setup(x => x.GenerateReplyAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, GenerateReplyOptions, CancellationToken>((messages, option, _) => + { + // verify prompt + var rolePlayPrompt = messages.First().GetContent(); + rolePlayPrompt.Should().Contain("You are in a role play game. Carefully read the conversation history and carry on the conversation"); + rolePlayPrompt.Should().Contain("The available roles are:"); + rolePlayPrompt.Should().Contain("Alice,Bob"); + rolePlayPrompt.Should().Contain("From Alice:"); + option.StopSequence.Should().BeEquivalentTo([":"]); + option.Temperature.Should().Be(0); + option.MaxToken.Should().Be(128); + option.Functions.Should().BeNull(); + }) + .ReturnsAsync(new TextMessage(Role.Assistant, "From Alice")); + + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + + var orchestrator = new RolePlayOrchestrator(admin); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = [], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(alice); + } + + [Fact] + public async Task ItReturnNullWhenNoCandidateIsAvailableAsync() + { + var admin = Mock.Of(); + var orchestrator = new RolePlayOrchestrator(admin); + var context = new OrchestrationContext + { + Candidates = [], + ChatHistory = [], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().BeNull(); + } + + [Fact] + public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync() + { + var admin = Mock.Of(); + var alice = new EchoAgent("Alice"); + var orchestrator = new RolePlayOrchestrator(admin); + var context = new OrchestrationContext + { + Candidates = [alice], + ChatHistory = [], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(alice); + } + + [Fact] + public async Task ItThrowExceptionWhenAdminFailsToFollowPromptAsync() + { + var admin = Mock.Of(); + Mock.Get(admin).Setup(x => x.Name).Returns("Admin"); + Mock.Get(admin).Setup(x => x.GenerateReplyAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new TextMessage(Role.Assistant, "I don't know")); // admin fails to follow the prompt and returns an invalid message + + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + + var orchestrator = new RolePlayOrchestrator(admin); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = [], + }; + + var action = async () => await orchestrator.GetNextSpeakerAsync(context); + + await action.Should().ThrowAsync() + .WithMessage("The response from admin is 't know, which is either not in the candidates list or not in the correct format."); + } + + [Fact] + public async Task ItSelectNextSpeakerFromWorkflowIfProvided() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + var charlie = new EchoAgent("Charlie"); + workflow.AddTransition(Transition.Create(alice, bob)); + workflow.AddTransition(Transition.Create(bob, charlie)); + workflow.AddTransition(Transition.Create(charlie, alice)); + + var admin = Mock.Of(); + var orchestrator = new RolePlayOrchestrator(admin, workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob, charlie], + ChatHistory = + [ + new TextMessage(Role.User, "Hello, Bob", from: "Alice"), + ], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(bob); + } + + [Fact] + public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + workflow.AddTransition(Transition.Create(alice, bob)); + + var admin = Mock.Of(); + var orchestrator = new RolePlayOrchestrator(admin, workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = + [ + new TextMessage(Role.User, "Hello, Alice", from: "Bob"), + ], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().BeNull(); + } + + [Fact] + public async Task ItUseCandidatesFromWorflowAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + var charlie = new EchoAgent("Charlie"); + workflow.AddTransition(Transition.Create(alice, bob)); + workflow.AddTransition(Transition.Create(alice, charlie)); + + var admin = Mock.Of(); + Mock.Get(admin).Setup(x => x.GenerateReplyAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, GenerateReplyOptions, CancellationToken>((messages, option, _) => + { + messages.First().IsSystemMessage().Should().BeTrue(); + + // verify prompt + var rolePlayPrompt = messages.First().GetContent(); + rolePlayPrompt.Should().Contain("Bob,Charlie"); + rolePlayPrompt.Should().Contain("From Bob:"); + option.StopSequence.Should().BeEquivalentTo([":"]); + option.Temperature.Should().Be(0); + option.MaxToken.Should().Be(128); + option.Functions.Should().BeEmpty(); + }) + .ReturnsAsync(new TextMessage(Role.Assistant, "From Bob")); + var orchestrator = new RolePlayOrchestrator(admin, workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = + [ + new TextMessage(Role.User, "Hello, Bob", from: "Alice"), + ], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(bob); + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task GPT_3_5_CoderReviewerRunnerTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + modelName: deployName) + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(openAIChatAgent); + } + + [ApiKeyFact("GOOGLE_GEMINI_API_KEY")] + public async Task GoogleGemini_1_5_flash_001_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set"); + var geminiAgent = new GeminiChatAgent( + name: "gemini", + model: "gemini-1.5-flash-001", + apiKey: apiKey) + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(geminiAgent); + } + + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task Claude3_Haiku_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Please set ANTHROPIC_API_KEY environment variable."); + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey); + + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku, + systemMessage: "You are a helpful AI assistant that convert user message to upper case") + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(agent); + } + + [ApiKeyFact("MISTRAL_API_KEY")] + public async Task Mistra_7b_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new InvalidOperationException("MISTRAL_API_KEY is not set."); + var client = new MistralClient(apiKey: apiKey); + + var agent = new MistralClientAgent( + client: client, + name: "MistralClientAgent", + model: "open-mistral-7b") + .RegisterMessageConnector(); + + await CoderReviewerRunnerTestAsync(agent); + } + + /// + /// This test is to mimic the conversation among coder, reviewer and runner. + /// The coder will write the code, the reviewer will review the code, and the runner will run the code. + /// + /// + /// + public async Task CoderReviewerRunnerTestAsync(IAgent admin) + { + var coder = new EchoAgent("Coder"); + var reviewer = new EchoAgent("Reviewer"); + var runner = new EchoAgent("Runner"); + var user = new EchoAgent("User"); + var initializeMessage = new List + { + new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"), + new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"), + new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"), + new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"), + new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name), + }; + + var chatHistory = new List() + { + new TextMessage(Role.User, """ + ```csharp + Console.WriteLine("Hello World"); + ``` + """, from: coder.Name), + new TextMessage(Role.User, "The code looks good", from: reviewer.Name), + new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name), + }; + + var orchestrator = new RolePlayOrchestrator(admin); + foreach (var message in chatHistory) + { + var context = new OrchestrationContext + { + Candidates = [coder, reviewer, runner, user], + ChatHistory = initializeMessage, + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker!.Name.Should().Be(message.From); + initializeMessage.Add(message); + } + + // the last next speaker should be the user + var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext + { + Candidates = [coder, reviewer, runner, user], + ChatHistory = initializeMessage, + }); + + lastSpeaker!.Name.Should().Be(user.Name); + } +} diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs new file mode 100644 index 00000000000..e14bf85cf21 --- /dev/null +++ b/dotnet/test/AutoGen.Tests/Orchestrator/RoundRobinOrchestratorTests.cs @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RoundRobinOrchestratorTests.cs + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Tests; + +public class RoundRobinOrchestratorTests +{ + [Fact] + public async Task ItReturnNextAgentAsync() + { + var orchestrator = new RoundRobinOrchestrator(); + var context = new OrchestrationContext + { + Candidates = new List + { + new EchoAgent("Alice"), + new EchoAgent("Bob"), + new EchoAgent("Charlie"), + }, + }; + + var messages = new List + { + new TextMessage(Role.User, "Hello, Alice", from: "Alice"), + new TextMessage(Role.User, "Hello, Bob", from: "Bob"), + new TextMessage(Role.User, "Hello, Charlie", from: "Charlie"), + }; + + var expected = new List { "Bob", "Charlie", "Alice" }; + + var zip = messages.Zip(expected); + + foreach (var (msg, expect) in zip) + { + context.ChatHistory = [msg]; + var nextSpeaker = await orchestrator.GetNextSpeakerAsync(context); + Assert.Equal(expect, nextSpeaker!.Name); + } + } + + [Fact] + public async Task ItReturnNullIfNoCandidates() + { + var orchestrator = new RoundRobinOrchestrator(); + var context = new OrchestrationContext + { + Candidates = new List(), + ChatHistory = new List + { + new TextMessage(Role.User, "Hello, Alice", from: "Alice"), + }, + }; + + var result = await orchestrator.GetNextSpeakerAsync(context); + Assert.Null(result); + } + + [Fact] + public async Task ItReturnNullIfLastMessageIsNotFromCandidates() + { + var orchestrator = new RoundRobinOrchestrator(); + var context = new OrchestrationContext + { + Candidates = new List + { + new EchoAgent("Alice"), + new EchoAgent("Bob"), + new EchoAgent("Charlie"), + }, + ChatHistory = new List + { + new TextMessage(Role.User, "Hello, David", from: "David"), + }, + }; + + var result = await orchestrator.GetNextSpeakerAsync(context); + result.Should().BeNull(); + } + + [Fact] + public async Task ItReturnEmptyListIfNoChatHistory() + { + var orchestrator = new RoundRobinOrchestrator(); + var context = new OrchestrationContext + { + Candidates = new List + { + new EchoAgent("Alice"), + new EchoAgent("Bob"), + new EchoAgent("Charlie"), + }, + }; + + var result = await orchestrator.GetNextSpeakerAsync(context); + result.Should().BeNull(); + } +} diff --git a/dotnet/test/AutoGen.Tests/Orchestrator/WorkflowOrchestratorTests.cs b/dotnet/test/AutoGen.Tests/Orchestrator/WorkflowOrchestratorTests.cs new file mode 100644 index 00000000000..6599566a446 --- /dev/null +++ b/dotnet/test/AutoGen.Tests/Orchestrator/WorkflowOrchestratorTests.cs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// WorkflowOrchestratorTests.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Tests; + +public class WorkflowOrchestratorTests +{ + [Fact] + public async Task ItReturnNextAgentAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + var charlie = new EchoAgent("Charlie"); + workflow.AddTransition(Transition.Create(alice, bob)); + workflow.AddTransition(Transition.Create(bob, charlie)); + workflow.AddTransition(Transition.Create(charlie, alice)); + var orchestrator = new WorkflowOrchestrator(workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob, charlie] + }; + + var messages = new List + { + new TextMessage(Role.User, "Hello, Alice", from: "Alice"), + new TextMessage(Role.User, "Hello, Bob", from: "Bob"), + new TextMessage(Role.User, "Hello, Charlie", from: "Charlie"), + }; + + var expected = new List { "Bob", "Charlie", "Alice" }; + + var zip = messages.Zip(expected); + + foreach (var (msg, expect) in zip) + { + context.ChatHistory = [msg]; + var result = await orchestrator.GetNextSpeakerAsync(context); + Assert.Equal(expect, result!.Name); + } + } + + [Fact] + public async Task ItReturnNullIfNoCandidates() + { + var workflow = new Graph(); + var orchestrator = new WorkflowOrchestrator(workflow); + var context = new OrchestrationContext + { + Candidates = new List(), + ChatHistory = new List + { + new TextMessage(Role.User, "Hello, Alice", from: "Alice"), + }, + }; + + var nextAgent = await orchestrator.GetNextSpeakerAsync(context); + nextAgent.Should().BeNull(); + } + + [Fact] + public async Task ItReturnNullIfNoAgentIsAvailableFromWorkflowAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + workflow.AddTransition(Transition.Create(alice, bob)); + var orchestrator = new WorkflowOrchestrator(workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = new List + { + new TextMessage(Role.User, "Hello, Bob", from: "Bob"), + }, + }; + + var nextSpeaker = await orchestrator.GetNextSpeakerAsync(context); + nextSpeaker.Should().BeNull(); + } + + [Fact] + public async Task ItThrowExceptionWhenMoreThanOneAvailableAgentsFromWorkflowAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + var charlie = new EchoAgent("Charlie"); + workflow.AddTransition(Transition.Create(alice, bob)); + workflow.AddTransition(Transition.Create(alice, charlie)); + var orchestrator = new WorkflowOrchestrator(workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob, charlie], + ChatHistory = new List + { + new TextMessage(Role.User, "Hello, Bob", from: "Alice"), + }, + }; + + var action = async () => await orchestrator.GetNextSpeakerAsync(context); + + await action.Should().ThrowExactlyAsync().WithMessage("There are more than one available agents from the workflow for the next speaker."); + } +}