Skip to content

Commit

Permalink
Squash changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLuong98 committed Jun 9, 2024
1 parent 85ad929 commit 101f482
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 37 deletions.
3 changes: 2 additions & 1 deletion dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public sealed class AnthropicClient : IDisposable

private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new ContentBaseConverter() }
};

private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
Expand Down
11 changes: 8 additions & 3 deletions dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

using System.Text.Json.Serialization;
using System.Collections.Generic;

namespace AutoGen.Anthropic.DTO;

using System.Collections.Generic;

public class ChatCompletionRequest
{
[JsonPropertyName("model")]
Expand Down Expand Up @@ -50,9 +49,15 @@ public class ChatMessage
public string Role { get; set; }

[JsonPropertyName("content")]
public string Content { get; set; }
public List<ContentBase> Content { get; set; }

public ChatMessage(string role, string content)
{
Role = role;
Content = new List<ContentBase>() { new TextContent { Text = content } };
}

public ChatMessage(string role, List<ContentBase> content)
{
Role = role;
Content = content;
Expand Down
112 changes: 91 additions & 21 deletions dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -19,7 +20,7 @@ public class AnthropicMessageConnector : IStreamingMiddleware
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);
var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);

return response is IMessage<ChatCompletionResponse> chatMessage
Expand All @@ -31,7 +32,7 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chatMessages = await ProcessMessageAsync(messages, agent);

await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
Expand All @@ -53,60 +54,78 @@ public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext c
private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> chatMessage,
IStreamingAgent agent)
{
Delta? delta = chatMessage.Content.Delta;
var delta = chatMessage.Content.Delta;
return delta != null && !string.IsNullOrEmpty(delta.Text)
? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
: null;
}

private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
private async Task<IEnumerable<IMessage>> ProcessMessageAsync(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany<IMessage, IMessage>(m =>
var processedMessages = new List<IMessage>();

foreach (var message in messages)
{
return m switch
var processedMessage = message switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
_ => [m],

ImageMessage imageMessage =>
new MessageEnvelope<ChatMessage>(new ChatMessage("user",
new ContentBase[] { new ImageContent { Source = await ProcessImageSourceAsync(imageMessage) } }
.ToList()),
from: agent.Name),

MultiModalMessage multiModalMessage => await ProcessMultiModalMessageAsync(multiModalMessage, agent),
_ => message,
};
});

processedMessages.Add(processedMessage);
}

return processedMessages;
}

private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from)
{
if (response.Content is null)
{
throw new ArgumentNullException(nameof(response.Content));
}

if (response.Content.Count != 1)
{
throw new NotSupportedException($"{nameof(response.Content)} != 1");
}

return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name);
}

private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMessage, IAgent agent)
private IMessage<ChatMessage> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
IEnumerable<ChatMessage> messages;
ChatMessage messages;

if (textMessage.From == agent.Name)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.From is null)
{
if (textMessage.Role == Role.User)
{
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}
else if (textMessage.Role == Role.Assistant)
{
messages = [new ChatMessage(
"assistant", textMessage.Content)];
messages = new ChatMessage(
"assistant", textMessage.Content);
}
else if (textMessage.Role == Role.System)
{
messages = [new ChatMessage(
"system", textMessage.Content)];
messages = new ChatMessage(
"system", textMessage.Content);
}
else
{
Expand All @@ -116,10 +135,61 @@ private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMe
else
{
// if from is not null, then the message is from user
messages = [new ChatMessage(
"user", textMessage.Content)];
messages = new ChatMessage(
"user", textMessage.Content);
}

return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: textMessage.From));
return new MessageEnvelope<ChatMessage>(messages, from: textMessage.From);
}

