Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change ChatClientBuilder to register singletons and support lambda-less chaining #5642

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,43 @@ namespace Microsoft.Extensions.AI;
/// <summary>A builder for creating pipelines of <see cref="IChatClient"/>.</summary>
public sealed class ChatClientBuilder
{
private Func<IServiceProvider, IChatClient> _innerClientFactory;

/// <summary>The registered client factory instances.</summary>
private List<Func<IServiceProvider, IChatClient, IChatClient>>? _clientFactories;

/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="services">The service provider to use for dependency injection.</param>
public ChatClientBuilder(IServiceProvider? services = null)
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(IChatClient innerClient)
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
{
Services = services ?? EmptyServiceProvider.Instance;
_ = Throw.IfNull(innerClient);
_innerClientFactory = _ => innerClient;
}

/// <summary>Gets the <see cref="IServiceProvider"/> associated with the builder instance.</summary>
public IServiceProvider Services { get; }
/// <summary>Initializes a new instance of the <see cref="ChatClientBuilder"/> class.</summary>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
public ChatClientBuilder(Func<IServiceProvider, IChatClient> innerClientFactory)
{
_innerClientFactory = Throw.IfNull(innerClientFactory);
}

/// <summary>Completes the pipeline by adding a final <see cref="IChatClient"/> that represents the underlying backend. This is typically a client for an LLM service.</summary>
/// <param name="innerClient">The inner client to use.</param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</returns>
public IChatClient Use(IChatClient innerClient)
/// <summary>Returns an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
/// <param name="services">
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IChatClient"/> instances.
/// If null, an empty <see cref="IServiceProvider"/> will be used.
/// </param>
/// <returns>An instance of <see cref="IChatClient"/> that represents the entire pipeline.</returns>
public IChatClient Build(IServiceProvider? services = null)
{
var chatClient = Throw.IfNull(innerClient);
services ??= EmptyServiceProvider.Instance;
var chatClient = _innerClientFactory(services);

// To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost.
if (_clientFactories is not null)
{
for (var i = _clientFactories.Count - 1; i >= 0; i--)
{
chatClient = _clientFactories[i](Services, chatClient) ??
chatClient = _clientFactories[i](services, chatClient) ??
throw new InvalidOperationException(
$"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " +
$"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,71 @@ namespace Microsoft.Extensions.DependencyInjection;
public static class ChatClientBuilderServiceCollectionExtensions
{
/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="services">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="clientFactory">The factory to use to construct the <see cref="IChatClient"/> instance.</param>
/// <returns>The <paramref name="services"/> collection.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static IServiceCollection AddChatClient(
this IServiceCollection services,
Func<ChatClientBuilder, IChatClient> clientFactory)
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a singleton service.</remarks>
public static ChatClientBuilder AddChatClient(
this IServiceCollection serviceCollection,
IChatClient innerClient)
=> AddChatClient(serviceCollection, _ => innerClient);

/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <typeparam name="T">The type of the inner <see cref="IChatClient"/> that represents the underlying backend. This will be resolved from the service provider.</typeparam>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a singleton service.</remarks>
public static ChatClientBuilder AddChatClient<T>(
this IServiceCollection serviceCollection)
where T : IChatClient
=> AddChatClient(serviceCollection, services => services.GetRequiredService<T>());
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a singleton service.</remarks>
public static ChatClientBuilder AddChatClient(
this IServiceCollection serviceCollection,
Func<IServiceProvider, IChatClient> innerClientFactory)
{
_ = Throw.IfNull(services);
_ = Throw.IfNull(clientFactory);
_ = Throw.IfNull(serviceCollection);
_ = Throw.IfNull(innerClientFactory);

return services.AddScoped(services =>
clientFactory(new ChatClientBuilder(services)));
var builder = new ChatClientBuilder(innerClientFactory);
_ = serviceCollection.AddSingleton(builder.Build);
return builder;
}

/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="services">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="serviceKey">The key with which to associate the client.</param>
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static ChatClientBuilder AddKeyedChatClient(
this IServiceCollection serviceCollection,
object serviceKey,
IChatClient innerClient)
=> AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient);

/// <summary>Adds a chat client to the <see cref="IServiceCollection"/>.</summary>
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
/// <param name="serviceKey">The key with which to associate the client.</param>
/// <param name="clientFactory">The factory to use to construct the <see cref="IChatClient"/> instance.</param>
/// <returns>The <paramref name="services"/> collection.</returns>
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner client.</returns>
/// <remarks>The client is registered as a scoped service.</remarks>
public static IServiceCollection AddKeyedChatClient(
this IServiceCollection services,
public static ChatClientBuilder AddKeyedChatClient(
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
this IServiceCollection serviceCollection,
object serviceKey,
Func<ChatClientBuilder, IChatClient> clientFactory)
Func<IServiceProvider, IChatClient> innerClientFactory)
{
_ = Throw.IfNull(services);
_ = Throw.IfNull(serviceCollection);
_ = Throw.IfNull(serviceKey);
_ = Throw.IfNull(clientFactory);
_ = Throw.IfNull(innerClientFactory);

return services.AddKeyedScoped(serviceKey, (services, _) =>
clientFactory(new ChatClientBuilder(services)));
var builder = new ChatClientBuilder(innerClientFactory);
_ = serviceCollection.AddKeyedSingleton(serviceKey, builder.Build);
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()

Assert.Same(client, chatClient.GetService<ChatCompletionsClient>());

using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();

Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls()
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseFunctionInvocation()
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();

var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
Expand Down Expand Up @@ -415,12 +415,12 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();

var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
Expand Down Expand Up @@ -454,12 +454,12 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA
}, "GetTemperature");

// First call executes the function and calls the LLM
using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.ConfigureOptions(options => options.Tools = [getTemperature])
.UseFunctionInvocation()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.UseCallCounting()
.Use(CreateChatClient()!);
.Build();

var llmCallCount = chatClient.GetService<CallCountingChatClient>();
var message = new ChatMessage(ChatRole.User, "What is the temperature?");
Expand Down Expand Up @@ -573,9 +573,9 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.AddInMemoryExporter(activities)
.Build();

var chatClient = new ChatClientBuilder()
var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseOpenTelemetry(sourceName: sourceName)
.Use(CreateChatClient()!);
.Build();

var response = await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit()
}
};

using var client = new ChatClientBuilder()
using var client = new ChatClientBuilder(innerClient)
.UseChatReducer(new TokenCountingChatReducer(_gpt4oTokenizer, 40))
.Use(innerClient);
.Build();

List<ChatMessage> messages =
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ public async Task PromptBasedFunctionCalling_NoArgs()
{
SkipIfNotEnabled();

using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
.Use(CreateChatClient()!);
.Build();

var secretNumber = 42;
var response = await chatClient.CompleteAsync("What is the current secret number? Answer with digits only.", new ChatOptions
Expand All @@ -61,11 +61,11 @@ public async Task PromptBasedFunctionCalling_WithArgs()
{
SkipIfNotEnabled();

using var chatClient = new ChatClientBuilder()
using var chatClient = new ChatClientBuilder(CreateChatClient()!)
.UseFunctionInvocation()
.UsePromptBasedFunctionCalling()
.Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient))
.Use(CreateChatClient()!);
.Build();

var stockPriceTool = AIFunctionFactory.Create([Description("Returns the stock price for a given ticker symbol")] (
[Description("The ticker symbol")] string symbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()
Assert.Same(client, client.GetService<OllamaChatClient>());
Assert.Same(client, client.GetService<IChatClient>());

using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(client)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(client);
.Build();

Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient()

Assert.NotNull(chatClient.GetService<ChatClient>());

using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();

Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
Expand All @@ -119,11 +119,11 @@ public void GetService_ChatClient_SuccessfullyReturnsUnderlyingClient()
Assert.Same(chatClient, chatClient.GetService<IChatClient>());
Assert.Same(openAIClient, chatClient.GetService<ChatClient>());

using IChatClient pipeline = new ChatClientBuilder()
using IChatClient pipeline = new ChatClientBuilder(chatClient)
.UseFunctionInvocation()
.UseOpenTelemetry()
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
.Use(chatClient);
.Build();

Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,38 @@ public class ChatClientBuilderTest
public void PassesServiceProviderToFactories()
{
var expectedServiceProvider = new ServiceCollection().BuildServiceProvider();
using TestChatClient expectedResult = new();
var builder = new ChatClientBuilder(expectedServiceProvider);
using TestChatClient expectedInnerClient = new();
using TestChatClient expectedOuterClient = new();

var builder = new ChatClientBuilder(services =>
{
Assert.Same(expectedServiceProvider, services);
return expectedInnerClient;
});

builder.Use((serviceProvider, innerClient) =>
{
Assert.Same(expectedServiceProvider, serviceProvider);
return expectedResult;
Assert.Same(expectedInnerClient, innerClient);
return expectedOuterClient;
});

using TestChatClient innerClient = new();
Assert.Equal(expectedResult, builder.Use(innerClient: innerClient));
Assert.Same(expectedOuterClient, builder.Build(expectedServiceProvider));
}

[Fact]
public void BuildsPipelineInOrderAdded()
{
// Arrange
using TestChatClient expectedInnerClient = new();
var builder = new ChatClientBuilder();
var builder = new ChatClientBuilder(expectedInnerClient);

builder.Use(next => new InnerClientCapturingChatClient("First", next));
builder.Use(next => new InnerClientCapturingChatClient("Second", next));
builder.Use(next => new InnerClientCapturingChatClient("Third", next));

// Act
var first = (InnerClientCapturingChatClient)builder.Use(expectedInnerClient);
var first = (InnerClientCapturingChatClient)builder.Build();

// Assert
Assert.Equal("First", first.Name);
Expand All @@ -52,23 +58,22 @@ public void BuildsPipelineInOrderAdded()
[Fact]
public void DoesNotAcceptNullInnerService()
{
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder().Use((IChatClient)null!));
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder((IChatClient)null!));
}

[Fact]
public void DoesNotAcceptNullFactories()
{
ChatClientBuilder builder = new();
Assert.Throws<ArgumentNullException>(() => builder.Use((Func<IChatClient, IChatClient>)null!));
Assert.Throws<ArgumentNullException>(() => builder.Use((Func<IServiceProvider, IChatClient, IChatClient>)null!));
Assert.Throws<ArgumentNullException>(() => new ChatClientBuilder((Func<IServiceProvider, IChatClient>)null!));
}

[Fact]
public void DoesNotAllowFactoriesToReturnNull()
{
ChatClientBuilder builder = new();
using var innerClient = new TestChatClient();
ChatClientBuilder builder = new(innerClient);
builder.Use(_ => null!);
var ex = Assert.Throws<InvalidOperationException>(() => builder.Use(new TestChatClient()));
var ex = Assert.Throws<InvalidOperationException>(() => builder.Build());
Assert.Contains("entry at index 0", ex.Message);
}

Expand Down
Loading
Loading