Skip to content

Commit

Permalink
add cancellation token to transition check lambda (#3132)
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed Jul 15, 2024
1 parent 96146aa commit 178bb8d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
45 changes: 32 additions & 13 deletions dotnet/src/AutoGen.Core/GroupChat/Graph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core;
Expand All @@ -12,11 +13,7 @@ public class Graph
{
private readonly List<Transition> transitions = new List<Transition>();

public Graph()
{
}

public Graph(IEnumerable<Transition>? transitions)
public Graph(IEnumerable<Transition>? transitions = null)
{
if (transitions != null)
{
Expand All @@ -40,13 +37,13 @@ public void AddTransition(Transition transition)
/// <param name="fromAgent">the from agent</param>
/// <param name="messages">messages</param>
/// <returns>A list of agents that the messages can be transit to</returns>
public async Task<IEnumerable<IAgent>> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable<IMessage> messages)
public async Task<IEnumerable<IAgent>> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable<IMessage> messages, CancellationToken ct = default)
{
var nextAgents = new List<IAgent>();
var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty<Transition>();
foreach (var transition in availableTransitions)
{
if (await transition.CanTransitionAsync(messages))
if (await transition.CanTransitionAsync(messages, ct))
{
nextAgents.Add(transition.To);
}
Expand All @@ -63,7 +60,7 @@ public class Transition
{
private readonly IAgent _from;
private readonly IAgent _to;
private readonly Func<IAgent, IAgent, IEnumerable<IMessage>, Task<bool>>? _canTransition;
private readonly Func<IAgent, IAgent, IEnumerable<IMessage>, CancellationToken, Task<bool>>? _canTransition;

/// <summary>
/// Create a new instance of <see cref="Transition"/>.
Expand All @@ -73,22 +70,44 @@ public class Transition
/// <param name="from">from agent</param>
/// <param name="to">to agent</param>
/// <param name="canTransitionAsync">detect if the transition is allowed, default to be always true</param>
internal Transition(IAgent from, IAgent to, Func<IAgent, IAgent, IEnumerable<IMessage>, Task<bool>>? canTransitionAsync = null)
internal Transition(IAgent from, IAgent to, Func<IAgent, IAgent, IEnumerable<IMessage>, CancellationToken, Task<bool>>? canTransitionAsync = null)
{
_from = from;
_to = to;
_canTransition = canTransitionAsync;
}

/// <summary>
/// Create a new instance of <see cref="Transition"/> without transition condition check.
/// </summary>
/// <returns><see cref="Transition"/></returns>"
public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent to)
where TFromAgent : IAgent
where TToAgent : IAgent
{
return new Transition(from, to, (fromAgent, toAgent, messages, _) => Task.FromResult(true));
}

/// <summary>
/// Create a new instance of <see cref="Transition"/>.
/// </summary>
/// <returns><see cref="Transition"/></returns>"
public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent to, Func<TFromAgent, TToAgent, IEnumerable<IMessage>, Task<bool>>? canTransitionAsync = null)
public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent to, Func<TFromAgent, TToAgent, IEnumerable<IMessage>, Task<bool>> canTransitionAsync)
where TFromAgent : IAgent
where TToAgent : IAgent
{
return new Transition(from, to, (fromAgent, toAgent, messages, _) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages));
}

/// <summary>
/// Create a new instance of <see cref="Transition"/> with cancellation token.
/// </summary>
/// <returns><see cref="Transition"/></returns>"
public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent to, Func<TFromAgent, TToAgent, IEnumerable<IMessage>, CancellationToken, Task<bool>> canTransitionAsync)
where TFromAgent : IAgent
where TToAgent : IAgent
{
return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true));
return new Transition(from, to, (fromAgent, toAgent, messages, ct) => canTransitionAsync.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages, ct));
}

public IAgent From => _from;
Expand All @@ -99,13 +118,13 @@ public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent
/// Check if the transition is allowed.
/// </summary>
/// <param name="messages">messages</param>
public Task<bool> CanTransitionAsync(IEnumerable<IMessage> messages)
public Task<bool> CanTransitionAsync(IEnumerable<IMessage> messages, CancellationToken ct = default)
{
if (_canTransition == null)
{
return Task.FromResult(true);
}

return _canTransition(this.From, this.To, messages);
return _canTransition(this.From, this.To, messages, ct);
}
}
4 changes: 2 additions & 2 deletions dotnet/test/AutoGen.Tests/WorkflowTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public async Task TransitionTestAsync()
var alice = new EchoAgent("alice");
var bob = new EchoAgent("bob");

var aliceToBob = Transition.Create(alice, bob, async (from, to, messages) =>
var aliceToBob = Transition.Create(alice, bob, async (from, to, messages, _) =>
{
if (messages.Any(m => m.GetContent() == "Hello"))
{
Expand All @@ -30,7 +30,7 @@ public async Task TransitionTestAsync()
var canTransit = await aliceToBob.CanTransitionAsync([]);
canTransit.Should().BeFalse();

canTransit = await aliceToBob.CanTransitionAsync(new[] { new TextMessage(Role.Assistant, "Hello") });
canTransit = await aliceToBob.CanTransitionAsync([new TextMessage(Role.Assistant, "Hello")]);
canTransit.Should().BeTrue();

// if no function is provided, it should always return true
Expand Down

0 comments on commit 178bb8d

Please sign in to comment.