Skip to content

Commit

Permalink
Add UseEmbeddingGenerationOptions (#5594)
Browse files Browse the repository at this point in the history
* Add UseEmbeddingGenerationOptions

Counterpart to UseChatOptions

* Document/test null options returned from callback
  • Loading branch information
stephentoub authored Nov 1, 2024
1 parent 53783e7 commit a12664e
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI;
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide
/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
Expand All @@ -31,6 +31,9 @@ namespace Microsoft.Extensions.AI;
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next client in the pipeline.
/// </para>
/// <para>
/// The provided implementation of <see cref="IChatClient"/> is thread-safe for concurrent use so long as the employed configuration
/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
Expand All @@ -39,15 +42,15 @@ namespace Microsoft.Extensions.AI;
public sealed class ConfigureOptionsChatClient : DelegatingChatClient
{
/// <summary>The callback delegate used to configure options.</summary>
private readonly Func<ChatOptions?, ChatOptions> _configureOptions;
private readonly Func<ChatOptions?, ChatOptions?> _configureOptions;

/// <summary>Initializes a new instance of the <see cref="ConfigureOptionsChatClient"/> class with the specified <paramref name="configureOptions"/> callback.</summary>
/// <param name="innerClient">The inner client.</param>
/// <param name="configureOptions">
/// The delegate to invoke to configure the <see cref="ChatOptions"/> instance. It is passed the caller-supplied <see cref="ChatOptions"/>
/// instance and should return the configured <see cref="ChatOptions"/> instance to use.
/// </param>
public ConfigureOptionsChatClient(IChatClient innerClient, Func<ChatOptions?, ChatOptions> configureOptions)
public ConfigureOptionsChatClient(IChatClient innerClient, Func<ChatOptions?, ChatOptions?> configureOptions)
: base(innerClient)
{
_configureOptions = Throw.IfNull(configureOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ public static class ConfigureOptionsChatClientBuilderExtensions
/// </param>
/// <returns>The <paramref name="builder"/>.</returns>
/// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="ChatOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new ChatOptions() { MaxTokens = 1000 }</c>. To provide
/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new ChatOptions() { MaxTokens = 1000 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
Expand All @@ -35,9 +36,13 @@ public static class ConfigureOptionsChatClientBuilderExtensions
/// return newOptions;
/// }
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next client in the pipeline.
/// </para>
/// </remarks>
public static ChatClientBuilder UseChatOptions(
this ChatClientBuilder builder, Func<ChatOptions?, ChatOptions> configureOptions)
this ChatClientBuilder builder, Func<ChatOptions?, ChatOptions?> configureOptions)
{
_ = Throw.IfNull(builder);
_ = Throw.IfNull(configureOptions);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

#pragma warning disable SA1629 // Documentation text should end with a period

namespace Microsoft.Extensions.AI;

/// <summary>A delegating embedding generator that updates or replaces the <see cref="EmbeddingGenerationOptions"/> used by the remainder of the pipeline.</summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
/// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="EmbeddingGenerationOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. To provide
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
/// <c>
/// options =>
/// {
/// var newOptions = options?.Clone() ?? new();
/// newOptions.Dimensions = 100;
/// return newOptions;
/// }
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next generator in the pipeline.
/// </para>
/// <para>
/// The provided implementation of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> is thread-safe for concurrent use so long as the employed configuration
/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
/// </para>
/// </remarks>
public sealed class ConfigureOptionsEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
where TEmbedding : Embedding
{
/// <summary>The callback delegate used to configure options.</summary>
private readonly Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> _configureOptions;

/// <summary>
/// Initializes a new instance of the <see cref="ConfigureOptionsEmbeddingGenerator{TInput, TEmbedding}"/> class with the
/// specified <paramref name="configureOptions"/> callback.
/// </summary>
/// <param name="innerGenerator">The inner generator.</param>
/// <param name="configureOptions">
/// The delegate to invoke to configure the <see cref="EmbeddingGenerationOptions"/> instance. It is passed the caller-supplied
/// <see cref="EmbeddingGenerationOptions"/> instance and should return the configured <see cref="EmbeddingGenerationOptions"/> instance to use.
/// </param>
public ConfigureOptionsEmbeddingGenerator(
IEmbeddingGenerator<TInput, TEmbedding> innerGenerator,
Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> configureOptions)
: base(innerGenerator)
{
_configureOptions = Throw.IfNull(configureOptions);
}

/// <inheritdoc/>
public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
IEnumerable<TInput> values,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
return await base.GenerateAsync(values, _configureOptions(options), cancellationToken).ConfigureAwait(false);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using Microsoft.Shared.Diagnostics;

#pragma warning disable SA1629 // Documentation text should end with a period

namespace Microsoft.Extensions.AI;

/// <summary>Provides extensions for configuring <see cref="ConfigureOptionsEmbeddingGenerator{TInput, TEmbedding}"/> instances.</summary>
public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions
{
/// <summary>
/// Adds a callback that updates or replaces <see cref="EmbeddingGenerationOptions"/>. This can be used to set default options.
/// </summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
/// <param name="builder">The <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/>.</param>
/// <param name="configureOptions">
/// The delegate to invoke to configure the <see cref="EmbeddingGenerationOptions"/> instance. It is passed the caller-supplied
/// <see cref="EmbeddingGenerationOptions"/> instance and should return the configured <see cref="EmbeddingGenerationOptions"/> instance to use.
/// </param>
/// <returns>The <paramref name="builder"/>.</returns>
/// <remarks>
/// <para>
/// The configuration callback is invoked with the caller-supplied <see cref="EmbeddingGenerationOptions"/> instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example <c>_ => new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. To provide
/// a new instance only if the caller-supplied instance is <see langword="null"/>, the callback may conditionally return a new instance, for example
/// <c>options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }</c>. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
/// <c>
/// options =>
/// {
/// var newOptions = options?.Clone() ?? new();
/// newOptions.Dimensions = 100;
/// return newOptions;
/// }
/// </c>
/// </para>
/// <para>
/// The callback may return <see langword="null"/>, in which case a <see langword="null"/> options will be passed to the next generator in the pipeline.
/// </para>
/// </remarks>
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> UseEmbeddingGenerationOptions<TInput, TEmbedding>(
this EmbeddingGeneratorBuilder<TInput, TEmbedding> builder,
Func<EmbeddingGenerationOptions?, EmbeddingGenerationOptions?> configureOptions)
where TEmbedding : Embedding
{
_ = Throw.IfNull(builder);
_ = Throw.IfNull(configureOptions);

return builder.Use(innerGenerator => new ConfigureOptionsEmbeddingGenerator<TInput, TEmbedding>(innerGenerator, configureOptions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ public void UseChatOptions_InvalidArgs_Throws()
Assert.Throws<ArgumentNullException>("configureOptions", () => builder.UseChatOptions(null!));
}

[Fact]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient()
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
{
ChatOptions providedOptions = new();
ChatOptions returnedOptions = new();
ChatOptions? returnedOptions = nullReturned ? null : new();
ChatCompletion expectedCompletion = new(Array.Empty<ChatMessage>());
var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray();
using CancellationTokenSource cts = new();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Extensions.AI;

public class ConfigureOptionsEmbeddingGeneratorTests
{
[Fact]
public void ConfigureOptionsEmbeddingGenerator_InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("innerGenerator", () => new ConfigureOptionsEmbeddingGenerator<string, Embedding<float>>(null!, _ => new EmbeddingGenerationOptions()));
Assert.Throws<ArgumentNullException>("configureOptions", () => new ConfigureOptionsEmbeddingGenerator<string, Embedding<float>>(new TestEmbeddingGenerator(), null!));
}

[Fact]
public void UseEmbeddingGenerationOptions_InvalidArgs_Throws()
{
var builder = new EmbeddingGeneratorBuilder<string, Embedding<float>>();
Assert.Throws<ArgumentNullException>("configureOptions", () => builder.UseEmbeddingGenerationOptions(null!));
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
{
EmbeddingGenerationOptions providedOptions = new();
EmbeddingGenerationOptions? returnedOptions = nullReturned ? null : new();
GeneratedEmbeddings<Embedding<float>> expectedEmbeddings = [];
using CancellationTokenSource cts = new();

using IEmbeddingGenerator<string, Embedding<float>> innerGenerator = new TestEmbeddingGenerator
{
GenerateAsyncCallback = (inputs, options, cancellationToken) =>
{
Assert.Same(returnedOptions, options);
Assert.Equal(cts.Token, cancellationToken);
return Task.FromResult(expectedEmbeddings);
}
};

using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
.UseEmbeddingGenerationOptions(options =>
{
Assert.Same(providedOptions, options);
return returnedOptions;
})
.Use(innerGenerator);

var embeddings = await generator.GenerateAsync([], providedOptions, cts.Token);
Assert.Same(expectedEmbeddings, embeddings);
}
}

0 comments on commit a12664e

Please sign in to comment.