Skip to content

Commit

Permalink
Fix a few issues in IChatClient implementations (#5549)
Browse files Browse the repository at this point in the history
* Fix a few issues in IChatClient implementations

- Avoid null arg exception when constructing system message with null text
- Avoid empty exception when constructing user message with no parts
- Use all parts rather than just first text part for system message
- Handle assistant messages with both content and tools
- Avoid unnecessarily trying to weed out duplicate call ids

* Address PR feedback

- Normalize null to string.Empty in TextContent
- Ensure GetContentParts always produces at least one part, even if empty text content
  • Loading branch information
stephentoub authored Oct 22, 2024
1 parent 2dd959f commit 7cac12b
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.Extensions.AI;

/// <summary>
/// Represents text content in a chat.
/// </summary>
public sealed class TextContent : AIContent
{
private string? _text;

/// <summary>
/// Initializes a new instance of the <see cref="TextContent"/> class.
/// </summary>
/// <param name="text">The text content.</param>
public TextContent(string? text)
{
Text = text;
_text = text;
}

/// <summary>
/// Gets or sets the text content.
/// </summary>
public string? Text { get; set; }
[AllowNull]
public string Text
{
get => _text ?? string.Empty;
set => _text = value;
}

/// <inheritdoc/>
public override string ToString() => Text ?? string.Empty;
public override string ToString() => Text;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
Expand Down Expand Up @@ -410,13 +409,13 @@ private sealed class AzureAIChatToolJson
private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerable<ChatMessage> inputs)
{
// Maps all of the M.E.AI types to the corresponding AzureAI types.
// Unrecognized content is ignored.
// Unrecognized or non-processable content is ignored.

foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System)
{
yield return new ChatRequestSystemMessage(input.Text);
yield return new ChatRequestSystemMessage(input.Text ?? string.Empty);
}
else if (input.Role == ChatRole.Tool)
{
Expand Down Expand Up @@ -444,52 +443,64 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
}
else if (input.Role == ChatRole.User)
{
yield return new ChatRequestUserMessage(input.Contents.Select(static (AIContent item) => item switch
{
TextContent textContent => new ChatMessageTextContentItem(textContent.Text),
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType) :
imageContent.Uri is string uri ? new ChatMessageImageContentItem(new Uri(uri)) :
(ChatMessageContentItem?)null,
_ => null,
}).Where(c => c is not null));
yield return new ChatRequestUserMessage(GetContentParts(input.Contents));
}
else if (input.Role == ChatRole.Assistant)
{
Dictionary<string, ChatCompletionsToolCall>? toolCalls = null;
// TODO: ChatRequestAssistantMessage only enables text content currently.
// Update it with other content types when it supports that.
ChatRequestAssistantMessage message = new()
{
Content = input.Text
};

foreach (var content in input.Contents)
{
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
if (content is FunctionCallContent { CallId: not null } callRequest)
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
string jsonArguments = JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)));
(toolCalls ??= []).Add(
message.ToolCalls.Add(new ChatCompletionsFunctionToolCall(
callRequest.CallId,
new ChatCompletionsFunctionToolCall(
callRequest.CallId,
callRequest.Name,
jsonArguments));
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)))));
}
}

ChatRequestAssistantMessage message = new();
if (toolCalls is not null)
{
foreach (var entry in toolCalls)
{
message.ToolCalls.Add(entry.Value);
}
}
else
{
message.Content = input.Text;
}

yield return message;
}
}
}

/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentItem"/>.</summary>
private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> contents)
{
List<ChatMessageContentItem> parts = [];
foreach (var content in contents)
{
switch (content)
{
case TextContent textContent:
(parts ??= []).Add(new ChatMessageTextContentItem(textContent.Text));
break;

case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
(parts ??= []).Add(new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType));
break;

case ImageContent imageContent when imageContent.Uri is string uri:
(parts ??= []).Add(new ChatMessageImageContentItem(new Uri(uri)));
break;
}
}

if (parts.Count == 0)
{
parts.Add(new ChatMessageTextContentItem(string.Empty));
}

