Skip to content

Commit

Permalink
[.Net] fix #2695 and #2884 (#3069)
Browse files Browse the repository at this point in the history
* add round robin orchestrator

* add constructor for orchestrators

* add tests

* revert change

* return single orchestrator

* address comment
  • Loading branch information
LittleLittleCloud committed Jul 10, 2024
1 parent f55a98f commit 4e95630
Show file tree
Hide file tree
Showing 19 changed files with 901 additions and 106 deletions.
1 change: 1 addition & 0 deletions dotnet/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
<PackageReference Include="Moq" Version="4.20.70" />
</ItemGroup>

<ItemGroup Condition="'$(IsTestProject)' == 'true'">
Expand Down
21 changes: 19 additions & 2 deletions dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
Tools = _tools?.ToList(),
ToolChoice = _toolChoice ?? ToolChoice.Auto
ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
StopSequences = options?.StopSequence?.ToArray(),
};

chatCompletionRequest.Messages = BuildMessages(messages);
Expand Down Expand Up @@ -95,6 +96,22 @@ private List<ChatMessage> BuildMessages(IEnumerable<IMessage> messages)
}
}

return chatMessages;
// merge messages with the same role
// fixing #2884
var mergedMessages = chatMessages.Aggregate(new List<ChatMessage>(), (acc, message) =>
{
if (acc.Count > 0 && acc.Last().Role == message.Role)
{
acc.Last().Content.AddRange(message.Content);
}
else
{
acc.Add(message);
}
return acc;
});

return mergedMessages;
}
}
10 changes: 2 additions & 8 deletions dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClient.cs

using System;
Expand Down Expand Up @@ -90,13 +90,7 @@ public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAs
{
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken);

if (res == null)
{
throw new Exception("Failed to deserialize response");
}

cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
currentEvent.ContentBlock != null)
{
Expand Down
6 changes: 5 additions & 1 deletion dotnet/src/AutoGen.Core/Agent/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
using System.Threading.Tasks;

namespace AutoGen.Core;
public interface IAgent

public interface IAgentMetaInformation
{
public string Name { get; }
}

public interface IAgent : IAgentMetaInformation
{
/// <summary>
/// Generate reply
/// </summary>
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ internal static IEnumerable<IMessage> ProcessConversationsForRolePlay(
var msg = @$"From {x.From}:
{x.GetContent()}
<eof_msg>
round #
{i}";
round # {i}";
return new TextMessage(Role.User, content: msg);
});
Expand Down
78 changes: 54 additions & 24 deletions dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class GroupChat : IGroupChat
private List<IAgent> agents = new List<IAgent>();
private IEnumerable<IMessage> initializeMessages = new List<IMessage>();
private Graph? workflow = null;
private readonly IOrchestrator orchestrator;

public IEnumerable<IMessage>? Messages { get; private set; }

Expand All @@ -36,6 +37,37 @@ public GroupChat(
this.initializeMessages = initializeMessages ?? new List<IMessage>();
this.workflow = workflow;

if (admin is not null)
{
this.orchestrator = new RolePlayOrchestrator(admin, workflow);
}
else if (workflow is not null)
{
this.orchestrator = new WorkflowOrchestrator(workflow);
}
else
{
this.orchestrator = new RoundRobinOrchestrator();
}

this.Validation();
}

/// <summary>
/// Create a group chat which uses the <paramref name="orchestrator"/> to decide the next speaker(s).
/// </summary>
/// <param name="members"></param>
/// <param name="orchestrator"></param>
/// <param name="initializeMessages"></param>
public GroupChat(
IEnumerable<IAgent> members,
IOrchestrator orchestrator,
IEnumerable<IMessage>? initializeMessages = null)
{
this.agents = members.ToList();
this.initializeMessages = initializeMessages ?? new List<IMessage>();
this.orchestrator = orchestrator;

this.Validation();
}

Expand Down Expand Up @@ -64,12 +96,6 @@ private void Validation()
throw new Exception("All agents in the workflow must be in the group chat.");
}
}

