Skip to content

Commit

Permalink
Remove duplicate GetCacheKey methods (#5651)
Browse files Browse the repository at this point in the history
* Remove duplicate GetCacheKey methods

Consolidate to only the `ReadOnlySpan<object>`-based method.

* Update XML comments to say that the values are serialized
  • Loading branch information
stephentoub authored Nov 18, 2024
1 parent f085689 commit 06edb3c
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
Expand All @@ -16,6 +17,12 @@ namespace Microsoft.Extensions.AI;
/// </summary>
public abstract class CachingChatClient : DelegatingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>Initializes a new instance of the <see cref="CachingChatClient"/> class.</summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
protected CachingChatClient(IChatClient innerClient)
Expand Down Expand Up @@ -45,7 +52,7 @@ public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chat
// We're only storing the final result, not the in-flight task, so that we can avoid caching failures
// or having problems when one of the callers cancels but others don't. This has the drawback that
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(false, chatMessages, options);
var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options);

if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
{
Expand All @@ -68,7 +75,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
// we make a streaming request, yielding those results, but then convert those into a non-streaming
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.

var cacheKey = GetCacheKey(true, chatMessages, options);
var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion)
{
// Yield all of the cached items.
Expand All @@ -93,7 +100,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
}
else
{
var cacheKey = GetCacheKey(true, chatMessages, options);
var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
// Yield all of the cached items.
Expand All @@ -118,14 +125,10 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
}
}

/// <summary>
/// Computes a cache key for the specified call parameters.
/// </summary>
/// <param name="streaming">A flag to indicate if this is a streaming call.</param>
/// <param name="chatMessages">The chat content.</param>
/// <param name="options">The chat options to configure the request.</param>
/// <returns>A string that will be used as a cache key.</returns>
protected abstract string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options);
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
protected abstract string GetCacheKey(params ReadOnlySpan<object?> values);

/// <summary>
/// Returns a previously cached <see cref="ChatCompletion"/>, if available.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ namespace Microsoft.Extensions.AI;
/// </remarks>
public class DistributedCachingChatClient : CachingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;

Expand Down Expand Up @@ -98,15 +92,11 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
/// <remarks>The <paramref name="values"/> are serialized to JSON using <see cref="JsonSerializerOptions"/> in order to compute the key.</remarks>
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,10 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
return results;
}

/// <summary>
/// Computes a cache key for the specified call parameters.
/// </summary>
/// <param name="value">The <typeparamref name="TInput"/> for which an embedding is being requested.</param>
/// <param name="options">The options to configure the request.</param>
/// <returns>A string that will be used as a cache key.</returns>
protected abstract string GetCacheKey(TInput value, EmbeddingGenerationOptions? options);
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
protected abstract string GetCacheKey(params ReadOnlySpan<object?> values);

/// <summary>Returns a previously cached <see cref="Embedding{TEmbedding}"/>, if available.</summary>
/// <param name="key">The cache key.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,11 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
GetCacheKey([value, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
/// <remarks>The <paramref name="values"/> are serialized to JSON using <see cref="JsonSerializerOptions"/> in order to compute the key.</remarks>
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,18 @@ private static async Task AssertCompletionsEqualAsync(IReadOnlyList<StreamingCha
private sealed class CachingChatClientWithCustomKey(IChatClient innerClient, IDistributedCache storage)
: DistributedCachingChatClient(innerClient, storage)
{
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
var baseKey = base.GetCacheKey(streaming, chatMessages, options);
return baseKey + options?.AdditionalProperties?["someKey"]?.ToString();
var baseKey = base.GetCacheKey(values);
foreach (var value in values)
{
if (value is ChatOptions options)
{
return baseKey + options.AdditionalProperties?["someKey"]?.ToString();
}
}

return baseKey;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,18 @@ private static void AssertEmbeddingsEqual(Embedding<float> expected, Embedding<f
private sealed class CachingEmbeddingGeneratorWithCustomKey(IEmbeddingGenerator<string, Embedding<float>> innerGenerator, IDistributedCache storage)
: DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, storage)
{
protected override string GetCacheKey(string value, EmbeddingGenerationOptions? options) =>
base.GetCacheKey(value, options) + options?.AdditionalProperties?["someKey"]?.ToString();
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
var baseKey = base.GetCacheKey(values);
foreach (var value in values)
{
if (value is EmbeddingGenerationOptions options)
{
return baseKey + options.AdditionalProperties?["someKey"]?.ToString();
}
}

return baseKey;
}
}
}

0 comments on commit 06edb3c

Please sign in to comment.