Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Mark Message as obsolete and add ToolCallAggregateMessage type #2716

Merged
merged 11 commits into from
May 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public async Task CreateOpenAIChatAgentAsync()
new TextMessage(Role.Assistant, "Hello", from: "user"),
],
from: "user"),
new Message(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
new TextMessage(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
};

foreach (var message in messages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@ public static async Task RunAsync()
// talk to the assistant agent
var upperCase = await agent.SendAsync("convert to upper case: hello world");
upperCase.GetContent()?.Should().Be("HELLO WORLD");
upperCase.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
upperCase.Should().BeOfType<ToolCallAggregateMessage>();
upperCase.GetToolCalls().Should().HaveCount(1);
upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase));

var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e");
concatString.GetContent()?.Should().Be("a b c d e");
concatString.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
concatString.Should().BeOfType<ToolCallAggregateMessage>();
concatString.GetToolCalls().Should().HaveCount(1);
concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString));

var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1");
calculateTax.GetContent().Should().Be("tax is 10");
calculateTax.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTax.Should().BeOfType<ToolCallAggregateMessage>();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
}
Expand Down
74 changes: 28 additions & 46 deletions dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Example05_Dalle_And_GPT4V.cs

using AutoGen;
using AutoGen.Core;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
using FluentAssertions;
using autogen = AutoGen.LLMConfigAPI;
Expand Down Expand Up @@ -66,50 +67,39 @@ public static async Task RunAsync()
File.Delete(imagePath);
}

var dalleAgent = new AssistantAgent(
name: "dalle",
systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url",
llmConfig: new ConversableAgentConfig
{
Temperature = 0,
ConfigList = gpt35Config,
FunctionContracts = new[]
{
instance.GenerateImageFunctionContract,
},
},
var generateImageFunctionMiddleware = new FunctionCallMiddleware(
functions: [instance.GenerateImageFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ nameof(GenerateImage), instance.GenerateImageWrapper },
})
});
var dalleAgent = new OpenAIChatAgent(
openAIClient: openAIClient,
modelName: "gpt-3.5-turbo",
name: "dalle",
systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(generateImageFunctionMiddleware)
.RegisterMiddleware(async (msgs, option, agent, ct) =>
{
// if last message contains [TERMINATE], then find the last image url and terminate the conversation
if (msgs.Last().GetContent()?.Contains("TERMINATE") is true)
if (msgs.Any(msg => msg.GetContent()?.ToLower().Contains("approve") is true))
{
var lastMessageWithImage = msgs.Last(msg => msg is ImageMessage) as ImageMessage;
var lastImageUrl = lastMessageWithImage.Url;
Console.WriteLine($"download image from {lastImageUrl} to {imagePath}");
var httpClient = new HttpClient();
var imageBytes = await httpClient.GetByteArrayAsync(lastImageUrl);
File.WriteAllBytes(imagePath, imageBytes);

var messageContent = $@"{GroupChatExtension.TERMINATE}

{lastImageUrl}";
return new TextMessage(Role.Assistant, messageContent)
{
From = "dalle",
};
return new TextMessage(Role.Assistant, $"The image satisfies the condition, conversation is terminated. {GroupChatExtension.TERMINATE}");
}

var reply = await agent.GenerateReplyAsync(msgs, option, ct);
var msgsWithoutImage = msgs.Where(msg => msg is not ImageMessage).ToList();
var reply = await agent.GenerateReplyAsync(msgsWithoutImage, option, ct);
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved

if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
{
var imageUrl = content.Split("\n").Last();
var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From);

Console.WriteLine($"download image from {imageUrl} to {imagePath}");
var httpClient = new HttpClient();
var imageBytes = await httpClient.GetByteArrayAsync(imageUrl);
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved
File.WriteAllBytes(imagePath, imageBytes);

return imageMessage;
}
else
Expand All @@ -119,33 +109,25 @@ public static async Task RunAsync()
})
.RegisterPrintMessage();

var gpt4VAgent = new AssistantAgent(
var gpt4VAgent = new OpenAIChatAgent(
openAIClient: openAIClient,
name: "gpt4v",
modelName: "gpt-4-vision-preview",
systemMessage: @"You are a critism that provide feedback to DALL-E agent.
Carefully check the image generated by DALL-E agent and provide feedback.
If the image satisfies the condition, then terminate the conversation by saying [TERMINATE].
If the image satisfies the condition, then say [APPROVE].
Otherwise, provide detailed feedback to DALL-E agent so it can generate better image.

The image should satisfy the following conditions:
- There should be a cat and a mouse in the image
- The cat should be chasing after the mouse
",
llmConfig: new ConversableAgentConfig
{
Temperature = 0,
ConfigList = gpt4vConfig,
})
- The cat should be chasing after the mouse")
.RegisterMessageConnector()
.RegisterPrintMessage();

IEnumerable<IMessage> conversation = new List<IMessage>()
{
new TextMessage(Role.User, "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse")
};
var maxRound = 20;
await gpt4VAgent.InitiateChatAsync(
receiver: dalleAgent,
message: "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse",
maxRound: maxRound);
maxRound: 10);