// must provide one of admin or workflow
if (this.admin == null && this.workflow == null)
{
throw new Exception("Must provide one of admin or workflow.");
}
}

/// <summary>
Expand All @@ -81,6 +107,7 @@ private void Validation()
/// <param name="currentSpeaker">current speaker</param>
/// <param name="conversationHistory">conversation history</param>
/// <returns>next speaker.</returns>
[Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")]
public async Task<IAgent> SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable<IMessage> conversationHistory)
{
var agentNames = this.agents.Select(x => x.Name).ToList();
Expand Down Expand Up @@ -140,37 +167,40 @@ public void AddInitializeMessage(IMessage message)
}

public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
IEnumerable<IMessage>? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
conversationHistory.AddRange(this.initializeMessages);
if (chatHistory != null)
{
conversationHistory.AddRange(conversationWithName);
conversationHistory.AddRange(chatHistory);
}
var roundLeft = maxRound;

var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
while (roundLeft > 0)
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
var orchestratorContext = new OrchestrationContext
{
Candidates = this.agents,
ChatHistory = conversationHistory,
};
var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct);
if (nextSpeaker == null)
{
break;
}

var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct);
conversationHistory.Add(result);

// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
if (result.IsGroupChatTerminateMessage())
{
break;
return conversationHistory;
}

lastSpeaker = currentSpeaker;
round++;
roundLeft--;
}

return conversationHistory;
Expand Down
File renamed without changes.
71 changes: 2 additions & 69 deletions dotnet/src/AutoGen.Core/GroupChat/RoundRobinGroupChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core;

Expand All @@ -25,76 +22,12 @@ public SequentialGroupChat(IEnumerable<IAgent> agents, List<IMessage>? initializ
/// <summary>
/// A group chat that allows agents to talk in a round-robin manner.
/// </summary>
public class RoundRobinGroupChat : IGroupChat
public class RoundRobinGroupChat : GroupChat
{
private readonly List<IAgent> agents = new List<IAgent>();
private readonly List<IMessage> initializeMessages = new List<IMessage>();

public RoundRobinGroupChat(
IEnumerable<IAgent> agents,
List<IMessage>? initializeMessages = null)
: base(agents, initializeMessages: initializeMessages)
{
this.agents.AddRange(agents);
this.initializeMessages = initializeMessages ?? new List<IMessage>();
}

/// <inheritdoc />
public void AddInitializeMessage(IMessage message)
{
this.SendIntroduction(message);
}

public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
{
conversationHistory.AddRange(conversationWithName);
}

var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
conversationHistory.Add(result);

// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
{
break;
}

lastSpeaker = currentSpeaker;
round++;
}

return conversationHistory;
}

public void SendIntroduction(IMessage message)
{
this.initializeMessages.Add(message);
}

private IAgent SelectNextSpeaker(IAgent currentSpeaker)
{
var index = this.agents.IndexOf(currentSpeaker);
if (index == -1)
{
throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
}

var nextIndex = (index + 1) % this.agents.Count;
return this.agents[nextIndex];
}
}
28 changes: 28 additions & 0 deletions dotnet/src/AutoGen.Core/Orchestrator/IOrchestrator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IOrchestrator.cs

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core;

public class OrchestrationContext
{
public IEnumerable<IAgent> Candidates { get; set; } = Array.Empty<IAgent>();

public IEnumerable<IMessage> ChatHistory { get; set; } = Array.Empty<IMessage>();
}

public interface IOrchestrator
{
/// <summary>
/// Return the next agent as the next speaker. return null if no agent is selected.
/// </summary>
/// <param name="context">orchestration context, such as candidate agents and chat history.</param>
/// <param name="cancellationToken">cancellation token</param>
public Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default);
}
Loading

0 comments on commit 4e95630

Please sign in to comment.