Skip to content

Commit

Permalink
Change ChatClientBuilder to register singletons and support lambda-le…
Browse files Browse the repository at this point in the history
…ss chaining (#5642)

* Change ChatClientBuilder to register singletons and support lambda-less chaining

* Add generic keyed version

* Improve XML doc

* Update README files

* Remove generic DI registration methods
  • Loading branch information
SteveSandersonMS authored Nov 14, 2024
1 parent d39bf3d commit 56e720c
Show file tree
Hide file tree
Showing 21 changed files with 239 additions and 187 deletions.
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ using Microsoft.Extensions.AI;
[Description("Gets the current weather")]
string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining";

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseFunctionInvocation()
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();

var response = client.CompleteStreamingAsync(
"Should I wear a rain coat?",
Expand All @@ -174,9 +174,9 @@ using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();

string[] prompts = ["What is AI?", "What is .NET?", "What is AI?"];

Expand Down Expand Up @@ -205,9 +205,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
.AddConsoleExporter()
.Build();

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"));
.Build();

Console.WriteLine((await client.CompleteAsync("What is AI?")).Message);
```
Expand All @@ -220,9 +220,9 @@ Options may also be baked into an `IChatClient` via the `ConfigureOptions` exten
```csharp
using Microsoft.Extensions.AI;

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434")))
.ConfigureOptions(options => options.ModelId ??= "phi3")
.Use(new OllamaChatClient(new Uri("http://localhost:11434")));
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?")); // will request "phi3"
Console.WriteLine(await client.CompleteAsync("What is AI?", new() { ModelId = "llama3.1" })); // will request "llama3.1"
Expand All @@ -248,11 +248,11 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()

// Explore changing the order of the intermediate "Use" calls to see that impact
// that has on what gets cached, traced, etc.
IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"))
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"));
.Build();

ChatOptions options = new()
{
Expand Down Expand Up @@ -341,9 +341,8 @@ using Microsoft.Extensions.Hosting;
// App Setup
var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddChatClient(b => b
.UseDistributedCache()
.Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")));
builder.Services.AddChatClient(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))
.UseDistributedCache();
var host = builder.Build();

// Elsewhere in the app
Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.AzureAIInference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseFunctionInvocation()
.Use(azureClient);
.Build();

ChatOptions chatOptions = new()
{
Expand Down Expand Up @@ -120,9 +120,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.Use(azureClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -156,9 +156,9 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -196,11 +196,11 @@ IChatClient azureClient =
new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(azureClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(azureClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -236,10 +236,9 @@ builder.Services.AddSingleton(
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();

var app = builder.Build();

Expand All @@ -261,8 +260,8 @@ builder.Services.AddSingleton(new ChatCompletionsClient(
new("https://models.inference.ai.azure.com"),
new AzureKeyCredential(builder.Configuration["GH_TOKEN"]!)));

builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<ChatCompletionsClient>().AsChatClient("gpt-4o-mini"));

var app = builder.Build();

Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ using Microsoft.Extensions.AI;

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseFunctionInvocation()
.Use(ollamaClient);
.Build();

ChatOptions chatOptions = new()
{
Expand All @@ -97,9 +97,9 @@ IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDi

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.Use(ollamaClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -128,9 +128,9 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -163,11 +163,11 @@ var chatOptions = new ChatOptions

IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(ollamaClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(ollamaClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -235,10 +235,9 @@ var builder = Host.CreateApplicationBuilder();
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"))
.UseDistributedCache()
.UseLogging()
.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
.UseLogging();

var app = builder.Build();

Expand All @@ -254,8 +253,8 @@ using Microsoft.Extensions.AI;

var builder = WebApplication.CreateBuilder(args);

builder.Services.AddChatClient(c =>
c.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1")));
builder.Services.AddChatClient(
new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"));

builder.Services.AddEmbeddingGenerator<string,Embedding<float>>(g =>
g.Use(new OllamaEmbeddingGenerator(endpoint, "all-minilm")));
Expand Down
25 changes: 12 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseFunctionInvocation()
.Use(openaiClient);
.Build();

ChatOptions chatOptions = new()
{
Expand Down Expand Up @@ -110,9 +110,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.Use(openaiClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -144,9 +144,9 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();

Console.WriteLine(await client.CompleteAsync("What is AI?"));
```
Expand Down Expand Up @@ -182,11 +182,11 @@ IChatClient openaiClient =
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
.AsChatClient("gpt-4o-mini");

IChatClient client = new ChatClientBuilder()
IChatClient client = new ChatClientBuilder(openaiClient)
.UseDistributedCache(cache)
.UseFunctionInvocation()
.UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true)
.Use(openaiClient);
.Build();

for (int i = 0; i < 3; i++)
{
Expand Down Expand Up @@ -260,10 +260,9 @@ builder.Services.AddSingleton(new OpenAIClient(Environment.GetEnvironmentVariabl
builder.Services.AddDistributedMemoryCache();
builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace));

builder.Services.AddChatClient(b => b
builder.Services.AddChatClient(services => services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"))
.UseDistributedCache()
.UseLogging()
.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
.UseLogging();

var app = builder.Build();

Expand All @@ -282,8 +281,8 @@ var builder = WebApplication.CreateBuilder(args);

builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API_KEY"]));

builder.Services.AddChatClient(b =>
b.Use(b.Services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini")));
builder.Services.AddChatClient(services =>
services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"));

builder.Services.AddEmbeddingGenerator<string, Embedding<float>>(g =>
g.Use(g.Services.GetRequiredService<OpenAIClient>().AsEmbeddingGenerator("text-embedding-3-small")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>A builder for creating pipelines of <see cref="IChatClient"/>.</summary>
public sealed class ChatClientBuilder
{
private Func<IServiceProvider, IChatClient> _innerClientFactory;

/// <summary>The registered client factory instances.</summary>
private List<Func<IServiceProvider, IChatClient, IChatClient>>? _clientFactories;

/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="services">The service provider to use for dependency injection.</param>
public ChatClientBuilder(IServiceProvider? services = null)
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(IChatClient innerClient)
{
Services = services ?? EmptyServiceProvider.Instance;
_ = Throw.IfNull(innerClient);
_innerClientFactory = _ => innerClient;
}

/// <summary>Gets the <see cref="IServiceProvider"/> associated with the builder instance.</summary>
public IServiceProvider Services { get; }
/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(Func<IServiceProvider, IChatClient> innerClientFactory)
{
_innerClientFactory = Throw.IfNull(innerClientFactory);
}

/// <summary>Completes the pipeline by adding a final <see cref="IChatClient"/> that represents the underlying backend. This is typically a client for an LLM service.</summary>
/// <param name="innerClient">The inner client to use.</param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</returns>
public IChatClient Use(IChatClient innerClient)
/// <summary>Returns an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
/// <param name="services">
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IChatClient"/> instances.
/// If null, an empty <see cref="IServiceProvider"/> will be used.
/// </param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline.</returns>
public IChatClient Build(IServiceProvider? services = null)
{
var chatClient = Throw.IfNull(innerClient);
services ??= EmptyServiceProvider.Instance;
var chatClient = _innerClientFactory(services);

// To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost.
if (_clientFactories is not null)
{
for (var i = _clientFactories.Count - 1; i >= 0; i--)
{
chatClient = _clientFactories[i](Services, chatClient) ??
chatClient = _clientFactories[i](services, chatClient) ??
throw new InvalidOperationException(
$"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " +
$"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances.");
Expand Down
Loading

0 comments on commit 56e720c

Please sign in to comment.