Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Add AutoGen.AzureAIInference #3332

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dotnet/AutoGen.sln
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sa
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -194,6 +200,18 @@ Global
{12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU
{12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU
{12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -229,6 +247,9 @@ Global
{6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
{5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
{143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
Expand Down
1 change: 1 addition & 0 deletions dotnet/eng/Version.props
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<MicrosoftASPNETCoreVersion>8.0.4</MicrosoftASPNETCoreVersion>
<GoogleCloudAPIPlatformVersion>3.0.0</GoogleCloudAPIPlatformVersion>
<JsonSchemaVersion>4.3.0.2</JsonSchemaVersion>
<AzureAIInferenceVersion>1.0.0-beta.1</AzureAIInferenceVersion>
<PowershellSDKVersion>7.4.4</PowershellSDKVersion>
</PropertyGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionsClientAgent.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.AzureAIInference.Extension;
using AutoGen.Core;
using Azure.AI.Inference;

namespace AutoGen.AzureAIInference;

/// <summary>
/// ChatCompletions client agent. This agent is a thin wrapper around <see cref="ChatCompletionsClient"/> to provide a simple interface for chat completions.
/// <para><see cref="ChatCompletionsClientAgent" /> supports the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatRequestMessage"/>: chat request message.
/// </item>
/// </list>
/// <para><see cref="ChatCompletionsClientAgent" /> returns the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatResponseMessage"/>: chat response message.
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="StreamingChatCompletionsUpdate"/>: streaming chat completions update.
/// </item>
/// </list>
/// </summary>
public class ChatCompletionsClientAgent : IStreamingAgent
{
private readonly ChatCompletionsClient chatCompletionsClient;
private readonly ChatCompletionsOptions options;
private readonly string systemMessage;

/// <summary>
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>.
/// </summary>
/// <param name="chatCompletionsClient">chat completions client</param>
/// <param name="name">agent name</param>
/// <param name="modelName">model name. e.g. gpt-turbo-3.5</param>
/// <param name="systemMessage">system message</param>
/// <param name="temperature">temperature</param>
/// <param name="maxTokens">max tokens to generated</param>
/// <param name="responseFormat">response format, set it to <see cref="ChatCompletionsResponseFormatJSON"/> to enable json mode.</param>
/// <param name="seed">seed to use, set it to enable deterministic output</param>
/// <param name="functions">functions</param>
public ChatCompletionsClientAgent(
ChatCompletionsClient chatCompletionsClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
: this(
chatCompletionsClient: chatCompletionsClient,
name: name,
options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions),
systemMessage: systemMessage)
{
}

/// <summary>
/// Create a new instance of <see cref="ChatCompletionsClientAgent"/>.
/// </summary>
/// <param name="chatCompletionsClient">chat completions client</param>
/// <param name="name">agent name</param>
/// <param name="systemMessage">system message</param>
/// <param name="options">chat completion option. The option can't contain messages</param>
public ChatCompletionsClientAgent(
ChatCompletionsClient chatCompletionsClient,
string name,
ChatCompletionsOptions options,
string systemMessage = "You are a helpful AI assistant")
{
if (options.Messages is { Count: > 0 })
{
throw new ArgumentException("Messages should not be provided in options");
}

this.chatCompletionsClient = chatCompletionsClient;
this.Name = name;
this.options = options;
this.systemMessage = systemMessage;
}

public string Name { get; }

public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var reply = await this.chatCompletionsClient.CompleteAsync(settings, cancellationToken: cancellationToken);

return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
}

public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var response = await this.chatCompletionsClient.CompleteStreamingAsync(settings, cancellationToken);
await foreach (var update in response.WithCancellation(cancellationToken))
{
yield return new MessageEnvelope<StreamingChatCompletionsUpdate>(update, from: this.Name);
}
}

private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable<IMessage> messages)
{
var oaiMessages = messages.Select(m => m switch
{
IMessage<ChatRequestMessage> chatRequestMessage => chatRequestMessage.Content,
_ => throw new ArgumentException("Invalid message type")
});

// add system message if there's no system message in messages
if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
{
oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages);
}

// clone the options by serializing and deserializing
var json = JsonSerializer.Serialize(this.options);
var settings = JsonSerializer.Deserialize<ChatCompletionsOptions>(json) ?? throw new InvalidOperationException("Failed to clone options");

foreach (var m in oaiMessages)
{
settings.Messages.Add(m);
}

settings.Temperature = options?.Temperature ?? settings.Temperature;
settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens;

foreach (var functions in this.options.Tools)
{
settings.Tools.Add(functions);
}

foreach (var stopSequence in this.options.StopSequences)
{
settings.StopSequences.Add(stopSequence);
}

var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToAzureAIInferenceFunctionDefinition()).ToList();
if (openAIFunctionDefinitions is { Count: > 0 })
{
foreach (var f in openAIFunctionDefinitions)
{
settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}

if (options?.StopSequence is var sequence && sequence is { Length: > 0 })
{
foreach (var seq in sequence)
{
settings.StopSequences.Add(seq);
}
}

return settings;
}

private static ChatCompletionsOptions CreateChatCompletionOptions(
string modelName,
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
{
var options = new ChatCompletionsOptions()
{
Model = modelName,
Temperature = temperature,
MaxTokens = maxTokens,
Seed = seed,
ResponseFormat = responseFormat,
};

if (functions is not null)
{
foreach (var f in functions)
{
options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}

return options;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(PackageTargetFrameworks)</TargetFrameworks>
<RootNamespace>AutoGen.AzureAIInference</RootNamespace>
</PropertyGroup>

<Import Project="$(RepoRoot)/nuget/nuget-package.props" />

<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.AzureAIInference</Title>
<Description>
Azure AI Inference Intergration for AutoGen.
</Description>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.Inference" Version="$(AzureAIInferenceVersion)" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatComptionClientAgentExtension.cs

using AutoGen.Core;

namespace AutoGen.AzureAIInference.Extension;

public static class ChatComptionClientAgentExtension
{
/// <summary>
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="ChatCompletionsClientAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector(
this ChatCompletionsClientAgent agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new AzureAIInferenceChatRequestMessageConnector();
}

return agent.RegisterStreamingMiddleware(connector);
}

/// <summary>
/// Register an <see cref="AzureAIInferenceChatRequestMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="ChatCompletionsClientAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="AzureAIInferenceChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<ChatCompletionsClientAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<ChatCompletionsClientAgent> agent, AzureAIInferenceChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new AzureAIInferenceChatRequestMessageConnector();
}

return agent.RegisterStreamingMiddleware(connector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtension.cs

using System;
using System.Collections.Generic;
using AutoGen.Core;
using Azure.AI.Inference;
using Json.Schema;
using Json.Schema.Generation;

namespace AutoGen.AzureAIInference.Extension;

public static class FunctionContractExtension
{
/// <summary>
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDefinition"/> that can be used in gpt funciton call.
/// </summary>
/// <param name="functionContract">function contract</param>
/// <returns><see cref="FunctionDefinition"/></returns>
public static FunctionDefinition ToAzureAIInferenceFunctionDefinition(this FunctionContract functionContract)
{
var functionDefinition = new FunctionDefinition
{
Name = functionContract.Name,
Description = functionContract.Description,
};
var requiredParameterNames = new List<string>();
var propertiesSchemas = new Dictionary<string, JsonSchema>();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
foreach (var param in functionContract.Parameters ?? [])
{
if (param.Name is null)
{
throw new InvalidOperationException("Parameter name cannot be null");
}

var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
if (param.Description != null)
{
schemaBuilder = schemaBuilder.Description(param.Description);
}

if (param.IsRequired)
{
requiredParameterNames.Add(param.Name);
}

var schema = schemaBuilder.Build();
propertiesSchemas[param.Name] = schema;

}
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);

var option = new System.Text.Json.JsonSerializerOptions()
{
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
};

functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);

return functionDefinition;
}
}
Loading
Loading