Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose an AIJsonUtilities class in M.E.AI and lower M.E.AI.Abstractions to STJv8 #5513

Merged
Merged

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.Extensions.AI;

/// <summary>
/// An options class for configuring the behavior of <see cref="JsonFunctionCallUtilities"/> JSON schema inference functionality.
/// </summary>
public sealed class JsonSchemaInferenceOptions
{
/// <summary>
/// Gets the default options instance.
/// </summary>
public static JsonSchemaInferenceOptions Default { get; } = new JsonSchemaInferenceOptions();

/// <summary>
/// Gets a value indicating whether to include the type keywork in inferred schemas for .NET enums.
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
public bool IncludeTypeInEnumSchemas { get; init; }

/// <summary>
/// Gets a value indicating whether to generate schemas with the additionalProperties set to false for .NET objects.
/// </summary>
public bool DisallowAdditionalProperties { get; init; }

/// <summary>
/// Gets a value indicating whether to include the $schema keyword in inferred schemas.
/// </summary>
public bool IncludeSchemaKeyword { get; init; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public async Task<ChatCompletion> CompleteAsync(
{
if (toolCall is ChatCompletionsFunctionToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name))
{
Dictionary<string, object?>? arguments = FunctionCallHelpers.ParseFunctionCallArguments(ftc.Arguments, out Exception? parsingException);
Dictionary<string, object?>? arguments = JsonFunctionCallUtilities.ParseFunctionCallArguments(ftc.Arguments, out Exception? parsingException);

returnMessage.Contents.Add(new FunctionCallContent(toolCall.Id, ftc.Name, arguments)
{
Expand Down Expand Up @@ -226,7 +226,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
FunctionCallInfo fci = entry.Value;
if (!string.IsNullOrWhiteSpace(fci.Name))
{
var arguments = FunctionCallHelpers.ParseFunctionCallArguments(
var arguments = JsonFunctionCallUtilities.ParseFunctionCallArguments(
fci.Arguments?.ToString() ?? string.Empty,
out Exception? parsingException);

Expand Down Expand Up @@ -371,7 +371,7 @@ private ChatCompletionsFunctionToolDefinition ToAzureAIChatTool(AIFunction aiFun
{
tool.Properties.Add(
parameter.Name,
FunctionCallHelpers.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions));
JsonFunctionCallUtilities.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions));

if (parameter.IsRequired)
{
Expand Down Expand Up @@ -428,9 +428,10 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null)
{
JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a separate change if we decide to do it, but should ToolCallJsonSerializerOptions be made non-nullable like it is on the middleware clients, just defaulting it to the default options?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible, but it would require moving the defaults to Abstractions so that clients can reference it.

Copy link
Member

@stephentoub stephentoub Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding why this would require any movement. I'm simply suggesting changing:

public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; }

to

private JsonSerializerOptions _toolCallOptions = JsonContext.Default.Options;
...
public JsonSerializerOptions ToolCallJsonSerializerOptions
{
    get => _toolCallOptions;
    set => _toolCallOptions = Throw.IfNull(value);
}

and then here instead of:

JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;

it'd just be:

JsonSerializerOptions options = ToolCallJsonSerializerOptions;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we're trying to avoid exposing this options instance?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just pointing out that unless we expose a default options instance on Abstractions users will have to define their own contexts to fill in the default _toolCallOptions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just pointing out that unless we expose a default options instance on Abstractions users will have to define their own contexts to fill in the default _toolCallOptions.

I don't understand why. The current code doesn't require that, just substituting JsonContext.Default.Options when the user hasn't supplied their own. I'm simply suggesting moving around where we default back to that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. It is appropriate to default to JsonContext.Default.Options in the particular locations where it is being done because they're serializing specific types (JsonElement for FCR and IDictionary<string, object> for FCC) which are known to be defined in the local JsonContext. I don't think it would be safe to assume the same thing in other locations where ToolCallJsonSerializerOptions is being used (e.g. when serializing AdditionalProperties), so if we wanted to default to something in that case we should be using the global options instead.

try
{
result = FunctionCallHelpers.FormatFunctionResultAsJson(resultContent.Result, ToolCallJsonSerializerOptions);
result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
Expand Down Expand Up @@ -461,7 +462,8 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
{
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
{
string jsonArguments = FunctionCallHelpers.FormatFunctionParametersAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions);
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
string jsonArguments = JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)));
(toolCalls ??= []).Add(
callRequest.CallId,
new ChatCompletionsFunctionToolCall(
Expand Down Expand Up @@ -491,5 +493,7 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab

/// <summary>Source-generated JSON type information.</summary>
[JsonSerializable(typeof(AzureAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
private sealed partial class JsonContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
<InjectSharedEmptyCollections>true</InjectSharedEmptyCollections>
<InjectStringHashOnLegacy>true</InjectStringHashOnLegacy>
</PropertyGroup>

<ItemGroup>
<Compile Include="../Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs" />
</ItemGroup>


<ItemGroup>
<PackageReference Include="Azure.AI.Inference" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
Expand Down
4 changes: 4 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/JsonContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.Extensions.AI;
Expand All @@ -21,4 +23,6 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(OllamaToolCall))]
[JsonSerializable(typeof(OllamaEmbeddingRequest))]
[JsonSerializable(typeof(OllamaEmbeddingResponse))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
internal sealed partial class JsonContext : JsonSerializerContext;
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@
<InjectStringHashOnLegacy>true</InjectStringHashOnLegacy>
</PropertyGroup>

<ItemGroup>
<Compile Include="../Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="System.Net.Http.Json" />
<PackageReference Include="System.Text.Json" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,22 @@ private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMe
break;

case FunctionCallContent fcc:
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
yield return new OllamaChatRequestMessage
{
Role = "assistant",
Content = JsonSerializer.Serialize(new OllamaFunctionCallContent
{
CallId = fcc.CallId,
Name = fcc.Name,
Arguments = FunctionCallHelpers.FormatFunctionParametersAsJsonElement(fcc.Arguments, ToolCallJsonSerializerOptions),
Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
}, JsonContext.Default.OllamaFunctionCallContent)
};
break;

case FunctionResultContent frc:
JsonElement jsonResult = FunctionCallHelpers.FormatFunctionResultAsJsonElement(frc.Result, ToolCallJsonSerializerOptions);
JsonSerializerOptions serializerOptions1 = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, serializerOptions1.GetTypeInfo(typeof(object)));
yield return new OllamaChatRequestMessage
{
Role = "tool",
Expand Down Expand Up @@ -400,7 +402,7 @@ private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMe
{
Properties = function.Metadata.Parameters.ToDictionary(
p => p.Name,
p => FunctionCallHelpers.InferParameterJsonSchema(p, function.Metadata, ToolCallJsonSerializerOptions)),
p => JsonFunctionCallUtilities.InferParameterJsonSchema(p, function.Metadata, ToolCallJsonSerializerOptions)),
Required = function.Metadata.Parameters.Where(p => p.IsRequired).Select(p => p.Name).ToList(),
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
<InjectStringHashOnLegacy>true</InjectStringHashOnLegacy>
</PropertyGroup>

<ItemGroup>
<Compile Include="../Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="OpenAI" />
<PackageReference Include="System.Text.Json" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public async Task<ChatCompletion> CompleteAsync(
{
if (!string.IsNullOrWhiteSpace(toolCall.FunctionName))
{
Dictionary<string, object?>? arguments = FunctionCallHelpers.ParseFunctionCallArguments(toolCall.FunctionArguments, out Exception? parsingException);
Dictionary<string, object?>? arguments = JsonFunctionCallUtilities.ParseFunctionCallArguments(toolCall.FunctionArguments, out Exception? parsingException);

returnMessage.Contents.Add(new FunctionCallContent(toolCall.Id, toolCall.FunctionName, arguments)
{
Expand Down Expand Up @@ -321,7 +321,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
FunctionCallInfo fci = entry.Value;
if (!string.IsNullOrWhiteSpace(fci.Name))
{
var arguments = FunctionCallHelpers.ParseFunctionCallArguments(
var arguments = JsonFunctionCallUtilities.ParseFunctionCallArguments(
fci.Arguments?.ToString() ?? string.Empty,
out Exception? parsingException);

Expand Down Expand Up @@ -501,7 +501,7 @@ private ChatTool ToOpenAIChatTool(AIFunction aiFunction)
{
tool.Properties.Add(
parameter.Name,
FunctionCallHelpers.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions));
JsonFunctionCallUtilities.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions));

if (parameter.IsRequired)
{
Expand Down Expand Up @@ -596,9 +596,10 @@ private sealed class OpenAIChatToolJson
string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null)
{
JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
try
{
result = FunctionCallHelpers.FormatFunctionResultAsJson(resultContent.Result, ToolCallJsonSerializerOptions);
result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object)));
}
catch (NotSupportedException)
{
Expand Down Expand Up @@ -655,5 +656,6 @@ private sealed class OpenAIChatToolJson

/// <summary>Source-generated JSON type information.</summary>
[JsonSerializable(typeof(OpenAIChatToolJson))]
[JsonSerializable(typeof(JsonElement))]
private sealed partial class JsonContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
using static Microsoft.Extensions.AI.FunctionCallHelpers;

namespace Microsoft.Extensions.AI;

Expand All @@ -19,6 +16,13 @@ namespace Microsoft.Extensions.AI;
/// </summary>
public static class ChatClientStructuredOutputExtensions
{
private static readonly JsonSchemaInferenceOptions _inferenceOptions = new()
{
IncludeSchemaKeyword = true,
DisallowAdditionalProperties = true,
IncludeTypeInEnumSchemas = true
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
};

/// <summary>Sends chat messages to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
/// <param name="chatClient">The <see cref="IChatClient"/>.</param>
/// <param name="chatMessages">The chat content to send.</param>
Expand Down Expand Up @@ -120,26 +124,12 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

serializerOptions.MakeReadOnly();

var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), new()
{
TreatNullObliviousAsNonNullable = true,
TransformSchemaNode = static (context, node) =>
{
if (node is JsonObject obj)
{
if (obj.TryGetPropertyValue("enum", out _)
&& !obj.TryGetPropertyValue("type", out _))
{
obj.Insert(0, "type", "string");
}
}
var schemaNode = JsonFunctionCallUtilities.InferJsonSchema(
typeof(T),
serializerOptions,
inferenceOptions: _inferenceOptions);

return node;
},
});
schemaNode.Insert(0, "$schema", "https://json-schema.org/draft/2020-12/schema");
schemaNode.Add("additionalProperties", false);
var schema = JsonSerializer.Serialize(schemaNode, JsonDefaults.Options.GetTypeInfo(typeof(JsonNode)));
var schema = JsonSerializer.Serialize(schemaNode, JsonDefaults.Options.GetTypeInfo(typeof(JsonElement)));
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

ChatMessage? promptAugmentation = null;
options = (options ?? new()).Clone();
Expand All @@ -153,7 +143,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
// the LLM backend is meant to do whatever's needed to explain the schema to the LLM.
options.ResponseFormat = ChatResponseFormat.ForJsonSchema(
schema,
schemaName: SanitizeMetadataName(typeof(T).Name),
schemaName: JsonFunctionCallUtilities.SanitizeMemberName(typeof(T).Name),
schemaDescription: typeof(T).GetCustomAttribute<DescriptionAttribute>()?.Description);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
using System.Threading.Tasks;
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
using static Microsoft.Extensions.AI.FunctionCallHelpers;

namespace Microsoft.Extensions.AI;

Expand All @@ -40,7 +39,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions?
/// <param name="method">The method to be represented via the created <see cref="AIFunction"/>.</param>
/// <param name="name">The name to use for the <see cref="AIFunction"/>.</param>
/// <param name="description">The description to use for the <see cref="AIFunction"/>.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/> used to marshal function parameters.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/> used to marshal function parameters and return value.</param>
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null)
{
Expand Down Expand Up @@ -86,7 +85,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
/// </param>
/// <param name="name">The name to use for the <see cref="AIFunction"/>.</param>
/// <param name="description">The description to use for the <see cref="AIFunction"/>.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/> used to marshal function parameters.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/> used to marshal function parameters and return value.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null)
{
Expand Down Expand Up @@ -147,7 +146,7 @@ public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactory
string? functionName = options.Name;
if (functionName is null)
{
functionName = SanitizeMetadataName(method.Name!);
functionName = JsonFunctionCallUtilities.SanitizeMemberName(method.Name!);

const string AsyncSuffix = "Async";
if (IsAsyncMethod(method) &&
Expand Down Expand Up @@ -210,7 +209,7 @@ static bool IsAsyncMethod(MethodInfo method)
{
ParameterType = returnType,
Description = method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
Schema = FunctionCallHelpers.InferReturnParameterJsonSchema(returnType, options.SerializerOptions),
Schema = JsonFunctionCallUtilities.InferJsonSchema(returnType, options.SerializerOptions),
},
AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary<string, object?>.Instance,
JsonSerializerOptions = options.SerializerOptions,
Expand Down Expand Up @@ -356,13 +355,13 @@ static bool IsAsyncMethod(MethodInfo method)
DefaultValue = parameter.HasDefaultValue ? parameter.DefaultValue : null,
IsRequired = !parameter.IsOptional,
ParameterType = parameter.ParameterType,
Schema = FunctionCallHelpers.InferParameterJsonSchema(
Schema = JsonFunctionCallUtilities.InferParameterJsonSchema(
parameter.ParameterType,
parameter.Name,
options,
description,
parameter.HasDefaultValue,
parameter.DefaultValue,
options)
parameter.DefaultValue)
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
<DisableMicrosoftExtensionsLoggingSourceGenerator>false</DisableMicrosoftExtensionsLoggingSourceGenerator>
</PropertyGroup>

<ItemGroup>
<Compile Include="../Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
Expand Down
Loading
Loading