Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update M.E.AI.AzureAIInference for its beta2 release #5558

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eng/packages/General.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<Project xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<PackageVersion Include="Azure.AI.Inference" Version="1.0.0-beta.1" />
<PackageVersion Include="Azure.AI.Inference" Version="1.0.0-beta.2" />
<PackageVersion Include="ICSharpCode.Decompiler" Version="8.2.0.7535" />
<PackageVersion Include="Microsoft.Bcl.HashCode" Version="1.1.1" />
<PackageVersion Include="Microsoft.CodeAnalysis.Analyzers" Version="$(MicrosoftCodeAnalysisAnalyzersVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.Extensions.AI;
#endif
[JsonDerivedType(typeof(Embedding<float>), typeDiscriminator: "floats")]
[JsonDerivedType(typeof(Embedding<double>), typeDiscriminator: "doubles")]
[JsonDerivedType(typeof(Embedding<byte>), typeDiscriminator: "bytes")]
[JsonDerivedType(typeof(Embedding<sbyte>), typeDiscriminator: "sbytes")]
public class Embedding
{
/// <summary>Initializes a new instance of the <see cref="Embedding"/> class.</summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.Extensions.AI;

/// <summary>Used to create the JSON payload for an AzureAI chat tool description.</summary>
internal sealed class AzureAIChatToolJson
{
/// <summary>Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function.</summary>
public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray());
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

[JsonPropertyName("type")]
public string Type { get; set; } = "object";

[JsonPropertyName("required")]
public List<string> Required { get; set; } = [];

[JsonPropertyName("properties")]
public Dictionary<string, JsonElement> Properties { get; set; } = [];
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;
Expand All @@ -20,8 +19,9 @@
namespace Microsoft.Extensions.AI;

/// <summary>An <see cref="IChatClient"/> for an Azure AI Inference <see cref="ChatCompletionsClient"/>.</summary>
public sealed partial class AzureAIInferenceChatClient : IChatClient
public sealed class AzureAIInferenceChatClient : IChatClient
{
/// <summary>A default schema to use when a parameter lacks a pre-defined schema.</summary>
private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement;

/// <summary>The underlying <see cref="ChatCompletionsClient" />.</summary>
Expand Down Expand Up @@ -77,43 +77,33 @@ public async Task<ChatCompletion> CompleteAsync(
List<ChatMessage> returnMessages = [];

// Populate its content from those in the response content.
ChatFinishReason? finishReason = null;
foreach (var choice in response.Choices)
ChatMessage message = new()
{
ChatMessage returnMessage = new()
{
RawRepresentation = choice,
Role = ToChatRole(choice.Message.Role),
AdditionalProperties = new() { [nameof(choice.Index)] = choice.Index },
};
RawRepresentation = response,
Role = ToChatRole(response.Role),
};

finishReason ??= ToFinishReason(choice.FinishReason);
if (response.Content is string content)
{
message.Text = content;
}

if (choice.Message.ToolCalls is { Count: > 0 } toolCalls)
if (response.ToolCalls is { Count: > 0 } toolCalls)
{
foreach (var toolCall in toolCalls)
{
foreach (var toolCall in toolCalls)
if (toolCall is ChatCompletionsToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name))
{
if (toolCall is ChatCompletionsFunctionToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name))
{
FunctionCallContent callContent = ParseCallContentFromJsonString(ftc.Arguments, toolCall.Id, ftc.Name);
callContent.RawRepresentation = toolCall;
FunctionCallContent callContent = ParseCallContentFromJsonString(ftc.Arguments, toolCall.Id, ftc.Name);
callContent.RawRepresentation = toolCall;

returnMessage.Contents.Add(callContent);
}
message.Contents.Add(callContent);
}
}

if (!string.IsNullOrEmpty(choice.Message.Content))
{
returnMessage.Contents.Add(new TextContent(choice.Message.Content)
{
RawRepresentation = choice.Message
});
}

returnMessages.Add(returnMessage);
}

returnMessages.Add(message);

UsageDetails? usage = null;
if (response.Usage is CompletionsUsage completionsUsage)
{
Expand All @@ -128,11 +118,11 @@ public async Task<ChatCompletion> CompleteAsync(
// Wrap the content in a ChatCompletion to return.
return new ChatCompletion(returnMessages)
{
RawRepresentation = response,
CompletionId = response.Id,
CreatedAt = response.Created,
ModelId = response.Model,
FinishReason = finishReason,
FinishReason = ToFinishReason(response.FinishReason),
RawRepresentation = response,
Usage = usage,
};
}
Expand All @@ -143,13 +133,13 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
{
_ = Throw.IfNull(chatMessages);

Dictionary<int, FunctionCallInfo>? functionCallInfos = null;
Dictionary<string, FunctionCallInfo>? functionCallInfos = null;
ChatRole? streamedRole = default;
ChatFinishReason? finishReason = default;
string? completionId = null;
DateTimeOffset? createdAt = null;
string? modelId = null;
string? authorName = null;
string lastCallId = string.Empty;

// Process each update as it arrives
var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false);
Expand All @@ -161,12 +151,10 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
completionId ??= chatCompletionUpdate.Id;
createdAt ??= chatCompletionUpdate.Created;
modelId ??= chatCompletionUpdate.Model;
authorName ??= chatCompletionUpdate.AuthorName;

// Create the response content object.
StreamingChatCompletionUpdate completionUpdate = new()
{
AuthorName = authorName,
CompletionId = chatCompletionUpdate.Id,
CreatedAt = chatCompletionUpdate.Created,
FinishReason = finishReason,
Expand All @@ -182,34 +170,52 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
}

// Transfer over tool call updates.
if (chatCompletionUpdate.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate)
if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate)
{
// TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference
// has removed the Index property from ToolCallUpdate. It's now impossible via the
// exposed APIs to correctly handle multiple parallel tool calls, as the CallId is
// often null for anything other than the first update for a given call, and Index
// isn't available to correlate which updates are for which call. This is a temporary
// workaround to at least make a single tool call work and also make work multiple
// tool calls when their updates aren't interleaved.
if (toolCallUpdate.Id is not null)
{
lastCallId = toolCallUpdate.Id;
}

functionCallInfos ??= [];
if (!functionCallInfos.TryGetValue(toolCallUpdate.ToolCallIndex, out FunctionCallInfo? existing))
if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing))
{
functionCallInfos[toolCallUpdate.ToolCallIndex] = existing = new();
functionCallInfos[lastCallId] = existing = new();
}