private async Task<IMessage> ProcessMultiModalMessageAsync(MultiModalMessage multiModalMessage, IAgent agent)
{
var content = new List<ContentBase>();
foreach (var message in multiModalMessage.Content)
{
switch (message)
{
case TextMessage textMessage when textMessage.GetContent() is not null:
content.Add(new TextContent { Text = textMessage.GetContent() });
break;
case ImageMessage imageMessage:
content.Add(new ImageContent() { Source = await ProcessImageSourceAsync(imageMessage) });
break;
}
}

var chatMessage = new ChatMessage("user", content);
return MessageEnvelope.Create(chatMessage, agent.Name);
}

private async Task<ImageSource> ProcessImageSourceAsync(ImageMessage imageMessage)
{
if (imageMessage.Data != null)
{
return new ImageSource
{
MediaType = imageMessage.Data.MediaType,
Data = Convert.ToBase64String(imageMessage.Data.ToArray())
};
}

if (imageMessage.Url is null)
{
throw new InvalidOperationException("Invalid ImageMessage, the data or url must be provided");
}

var uri = new Uri(imageMessage.Url);
using var client = new HttpClient();
var response = client.GetAsync(uri).Result;
if (!response.IsSuccessStatusCode)
{
throw new HttpRequestException($"Failed to download the image from {uri}");
}

return new ImageSource
{
MediaType = "image/jpeg",
Data = Convert.ToBase64String(await response.Content.ReadAsByteArrayAsync())
};
}
}
1 change: 0 additions & 1 deletion dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat

public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{

return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}

Expand Down
95 changes: 86 additions & 9 deletions dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClientAgentTest.cs

using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Core;
using AutoGen.Tests;
using Xunit.Abstractions;
using FluentAssertions;

namespace AutoGen.Anthropic;
namespace AutoGen.Anthropic.Tests;

public class AnthropicClientAgentTest
{
private readonly ITestOutputHelper _output;

public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output;

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentChatCompletionTestAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);

var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that convert user message to upper case")
.RegisterMessageConnector();

var uppCaseMessage = new TextMessage(Role.User, "abcdefg");

var reply = await agent.SendAsync(chatHistory: new[] { uppCaseMessage });

reply.GetContent().Should().Contain("ABCDEFG");
reply.From.Should().Be(agent.Name);
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku).RegisterMessageConnector();

var singleAgentTest = new SingleAgentTest(_output);
await singleAgentTest.UpperCaseTestAsync(agent);
await singleAgentTest.UpperCaseStreamingTestAsync(agent);
var base64Image = await AnthropicTestUtils.Base64FromImageAsync("square.png");
var imageMessage = new ChatMessage("user",
[new ImageContent { Source = new ImageSource { MediaType = "image/png", Data = base64Image } }]);

var messages = new IMessage[] { MessageEnvelope.Create(imageMessage) };

// test streaming
foreach (var message in messages)
{
var reply = agent.GenerateStreamingReplyAsync([message]);

await foreach (var streamingMessage in reply)
{
streamingMessage.Should().BeOfType<TextMessageUpdate>();
streamingMessage.As<TextMessageUpdate>().From.Should().Be(agent.Name);
}
}
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestMultiModalAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku)
.RegisterMessageConnector();

var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);
var textMessage = new TextMessage(Role.User, "What's in this image?");
var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);

var reply = await agent.SendAsync(multiModalMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}

[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestImageMessageAsync()
{
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that is capable of determining what an image is. Tell me a brief description of the image."
)
.RegisterMessageConnector();

var image = Path.Combine("images", "square.png");
var binaryData = BinaryData.FromBytes(await File.ReadAllBytesAsync(image), "image/png");
var imageMessage = new ImageMessage(Role.User, binaryData);

var reply = await agent.SendAsync(imageMessage);
reply.Should().BeOfType<TextMessage>();
reply.GetRole().Should().Be(Role.Assistant);
reply.GetContent().Should().NotBeNullOrEmpty();
reply.From.Should().Be(agent.Name);
}
}
Loading

0 comments on commit 101f482

Please sign in to comment.