return parts;
}

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
Expand Down
66 changes: 42 additions & 24 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
Expand Down Expand Up @@ -569,13 +568,16 @@ private sealed class OpenAIChatToolJson
private IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs)
{
// Maps all of the M.E.AI types to the corresponding OpenAI types.
// Unrecognized content is ignored.
// Unrecognized or non-processable content is ignored.

foreach (ChatMessage input in inputs)
{
if (input.Role == ChatRole.System)
if (input.Role == ChatRole.System || input.Role == ChatRole.User)
{
yield return new SystemChatMessage(input.Text) { ParticipantName = input.AuthorName };
var parts = GetContentParts(input.Contents);
yield return input.Role == ChatRole.System ?
new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Tool)
{
Expand All @@ -601,39 +603,25 @@ private sealed class OpenAIChatToolJson
}
}
}
else if (input.Role == ChatRole.User)
{
yield return new UserChatMessage(input.Contents.Select(static (AIContent item) => item switch
{
TextContent textContent => ChatMessageContentPart.CreateTextPart(textContent.Text),
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType) :
imageContent.Uri is string uri ? ChatMessageContentPart.CreateImagePart(new Uri(uri)) :
null,
_ => null,
}).Where(c => c is not null))
{ ParticipantName = input.AuthorName };
}
else if (input.Role == ChatRole.Assistant)
{
Dictionary<string, ChatToolCall>? toolCalls = null;
AssistantChatMessage message = new(GetContentParts(input.Contents))
{
ParticipantName = input.AuthorName
};

foreach (var content in input.Contents)
{
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
if (content is FunctionCallContent { CallId: not null } callRequest)
{
(toolCalls ??= []).Add(
callRequest.CallId,
message.ToolCalls.Add(
ChatToolCall.CreateFunctionToolCall(
callRequest.CallId,
callRequest.Name,
BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions)));
}
}

AssistantChatMessage message = toolCalls is not null ?
new(toolCalls.Values) { ParticipantName = input.AuthorName } :
new(input.Text) { ParticipantName = input.AuthorName };

if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true)
{
message.Refusal = refusal;
Expand All @@ -644,6 +632,36 @@ private sealed class OpenAIChatToolJson
}
}

/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentPart"/>.</summary>
private static List<ChatMessageContentPart> GetContentParts(IList<AIContent> contents)
{
List<ChatMessageContentPart> parts = [];
foreach (var content in contents)
{
switch (content)
{
case TextContent textContent:
(parts ??= []).Add(ChatMessageContentPart.CreateTextPart(textContent.Text));
break;

case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType));
break;

case ImageContent imageContent when imageContent.Uri is string uri:
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(new Uri(uri)));
break;
}
}

if (parts.Count == 0)
{
parts.Add(ChatMessageContentPart.CreateTextPart(string.Empty));
}

return parts;
}

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public void Constructor_String_PropsDefault(string? text)
TextContent c = new(text);
Assert.Null(c.RawRepresentation);
Assert.Null(c.AdditionalProperties);
Assert.Equal(text, c.Text);
Assert.Equal(text ?? string.Empty, c.Text);
}

[Fact]
Expand All @@ -34,13 +34,17 @@ public void Constructor_PropsRoundtrip()
c.AdditionalProperties = props;
Assert.Same(props, c.AdditionalProperties);

Assert.Null(c.Text);
Assert.Equal(string.Empty, c.Text);
c.Text = "text";
Assert.Equal("text", c.Text);
Assert.Equal("text", c.ToString());

c.Text = null;
Assert.Null(c.Text);
Assert.Equal(string.Empty, c.Text);
Assert.Equal(string.Empty, c.ToString());

c.Text = string.Empty;
Assert.Equal(string.Empty, c.Text);
Assert.Equal(string.Empty, c.ToString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,89 @@ public async Task MultipleMessages_NonStreaming()
Assert.Equal(57, response.Usage.TotalTokenCount);
}

[Fact]
public async Task NullAssistantText_ContentSkipped_NonStreaming()
{
const string Input = """
{
"messages": [
{
"role": "assistant"
},
{
"content": [
{
"text": "hello!",
"type": "text"
}
],
"role": "user"
}
],
"model": "gpt-4o-mini"
}
""";

const string Output = """
{
"id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P",
"object": "chat.completion",
"created": 1727894187,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello.",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 42,
"completion_tokens": 15,
"total_tokens": 57,
"prompt_tokens_details": {
"cached_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0
}
},
"system_fingerprint": "fp_f85bea6784"
}
""";

using VerbatimHttpHandler handler = new(Input, Output);
using HttpClient httpClient = new(handler);
using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini");

List<ChatMessage> messages =
[
new(ChatRole.Assistant, (string?)null),
new(ChatRole.User, "hello!"),
];

var response = await client.CompleteAsync(messages);
Assert.NotNull(response);

Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId);
Assert.Equal("Hello.", response.Message.Text);
Assert.Single(response.Message.Contents);
Assert.Equal(ChatRole.Assistant, response.Message.Role);
Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId);
Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt);
Assert.Equal(ChatFinishReason.Stop, response.FinishReason);

Assert.NotNull(response.Usage);
Assert.Equal(42, response.Usage.InputTokenCount);
Assert.Equal(15, response.Usage.OutputTokenCount);
Assert.Equal(57, response.Usage.TotalTokenCount);
}

[Fact]
public async Task FunctionCallContent_NonStreaming()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ private int CountTokens(ChatMessage message)
int sum = 0;
foreach (AIContent content in message.Contents)
{
if ((content as TextContent)?.Text is string text)
if (content is TextContent text)
{
sum += _tokenizer.CountTokens(text);
sum += _tokenizer.CountTokens(text.Text);
}
}

Expand Down
Loading

0 comments on commit 7cac12b

Please sign in to comment.