Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change ChatClientBuilder to register singletons and support lambda-less chaining #5642

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ using Microsoft.Extensions.AI;
[Description("Gets the current weather")]
string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining";

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseFunctionInvocation()
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();

var response = client.CompleteStreamingAsync(
"Should I wear a rain coat?",
Expand All @@ -174,9 +174,9 @@ using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();

string[] prompts = ["What is AI?", "What is .NET?", "What is AI?"];

Expand Down Expand Up @@ -205,9 +205,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
.AddConsoleExporter()
.Build();

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();

Console.WriteLine((await client.CompleteAsync("What is AI?")).Message);
```
Expand All @@ -220,9 +220,9 @@ Options may also be baked into an `IChatClient` via the `ConfigureOptions` exten
```csharp
using Microsoft.Extensions.AI;

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434")))
.ConfigureOptions(options => options.ModelId ??= "phi3")
.Use(new OllamaChatClient(new Uri("http://localhost:11434")));
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?")); // will request "phi3"
Console.WriteLine(await client.CompleteAsync("What is AI?", new() { ModelId = "llama3.1" })); // will request "llama3.1"
Expand All @@ -248,11 +248,11 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()

// Explore changing the order of the intermediate "Use" calls to see that impact
// that has on what gets cached, traced, etc.
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();

ChatOptions options = new()
{
Expand Down Expand Up @@ -341,9 +341,8 @@ using Microsoft.Extensions.Hosting;
// App Setup
var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddChatClient(b => b
.UseDistributedCache()
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")));
builder.Services.AddChatClient(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache();
var host = builder.Build();

// Elsewhere in the app
Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.AzureAIInference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseFunctionInvocation()
.Use(azureClient);
.Build();

ChatOptions chatOptions = new()
{
Expand Down Expand Up @@ -120,9 +120,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.Use(azureClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -156,9 +156,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -196,11 +196,11 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -236,10 +236,9 @@ builder.Services.AddSingleton(
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();

var app = builder.Build();

Expand All @@ -261,8 +260,8 @@ builder.Services.AddSingleton(new ChatCompletionsClient(
new("https://models.inference.ai.azure.com"),
new AzureKeyCredential(builder.Configuration["GH_TOKEN"]!)));

builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"));

var app = builder.Build();

Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ using Microsoft.Extensions.AI;

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseFunctionInvocation()
.Use(ollamaClient);
.Build();

ChatOptions chatOptions = new()
{
Expand All @@ -97,9 +97,9 @@ IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDi

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.Use(ollamaClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -128,9 +128,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -163,11 +163,11 @@ var chatOptions = new ChatOptions

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -235,10 +235,9 @@ var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"))
.UseDistributedCache()
.UseLogging()
.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
.UseLogging();

var app = builder.Build();

Expand All @@ -254,8 +253,8 @@ using Microsoft.Extensions.AI;

var builder = WebApplication.CreateBuilder(args);

builder.Services.AddChatClient(c =>
c.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
builder.Services.AddChatClient(
new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"));

builder.Services.AddEmbeddingGenerator<string,Embedding<float>>(g =>
g.Use(new OllamaEmbeddingGenerator(endpoint, "all-minilm")));
Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseFunctionInvocation()
.Use(openaiClient);
.Build();

ChatOptions chatOptions = new()
{
Expand Down Expand Up @@ -110,9 +110,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.Use(openaiClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -144,9 +144,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -182,11 +182,11 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -260,10 +260,9 @@ builder.Services.AddSingleton(new OpenAIClient(Environment.GetEnvironmentVariabl
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();

var app = builder.Build();

Expand All @@ -282,8 +281,8 @@ var builder = WebApplication.CreateBuilder(args);

builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API_KEY"]));

builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"));

builder.Services.AddEmbeddingGenerator<string, Embedding<float>>(g =>
g.Use(g.Services.GetRequiredService<OpenAIClient>().AsEmbeddingGenerator("text-embedding-3-small")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>A builder for creating pipelines of <see cref="IChatClient"/>.</summary>
public sealed class ChatClientBuilder
{
private Func<IServiceProvider, IChatClient> _innerClientFactory;

/// <summary>The registered client factory instances.</summary>
private List<Func<IServiceProvider, IChatClient, IChatClient>>? _clientFactories;

/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="services">The service provider to use for dependency injection.</param>
public ChatClientBuilder(IServiceProvider? services = null)
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(IChatClient innerClient)
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
{
Services = services ?? EmptyServiceProvider.Instance;
_ = Throw.IfNull(innerClient);
_innerClientFactory = _ => innerClient;
}

/// <summary>Gets the <see cref="IServiceProvider"/> associated with the builder instance.</summary>
public IServiceProvider Services { get; }
/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(Func<IServiceProvider, IChatClient> innerClientFactory)
{
_innerClientFactory = Throw.IfNull(innerClientFactory);
}

/// <summary>Completes the pipeline by adding a final <see cref="IChatClient"/> that represents the underlying backend. This is typically a client for an LLM service.</summary>
/// <param name="innerClient">The inner client to use.</param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</returns>
public IChatClient Use(IChatClient innerClient)
/// <summary>Returns an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
/// <param name="services">
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IChatClient"/> instances.
/// If null, an empty <see cref="IServiceProvider"/> will be used.
/// </param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline.</returns>
public IChatClient Build(IServiceProvider? services = null)
{
var chatClient = Throw.IfNull(innerClient);
services ??= EmptyServiceProvider.Instance;
var chatClient = _innerClientFactory(services);

// To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost.
if (_clientFactories is not null)
{
for (var i = _clientFactories.Count - 1; i >= 0; i--)
{
chatClient = _clientFactories[i](Services, chatClient) ??
chatClient = _clientFactories[i](services, chatClient) ??
throw new InvalidOperationException(
$"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " +
$"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances.");
Expand Down
Loading
Loading