diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 770ffa60cfc..f2de7f92fc8 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; @@ -16,6 +17,12 @@ namespace Microsoft.Extensions.AI; /// public abstract class CachingChatClient : DelegatingChatClient { + /// A boxed value. + private static readonly object _boxedTrue = true; + + /// A boxed value. + private static readonly object _boxedFalse = false; + /// Initializes a new instance of the class. /// The underlying . protected CachingChatClient(IChatClient innerClient) @@ -45,7 +52,7 @@ public override async Task CompleteAsync(IList chat // We're only storing the final result, not the in-flight task, so that we can avoid caching failures // or having problems when one of the callers cancels but others don't. This has the drawback that // concurrent callers might trigger duplicate requests, but that's acceptable. - var cacheKey = GetCacheKey(false, chatMessages, options); + var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options); if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) { @@ -68,7 +75,7 @@ public override async IAsyncEnumerable CompleteSt // we make a streaming request, yielding those results, but then convert those into a non-streaming // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. - var cacheKey = GetCacheKey(true, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion) { // Yield all of the cached items. @@ -93,7 +100,7 @@ public override async IAsyncEnumerable CompleteSt } else { - var cacheKey = GetCacheKey(true, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) { // Yield all of the cached items. @@ -118,14 +125,10 @@ public override async IAsyncEnumerable CompleteSt } } - /// - /// Computes a cache key for the specified call parameters. - /// - /// A flag to indicate if this is a streaming call. - /// The chat content. - /// The chat options to configure the request. - /// A string that will be used as a cache key. - protected abstract string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options); + /// Computes a cache key for the specified values. + /// The values to inform the key. + /// The computed key. + protected abstract string GetCacheKey(params ReadOnlySpan values); /// /// Returns a previously cached , if available. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs index 678e9bd6523..a5bee20fa48 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -20,12 +20,6 @@ namespace Microsoft.Extensions.AI; /// public class DistributedCachingChatClient : CachingChatClient { - /// A boxed value. - private static readonly object _boxedTrue = true; - - /// A boxed value. - private static readonly object _boxedFalse = false; - /// The instance that will be used as the backing store for the cache. private readonly IDistributedCache _storage; @@ -98,15 +92,11 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); } - /// - protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) => - GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]); - - /// Gets a cache key based on the supplied values. + /// Computes a cache key for the specified values. /// The values to inform the key. /// The computed key. - /// This provides the default implementation for . - protected string GetCacheKey(ReadOnlySpan values) + /// The are serialized to JSON using in order to compute the key. + protected override string GetCacheKey(params ReadOnlySpan values) { _jsonSerializerOptions.MakeReadOnly(); return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index d632431102c..688e4b2353d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -106,13 +106,10 @@ public override async Task> GenerateAsync( return results; } - /// - /// Computes a cache key for the specified call parameters. - /// - /// The for which an embedding is being requested. - /// The options to configure the request. - /// A string that will be used as a cache key. - protected abstract string GetCacheKey(TInput value, EmbeddingGenerationOptions? options); + /// Computes a cache key for the specified values. + /// The values to inform the key. + /// The computed key. + protected abstract string GetCacheKey(params ReadOnlySpan values); /// Returns a previously cached , if available. /// The cache key. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index 6482ed8ed2b..32abb78e18b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -74,15 +74,11 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); } - /// - protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) => - GetCacheKey([value, options]); - - /// Gets a cache key based on the supplied values. + /// Computes a cache key for the specified values. /// The values to inform the key. /// The computed key. - /// This provides the default implementation for . - protected string GetCacheKey(ReadOnlySpan values) + /// The are serialized to JSON using in order to compute the key. + protected override string GetCacheKey(params ReadOnlySpan values) { _jsonSerializerOptions.MakeReadOnly(); return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index dcc6068b3ce..7ace4f2d294 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -815,10 +815,18 @@ private static async Task AssertCompletionsEqualAsync(IReadOnlyList chatMessages, ChatOptions? options) + protected override string GetCacheKey(params ReadOnlySpan values) { - var baseKey = base.GetCacheKey(streaming, chatMessages, options); - return baseKey + options?.AdditionalProperties?["someKey"]?.ToString(); + var baseKey = base.GetCacheKey(values); + foreach (var value in values) + { + if (value is ChatOptions options) + { + return baseKey + options.AdditionalProperties?["someKey"]?.ToString(); + } + } + + return baseKey; } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs index 55cc206ebfc..d32a249c7dc 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -350,7 +350,18 @@ private static void AssertEmbeddingsEqual(Embedding expected, Embedding> innerGenerator, IDistributedCache storage) : DistributedCachingEmbeddingGenerator>(innerGenerator, storage) { - protected override string GetCacheKey(string value, EmbeddingGenerationOptions? options) => - base.GetCacheKey(value, options) + options?.AdditionalProperties?["someKey"]?.ToString(); + protected override string GetCacheKey(params ReadOnlySpan values) + { + var baseKey = base.GetCacheKey(values); + foreach (var value in values) + { + if (value is EmbeddingGenerationOptions options) + { + return baseKey + options.AdditionalProperties?["someKey"]?.ToString(); + } + } + + return baseKey; + } } }