Skip to content

Commit

Permalink
Replace STJ boilerplate in the leaf clients with AIJsonUtilities call…
Browse files Browse the repository at this point in the history
…s. (#5630)

* Replace STJ boilerplate in the leaf clients with AIJsonUtilities calls.

* Address feedback.

* Address feedback.

* Remove redundant using
  • Loading branch information
eiriktsarpalis authored Nov 13, 2024
1 parent ad3b5d0 commit 73962c6
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;
Expand All @@ -27,6 +28,9 @@ public sealed class AzureAIInferenceChatClient : IChatClient
/// <summary>The underlying <see cref="ChatCompletionsClient" />.</summary>
private readonly ChatCompletionsClient _chatCompletionsClient;

/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;

/// <summary>Initializes a new instance of the <see cref="AzureAIInferenceChatClient"/> class for the specified <see cref="ChatCompletionsClient"/>.</summary>
/// <param name="chatCompletionsClient">The underlying client.</param>
/// <param name="modelId">The ID of the model to use. If null, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
Expand All @@ -51,7 +55,11 @@ public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, s
}

/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}

/// <inheritdoc />
public ChatClientMetadata Metadata { get; }
Expand Down Expand Up @@ -304,7 +312,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
// These properties are strongly typed on ChatOptions but not on ChatCompletionsOptions.
if (options.TopK is int topK)
{
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, JsonContext.Default.Int32));
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(int))));
}

if (options.AdditionalProperties is { } props)
Expand All @@ -317,7 +325,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
default:
if (prop.Value is not null)
{
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, JsonContext.GetTypeInfo(prop.Value.GetType(), ToolCallJsonSerializerOptions));
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}

Expand Down Expand Up @@ -419,7 +427,7 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
{
try
{
result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
Expand Down Expand Up @@ -449,7 +457,7 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
callRequest.CallId,
new FunctionCall(
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, JsonContext.GetTypeInfo(typeof(IDictionary<string, object>), ToolCallJsonSerializerOptions)))));
JsonSerializer.Serialize(callRequest.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object>))))));
}
}

Expand Down Expand Up @@ -490,5 +498,6 @@ private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> con

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private EmbeddingsOptions ToAzureAIOptions(IEnumerable<string> inputs, Embedding
{
if (prop.Value is not null)
{
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, JsonContext.GetTypeInfo(prop.Value.GetType(), null));
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

namespace Microsoft.Extensions.AI;

Expand All @@ -16,55 +12,4 @@ namespace Microsoft.Extensions.AI;
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(AzureAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(long))]
[JsonSerializable(typeof(float))]
[JsonSerializable(typeof(double))]
[JsonSerializable(typeof(bool))]
[JsonSerializable(typeof(float[]))]
[JsonSerializable(typeof(byte[]))]
[JsonSerializable(typeof(sbyte[]))]
internal sealed partial class JsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();

/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
_defaultToolJsonOptions.GetTypeInfo(type);

/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.

if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};

options.MakeReadOnly();
return options;
}

return Default.Options;
}
}
internal sealed partial class JsonContext : JsonSerializerContext;
4 changes: 0 additions & 4 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/JsonContext.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.Extensions.AI;
Expand All @@ -23,6 +21,4 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(OllamaToolCall))]
[JsonSerializable(typeof(OllamaEmbeddingRequest))]
[JsonSerializable(typeof(OllamaEmbeddingResponse))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
internal sealed partial class JsonContext : JsonSerializerContext;
15 changes: 10 additions & 5 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public sealed class OllamaChatClient : IChatClient
/// <summary>The <see cref="HttpClient"/> to use for sending requests.</summary>
private readonly HttpClient _httpClient;

/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;

/// <summary>Initializes a new instance of the <see cref="OllamaChatClient"/> class.</summary>
/// <param name="endpoint">The endpoint URI where Ollama is hosted.</param>
/// <param name="modelId">
Expand Down Expand Up @@ -66,7 +69,11 @@ public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpCl
public ChatClientMetadata Metadata { get; }

/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}

/// <inheritdoc />
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -388,24 +395,22 @@ private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMe

case FunctionCallContent fcc:
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
yield return new OllamaChatRequestMessage
{
Role = "assistant",
Content = JsonSerializer.Serialize(new OllamaFunctionCallContent
{
CallId = fcc.CallId,
Name = fcc.Name,
Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
}, JsonContext.Default.OllamaFunctionCallContent)
};
break;
}

case FunctionResultContent frc:
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, serializerOptions.GetTypeInfo(typeof(object)));
JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
yield return new OllamaChatRequestMessage
{
Role = "tool",
Expand Down
65 changes: 15 additions & 50 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
Expand Down Expand Up @@ -38,6 +37,9 @@ public sealed partial class OpenAIChatClient : IChatClient
/// <summary>The underlying <see cref="ChatClient" />.</summary>
private readonly ChatClient _chatClient;

/// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;

/// <summary>Initializes a new instance of the <see cref="OpenAIChatClient"/> class for the specified <see cref="OpenAIClient"/>.</summary>
/// <param name="openAIClient">The underlying client.</param>
/// <param name="modelId">The model to use.</param>
Expand Down Expand Up @@ -80,7 +82,11 @@ public OpenAIChatClient(ChatClient chatClient)
}

/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
get => _toolCallJsonSerializerOptions;
set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
}

/// <inheritdoc />
public ChatClientMetadata Metadata { get; }
Expand Down Expand Up @@ -593,7 +599,7 @@ private sealed class OpenAIChatToolJson
{
try
{
result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
Expand Down Expand Up @@ -622,7 +628,7 @@ private sealed class OpenAIChatToolJson
callRequest.Name,
new(JsonSerializer.SerializeToUtf8Bytes(
callRequest.Arguments,
JsonContext.GetTypeInfo(typeof(IDictionary<string, object?>), ToolCallJsonSerializerOptions)))));
ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))))));
}
}

Expand Down Expand Up @@ -668,60 +674,19 @@ private static List<ChatMessageContentPart> GetContentParts(IList<AIContent> con

private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);

private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8Json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);

/// <summary>Source-generated JSON type information.</summary>
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
private sealed partial class JsonContext : JsonSerializerContext
{
/// <summary>Gets the <see cref="JsonSerializerOptions"/> singleton used as the default in JSON serialization operations.</summary>
private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();

/// <summary>Gets JSON type information for the specified type.</summary>
/// <remarks>
/// This first tries to get the type information from <paramref name="firstOptions"/>,
/// falling back to <see cref="_defaultToolJsonOptions"/> if it can't.
/// </remarks>
public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
info :
_defaultToolJsonOptions.GetTypeInfo(type);

/// <summary>Creates the default <see cref="JsonSerializerOptions"/> to use for serialization-related operations.</summary>
[UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
// Otherwise, use the source-generated options to enable trimming and Native AOT.

if (JsonSerializer.IsReflectionEnabledByDefault)
{
// Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true,
};

options.MakeReadOnly();
return options;
}

return Default.Options;
}
}
private sealed partial class JsonContext : JsonSerializerContext;
}
Loading

0 comments on commit 73962c6

Please sign in to comment.