existing.CallId ??= toolCallUpdate.Id;
existing.Name ??= toolCallUpdate.Name;
if (toolCallUpdate.ArgumentsUpdate is not null)
existing.Name ??= toolCallUpdate.Function.Name;
if (toolCallUpdate.Function.Arguments is { } arguments)
{
_ = (existing.Arguments ??= new()).Append(toolCallUpdate.ArgumentsUpdate);
_ = (existing.Arguments ??= new()).Append(arguments);
}
}

if (chatCompletionUpdate.Usage is { } usage)
{
completionUpdate.Contents.Add(new UsageContent(new()
{
InputTokenCount = usage.PromptTokens,
OutputTokenCount = usage.CompletionTokens,
TotalTokenCount = usage.TotalTokens,
}));
}

// Now yield the item.
yield return completionUpdate;
}

// TODO: Add usage as content when it's exposed by Azure.AI.Inference.

// Now that we've received all updates, combine any for function calls into a single item to yield.
if (functionCallInfos is not null)
{
var completionUpdate = new StreamingChatCompletionUpdate
{
AuthorName = authorName,
CompletionId = completionId,
CreatedAt = createdAt,
FinishReason = finishReason,
Expand All @@ -224,7 +230,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
{
FunctionCallContent callContent = ParseCallContentFromJsonString(
fci.Arguments?.ToString() ?? string.Empty,
fci.CallId!,
entry.Key,
fci.Name!);
completionUpdate.Contents.Add(callContent);
}
Expand All @@ -243,7 +249,6 @@ void IDisposable.Dispose()
/// <summary>POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates.</summary>
private sealed class FunctionCallInfo
{
public string? CallId;
public string? Name;
public StringBuilder? Arguments;
}
Expand Down Expand Up @@ -292,7 +297,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
// These properties are strongly-typed on ChatOptions but not on ChatCompletionsOptions.
if (options.TopK is int topK)
{
result.AdditionalProperties["top_k"] = BinaryData.FromObjectAsJson(topK, JsonContext.Default.Options);
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, JsonContext.Default.Int32));
}

