Skip to content

Commit

Permalink
Improve EmbeddingGeneratorExtensions
Browse files Browse the repository at this point in the history
- Renames GenerateAsync extension method (not the interface method) to be GenerateEmbeddingAsync, since it produces a single TEmbedding
- Adds GenerateEmbeddingVectorAsync, which returns a `ReadOnlyMemory<T>`
- Adds a GenerateAndZipEmbeddingsAsync, which creates a `List<KeyValuePair<TInput, TEmbedding>>` that pairs the inputs with the outputs.
  • Loading branch information
stephentoub committed Oct 22, 2024
1 parent 651546f commit 41abe1e
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,35 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TValue,TEmbedding}"/> instances.</summary>
/// <summary>Provides a collection of static methods for extending <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> instances.</summary>
public static class EmbeddingGeneratorExtensions
{
/// <summary>Generates an embedding from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TValue">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(
this IEmbeddingGenerator<TValue, TEmbedding> generator,
TValue value,
/// <returns>
/// The generated embedding for the specified <paramref name="value"/>.
/// </returns>
/// <remarks>
/// This operations is equivalent to using <see cref="IEmbeddingGenerator{TInput, TEmbedding}.GenerateAsync"/> with a
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
/// resulting <see cref="GeneratedEmbeddings{TEmbedding}"/> collection.
/// </remarks>
public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
Expand All @@ -37,4 +46,64 @@ public static async Task<TEmbedding> GenerateAsync<TValue, TEmbedding>(

return embeddings[0];
}

/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="value">A value from which an embedding will be generated.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
/// <remarks>
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
/// </remarks>
public static async Task<ReadOnlyMemory<TEmbedding>> GenerateEmbeddingVectorAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, Embedding<TEmbedding>> generator,
TInput value,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
var embedding = await GenerateEmbeddingAsync(generator, value, options, cancellationToken).ConfigureAwait(false);
return embedding.Vector;
}

/// <summary>
/// Generates embeddings for each of the supplied <paramref name="values"/> and produces a list that pairs
/// each input with its resulting embedding.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The type of embedding to generate.</typeparam>
/// <param name="generator">The embedding generator.</param>
/// <param name="values">The collection of values for which to generate embeddings.</param>
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embeddings.</returns>
public static async Task<IList<KeyValuePair<TInput, TEmbedding>>> GenerateAndZipEmbeddingsAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
IEnumerable<TInput> values,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
_ = Throw.IfNull(values);

IList<TInput> inputs = values as IList<TInput> ?? values.ToList();

var embeddings = await generator.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false);
if (embeddings.Count != inputs.Count)
{
throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputs.Count}).");
}

List<KeyValuePair<TInput, TEmbedding>> results = new(embeddings.Count);
for (int i = 0; i < embeddings.Count; i++)
{
results.Add(new KeyValuePair<TInput, TEmbedding>(inputs[i], embeddings[i]));
}

return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

Expand All @@ -12,7 +13,9 @@ public class EmbeddingGeneratorExtensionsTests
[Fact]
public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
{
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateEmbeddingVectorAsync("hello"));
await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAndZipEmbeddingsAsync(["hello"]));
}

[Fact]
Expand All @@ -26,6 +29,35 @@ public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
};

Assert.Same(result, await service.GenerateAsync("hello"));
Assert.Same(result, await service.GenerateEmbeddingAsync("hello"));
Assert.Equal(result.Vector, await service.GenerateEmbeddingVectorAsync("hello"));
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(10)]
public async Task GenerateAndZipEmbeddingsAsync_ReturnsExpectedList(int count)
{
string[] inputs = Enumerable.Range(0, count).Select(i => $"hello {i}").ToArray();
Embedding<float>[] embeddings = Enumerable
.Range(0, count)
.Select(i => new Embedding<float>(Enumerable.Range(i, 4).Select(i => (float)i).ToArray()))
.ToArray();

using TestEmbeddingGenerator service = new()
{
GenerateAsyncCallback = (values, options, cancellationToken) =>
Task.FromResult<GeneratedEmbeddings<Embedding<float>>>(new(embeddings))
};

var results = await service.GenerateAndZipEmbeddingsAsync(inputs);
Assert.NotNull(results);
Assert.Equal(count, results.Count);
for (int i = 0; i < count; i++)
{
Assert.Equal(inputs[i], results[i].Key);
Assert.Same(embeddings[i], results[i].Value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ public virtual async Task Caching_SameOutputsForSameInput()
.Use(CreateEmbeddingGenerator()!);

string input = "Red, White, and Blue";
var embedding1 = await generator.GenerateAsync(input);
var embedding2 = await generator.GenerateAsync(input);
var embedding3 = await generator.GenerateAsync(input + "... and Green");
var embedding4 = await generator.GenerateAsync(input);
var embedding1 = await generator.GenerateEmbeddingAsync(input);
var embedding2 = await generator.GenerateEmbeddingAsync(input);
var embedding3 = await generator.GenerateEmbeddingAsync(input + "... and Green");
var embedding4 = await generator.GenerateEmbeddingAsync(input);

var callCounter = generator.GetService<CallCountingEmbeddingGenerator>();
Assert.NotNull(callCounter);
Expand All @@ -114,7 +114,7 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
.UseOpenTelemetry(sourceName: sourceName)
.Use(CreateEmbeddingGenerator()!);

_ = await embeddingGenerator.GenerateAsync("Hello, world!");
_ = await embeddingGenerator.GenerateEmbeddingAsync("Hello, world!");

Assert.Single(activities);
var activity = activities.Single();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public async Task CachesSuccessResultsAsync()
};

// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
var result1 = await outer.GenerateEmbeddingAsync("abc");
AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);

// Act
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.Equal(1, innerCallCount);
Expand Down Expand Up @@ -134,8 +134,8 @@ public async Task AllowsConcurrentCallsAsync()
};

// Act 1: Concurrent calls before resolution are passed into the inner client
var result1 = outer.GenerateAsync("abc");
var result2 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
var result2 = outer.GenerateEmbeddingAsync("abc");

// Assert 1
Assert.Equal(2, innerCallCount);
Expand All @@ -146,7 +146,7 @@ public async Task AllowsConcurrentCallsAsync()
AssertEmbeddingsEqual(_expectedEmbedding, await result2);

// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
var result3 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}
Expand All @@ -169,12 +169,12 @@ public async Task DoesNotCacheExceptionResultsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex1 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));
Assert.Equal("some failure", ex1.Message);
Assert.Equal(1, innerCallCount);

// Act
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateAsync("abc"));
var ex2 = await Assert.ThrowsAsync<InvalidTimeZoneException>(() => outer.GenerateEmbeddingAsync("abc"));

// Assert
Assert.NotSame(ex1, ex2);
Expand Down Expand Up @@ -207,15 +207,15 @@ public async Task DoesNotCacheCanceledResultsAsync()
};

// First call gets cancelled
var result1 = outer.GenerateAsync("abc");
var result1 = outer.GenerateEmbeddingAsync("abc");
Assert.False(result1.IsCompleted);
Assert.Equal(1, innerCallCount);
resolutionTcs.SetCanceled();
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => result1);
Assert.True(result1.IsCanceled);

// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateEmbeddingAsync("abc");
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
Expand All @@ -241,11 +241,11 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -277,11 +277,11 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
};

// Act: Call with two different options
var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});
Expand Down Expand Up @@ -315,7 +315,7 @@ public async Task CanResolveIDistributedCacheFromDI()

// Act: Make a request that should populate the cache
Assert.Empty(_storage.Keys);
var result = await outer.GenerateAsync("abc");
var result = await outer.GenerateEmbeddingAsync("abc");

// Assert
Assert.NotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level)
.UseLogging()
.Use(innerGenerator);

await generator.GenerateAsync("Blue whale");
await generator.GenerateEmbeddingAsync("Blue whale");

if (level is LogLevel.Trace)
{
Expand Down

0 comments on commit 41abe1e

Please sign in to comment.