File.Exists(imagePath).Should().BeTrue();
}
Expand Down
4 changes: 1 addition & 3 deletions dotnet/sample/AutoGen.BasicSamples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs

using AutoGen.BasicSample;
Console.ReadLine();
await Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.RunAsync();
await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 19 additions & 9 deletions dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageExtension.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand All @@ -15,7 +16,9 @@ public static string FormatMessage(this IMessage message)
{
return message switch
{
#pragma warning disable CS0618 // deprecated
Message msg => msg.FormatMessage(),
#pragma warning restore CS0618 // deprecated
TextMessage textMessage => textMessage.FormatMessage(),
ImageMessage imageMessage => imageMessage.FormatMessage(),
ToolCallMessage toolCallMessage => toolCallMessage.FormatMessage(),
Expand Down Expand Up @@ -110,6 +113,8 @@ public static string FormatMessage(this AggregateMessage<ToolCallMessage, ToolCa

return sb.ToString();
}

[Obsolete("This method is deprecated, please use the extension method FormatMessage(this IMessage message) instead.")]
public static string FormatMessage(this Message message)
{
var sb = new StringBuilder();
Expand Down Expand Up @@ -149,15 +154,16 @@ public static bool IsSystemMessage(this IMessage message)
return message switch
{
TextMessage textMessage => textMessage.Role == Role.System,
#pragma warning disable CS0618 // deprecated
Message msg => msg.Role == Role.System,
#pragma warning restore CS0618 // deprecated
_ => false,
};
}

/// <summary>
/// Get the content from the message
/// <para>if the message is a <see cref="Message"/> or <see cref="TextMessage"/>, return the content</para>
/// <para>if the message is a <see cref="ToolCallResultMessage"/> and only contains one function call, return the result of that function call</para>
/// <para>if the message implements <see cref="ICanGetTextContent"/>, return the content from the message by calling <see cref="ICanGetTextContent.GetContent()"/></para>
/// <para>if the message is a <see cref="AggregateMessage{ToolCallMessage, ToolCallResultMessage}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/> and the second message only contains one function call, return the result of that function call</para>
/// <para>for all other situation, return null.</para>
/// </summary>
Expand All @@ -166,9 +172,10 @@ public static bool IsSystemMessage(this IMessage message)
{
return message switch
{
TextMessage textMessage => textMessage.Content,
ICanGetTextContent canGetTextContent => canGetTextContent.GetContent(),
#pragma warning disable CS0618 // deprecated
Message msg => msg.Content,
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
#pragma warning restore CS0618 // deprecated
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
_ => null,
};
Expand All @@ -182,7 +189,9 @@ public static bool IsSystemMessage(this IMessage message)
return message switch
{
TextMessage textMessage => textMessage.Role,
#pragma warning disable CS0618 // deprecated
Message msg => msg.Role,
#pragma warning restore CS0618 // deprecated
ImageMessage img => img.Role,
MultiModalMessage multiModal => multiModal.Role,
_ => null,
Expand All @@ -191,8 +200,7 @@ public static bool IsSystemMessage(this IMessage message)

/// <summary>
/// Return the tool calls from the message if it's available.
/// <para>if the message is a <see cref="ToolCallMessage"/>, return its tool calls</para>
/// <para>if the message is a <see cref="Message"/> and the function name and function arguments are available, return a list of tool call with one item</para>
/// <para>if the message implements <see cref="ICanGetToolCalls"/>, return the tool calls from the message by calling <see cref="ICanGetToolCalls.GetToolCalls()"/></para>
/// <para>if the message is a <see cref="AggregateMessage{ToolCallMessage, ToolCallResultMessage}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/>, return the tool calls from the first message</para>
/// </summary>
/// <param name="message"></param>
Expand All @@ -201,11 +209,13 @@ public static bool IsSystemMessage(this IMessage message)
{
return message switch
{
ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls,
ICanGetToolCalls canGetToolCalls => canGetToolCalls.GetToolCalls().ToList(),
#pragma warning disable CS0618 // deprecated
Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null
? msg.Content is not null ? new List<ToolCall> { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) }
: new List<ToolCall> { new ToolCall(msg.FunctionName, msg.FunctionArguments) }
? msg.Content is not null ? [new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content)]
: new List<ToolCall> { new(msg.FunctionName, msg.FunctionArguments) }
: null,
#pragma warning restore CS0618 // deprecated
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message1.ToolCalls,
_ => null,
};
Expand Down
22 changes: 21 additions & 1 deletion dotnet/src/AutoGen.Core/Message/IMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs

using System.Collections.Generic;

namespace AutoGen.Core;

/// <summary>
Expand Down Expand Up @@ -29,7 +31,7 @@ namespace AutoGen.Core;
/// <item>
/// <see cref="AggregateMessage{TMessage1, TMessage2}"/>: an aggregate message type that contains two message types.
/// This type is useful when you want to combine two message types into one unique message type. One example is when invoking a tool call and you want to return both <see cref="ToolCallMessage"/> and <see cref="ToolCallResultMessage"/>.
/// One example of how this type is used in AutoGen is <see cref="FunctionCallMiddleware"/>
/// One example of how this type is used in AutoGen is <see cref="FunctionCallMiddleware"/> and its return message <see cref="ToolCallAggregateMessage"/>
/// </item>
/// </list>
/// </summary>
Expand All @@ -41,6 +43,24 @@ public interface IMessage<out T> : IMessage, IStreamingMessage<T>
{
}

/// <summary>
/// The interface for messages that can get text content.
/// This interface will be used by <see cref="MessageExtension.GetContent(IMessage)"/> to get the content from the message.
/// </summary>
public interface ICanGetTextContent : IMessage, IStreamingMessage
{
public string? GetContent();
}

/// <summary>
/// The interface for messages that can get a list of <see cref="ToolCall"/>
/// </summary>
public interface ICanGetToolCalls : IMessage, IStreamingMessage
{
public IEnumerable<ToolCall> GetToolCalls();
}


public interface IStreamingMessage
{
string? From { get; set; }
Expand Down
2 changes: 2 additions & 0 deletions dotnet/src/AutoGen.Core/Message/Message.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Message.cs

using System;
using System.Collections.Generic;

namespace AutoGen.Core;

[Obsolete("This message class is deprecated, please use a specific AutoGen built-in message type instead. For more information, please visit https://microsoft.github.io/autogen-for-net/articles/Built-in-messages.html")]
public class Message : IMessage
{
public Message(
Expand Down
14 changes: 12 additions & 2 deletions dotnet/src/AutoGen.Core/Message/TextMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace AutoGen.Core;

public class TextMessage : IMessage, IStreamingMessage
public class TextMessage : IMessage, IStreamingMessage, ICanGetTextContent
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved
{
public TextMessage(Role role, string content, string? from = null)
{
Expand Down Expand Up @@ -44,9 +44,14 @@ public override string ToString()
{
return $"TextMessage({this.Role}, {this.Content}, {this.From})";
}

public string? GetContent()
{
return this.Content;
}
}

public class TextMessageUpdate : IStreamingMessage
public class TextMessageUpdate : IStreamingMessage, ICanGetTextContent
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
Expand All @@ -60,4 +65,9 @@ public TextMessageUpdate(Role role, string? content, string? from = null)
public string? From { get; set; }

public Role Role { get; set; }

public string? GetContent()
{
return this.Content;
}
}
28 changes: 28 additions & 0 deletions dotnet/src/AutoGen.Core/Message/ToolCallAggregateMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallAggregateMessage.cs

using System.Collections.Generic;

namespace AutoGen.Core;

/// <summary>
/// An aggregate message that contains a tool call message and a tool call result message.
/// This message type is used by <see cref="FunctionCallMiddleware"/> to return both <see cref="ToolCallMessage"/> and <see cref="ToolCallResultMessage"/>.
/// </summary>
public class ToolCallAggregateMessage : AggregateMessage<ToolCallMessage, ToolCallResultMessage>, ICanGetTextContent, ICanGetToolCalls
{
public ToolCallAggregateMessage(ToolCallMessage message1, ToolCallResultMessage message2, string? from = null)
: base(message1, message2, from)
{
}

public string? GetContent()
{
return this.Message2.GetContent();
}

public IEnumerable<ToolCall> GetToolCalls()
{
return this.Message1.GetToolCalls();
}
}
7 changes: 6 additions & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public override string ToString()
}
}

public class ToolCallMessage : IMessage
public class ToolCallMessage : IMessage, ICanGetToolCalls
{
public ToolCallMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
{
Expand Down Expand Up @@ -89,6 +89,11 @@ public override string ToString()

return sb.ToString();
}

public IEnumerable<ToolCall> GetToolCalls()
{
return this.ToolCalls;
}
}

public class ToolCallMessageUpdate : IStreamingMessage
Expand Down
Loading
Loading