diff --git a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionResponse.cs b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionResponse.cs index ff241f8d340..13e29e7139b 100644 --- a/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionResponse.cs +++ b/dotnet/src/AutoGen.Mistral/DTOs/ChatCompletionResponse.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // ChatCompletionResponse.cs using System.Collections.Generic; diff --git a/dotnet/src/AutoGen.Mistral/DTOs/Error.cs b/dotnet/src/AutoGen.Mistral/DTOs/Error.cs index 77eb2d341fb..8bddcfc776c 100644 --- a/dotnet/src/AutoGen.Mistral/DTOs/Error.cs +++ b/dotnet/src/AutoGen.Mistral/DTOs/Error.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Error.cs using System.Text.Json.Serialization; diff --git a/dotnet/src/AutoGen.Mistral/DTOs/Model.cs b/dotnet/src/AutoGen.Mistral/DTOs/Model.cs index 915d2f737ec..70a4b3c997d 100644 --- a/dotnet/src/AutoGen.Mistral/DTOs/Model.cs +++ b/dotnet/src/AutoGen.Mistral/DTOs/Model.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Model.cs using System; diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index ecebe7fc3fa..487a361d7de 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -84,7 +84,7 @@ public async Task GenerateReplyAsync( var settings = this.CreateChatCompletionsOptions(options, messages); var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken); - return new MessageEnvelope(reply.Value.Choices.First().Message, from: this.Name); + return new MessageEnvelope(reply, from: this.Name); } public Task> GenerateStreamingReplyAsync( @@ -101,7 +101,7 @@ private async IAsyncEnumerable StreamingReplyAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { var settings = this.CreateChatCompletionsOptions(options, messages); - var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings); + var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings, cancellationToken); await foreach (var update in response.WithCancellation(cancellationToken)) { if (update.ChoiceIndex > 0) diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index c1581cbec08..3d96855c16c 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -98,6 +98,7 @@ public IMessage PostProcessMessage(IMessage message) Message => message, AggregateMessage => message, IMessage m => PostProcessMessage(m), + IMessage m => PostProcessMessage(m), _ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage, AggregateMessage"), }; } @@ -129,15 +130,24 @@ public IMessage PostProcessMessage(IMessage message) private IMessage PostProcessMessage(IMessage message) { - var chatResponseMessage = message.Content; + return PostProcessMessage(message.Content, message.From); + } + + private IMessage PostProcessMessage(IMessage message) + { + return PostProcessMessage(message.Content.Choices[0].Message, message.From); + } + + private IMessage PostProcessMessage(ChatResponseMessage chatResponseMessage, string? from) + { if (chatResponseMessage.Content is string content) { - return new TextMessage(Role.Assistant, content, message.From); + return new TextMessage(Role.Assistant, content, from); } if (chatResponseMessage.FunctionCall is FunctionCall functionCall) { - return new ToolCallMessage(functionCall.Name, functionCall.Arguments, message.From); + return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from); } if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any()) @@ -148,7 +158,7 @@ private IMessage PostProcessMessage(IMessage message) var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments)); - return new ToolCallMessage(toolCalls, message.From); + return new ToolCallMessage(toolCalls, from); } throw new InvalidOperationException("Invalid ChatResponseMessage"); diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs index 8626618fea7..a4753b66871 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs @@ -41,9 +41,10 @@ public async Task BasicConversationTestAsync() var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello")); var reply = await openAIChatAgent.SendAsync(chatMessageContent); - reply.Should().BeOfType>(); - reply.As>().From.Should().Be("assistant"); - reply.As>().Content.Role.Should().Be(ChatRole.Assistant); + reply.Should().BeOfType>(); + reply.As>().From.Should().Be("assistant"); + reply.As>().Content.Choices.First().Message.Role.Should().Be(ChatRole.Assistant); + reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0); // test streaming var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });