From 171a394edf58d5db855c37ab913d47196d6eeb86 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 23 Oct 2024 18:26:11 -0400 Subject: [PATCH] Add EmbeddingGeneratorOptions.Dimensions --- .../Embeddings/EmbeddingGenerationOptions.cs | 20 +++++++++++++++++++ .../OpenAIEmbeddingGenerator.cs | 8 +------- .../EmbeddingGenerationOptionsTests.cs | 16 +++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs index 02875e9de98..27b84273b5b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs @@ -1,11 +1,30 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.Shared.Diagnostics; + namespace Microsoft.Extensions.AI; /// Represents the options for an embedding generation request. public class EmbeddingGenerationOptions { + private int? _dimensions; + + /// Gets or sets the number of dimensions requested in the embedding. + public int? Dimensions + { + get => _dimensions; + set + { + if (value is not null) + { + _ = Throw.IfLessThan(value.Value, 1); + } + + _dimensions = value; + } + } + /// Gets or sets the model ID for the embedding generation request. public string? ModelId { get; set; } @@ -22,6 +41,7 @@ public virtual EmbeddingGenerationOptions Clone() => new() { ModelId = ModelId, + Dimensions = Dimensions, AdditionalProperties = AdditionalProperties?.Clone(), }; } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 27bf001b3ff..155e047279f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -135,17 +135,11 @@ void IDisposable.Dispose() { OpenAI.Embeddings.EmbeddingGenerationOptions openAIOptions = new() { - Dimensions = _dimensions, + Dimensions = options?.Dimensions ?? _dimensions, }; if (options?.AdditionalProperties is { Count: > 0 } additionalProperties) { - // Allow per-instance dimensions to be overridden by a per-call property - if (additionalProperties.TryGetValue(nameof(openAIOptions.Dimensions), out int? dimensions)) - { - openAIOptions.Dimensions = dimensions; - } - if (additionalProperties.TryGetValue(nameof(openAIOptions.EndUserId), out string? endUserId)) { openAIOptions.EndUserId = endUserId; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs index e9dd45959c7..fbc8b390abf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Text.Json; using Xunit; @@ -14,10 +15,20 @@ public void Constructor_Parameterless_PropsDefaulted() EmbeddingGenerationOptions options = new(); Assert.Null(options.ModelId); Assert.Null(options.AdditionalProperties); + Assert.Null(options.Dimensions); EmbeddingGenerationOptions clone = options.Clone(); Assert.Null(clone.ModelId); Assert.Null(clone.AdditionalProperties); + Assert.Null(clone.Dimensions); + } + + [Fact] + public void InvalidArgs_Throws() + { + EmbeddingGenerationOptions options = new(); + Assert.Throws(() => options.Dimensions = 0); + Assert.Throws(() => options.Dimensions = -1); } [Fact] @@ -31,13 +42,16 @@ public void Properties_Roundtrip() }; options.ModelId = "modelId"; + options.Dimensions = 1536; options.AdditionalProperties = additionalProps; Assert.Equal("modelId", options.ModelId); + Assert.Equal(1536, options.Dimensions); Assert.Same(additionalProps, options.AdditionalProperties); EmbeddingGenerationOptions clone = options.Clone(); Assert.Equal("modelId", clone.ModelId); + Assert.Equal(1536, clone.Dimensions); Assert.Equal(additionalProps, clone.AdditionalProperties); } @@ -53,6 +67,7 @@ public void JsonSerialization_Roundtrips() options.ModelId = "model"; options.AdditionalProperties = additionalProps; + options.Dimensions = 1536; string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.EmbeddingGenerationOptions); @@ -60,6 +75,7 @@ public void JsonSerialization_Roundtrips() Assert.NotNull(deserialized); Assert.Equal("model", deserialized.ModelId); + Assert.Equal(1536, deserialized.Dimensions); Assert.NotNull(deserialized.AdditionalProperties); Assert.Single(deserialized.AdditionalProperties);