if (options.AdditionalProperties is { } props)
Expand All @@ -310,7 +315,8 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
default:
if (prop.Value is not null)
{
result.AdditionalProperties[prop.Key] = BinaryData.FromObjectAsJson(prop.Value, ToolCallJsonSerializerOptions);
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, JsonContext.GetTypeInfo(prop.Value.GetType(), ToolCallJsonSerializerOptions));
result.AdditionalProperties[prop.Key] = new BinaryData(data);
}

break;
Expand Down Expand Up @@ -356,7 +362,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
}

/// <summary>Converts an Extensions function to an AzureAI chat tool.</summary>
private static ChatCompletionsFunctionToolDefinition ToAzureAIChatTool(AIFunction aiFunction)
private static ChatCompletionsToolDefinition ToAzureAIChatTool(AIFunction aiFunction)
{
BinaryData resultParameters = AzureAIChatToolJson.ZeroFunctionParametersSchema;

Expand All @@ -381,28 +387,11 @@ private static ChatCompletionsFunctionToolDefinition ToAzureAIChatTool(AIFunctio
JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.AzureAIChatToolJson));
}

return new()
return new(new FunctionDefinition(aiFunction.Metadata.Name)
{
Name = aiFunction.Metadata.Name,
Description = aiFunction.Metadata.Description,
Parameters = resultParameters,
};
}

/// <summary>Used to create the JSON payload for an AzureAI chat tool description.</summary>
private sealed class AzureAIChatToolJson
{
/// <summary>Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function.</summary>
public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray());

[JsonPropertyName("type")]
public string Type { get; set; } = "object";

[JsonPropertyName("required")]
public List<string> Required { get; set; } = [];

[JsonPropertyName("properties")]
public Dictionary<string, JsonElement> Properties { get; set; } = [];
});
}

/// <summary>Converts an Extensions chat message enumerable to an AzureAI chat message enumerable.</summary>
Expand All @@ -426,10 +415,9 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null)
{
JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
try
{
result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object)));
result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
}
catch (NotSupportedException)
{
Expand All @@ -449,20 +437,17 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
{
// TODO: ChatRequestAssistantMessage only enables text content currently.
// Update it with other content types when it supports that.
ChatRequestAssistantMessage message = new()
{
Content = input.Text
};
ChatRequestAssistantMessage message = new(input.Text ?? string.Empty);

foreach (var content in input.Contents)
{
if (content is FunctionCallContent { CallId: not null } callRequest)
{
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
message.ToolCalls.Add(new ChatCompletionsFunctionToolCall(
callRequest.CallId,
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)))));
message.ToolCalls.Add(new ChatCompletionsToolCall(
callRequest.CallId,
new FunctionCall(
callRequest.Name,
JsonSerializer.Serialize(callRequest.Arguments, JsonContext.GetTypeInfo(typeof(IDictionary<string, object>), ToolCallJsonSerializerOptions)))));
}
}

Expand Down Expand Up @@ -504,11 +489,4 @@ private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> con
private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);

/// <summary>Source-generated JSON type information.</summary>
[JsonSerializable(typeof(AzureAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(int))]
private sealed partial class JsonContext : JsonSerializerContext;
}
Loading
Loading