Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate GetCacheKey methods #5651

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
Expand All @@ -16,6 +17,12 @@ namespace Microsoft.Extensions.AI;
/// </summary>
public abstract class CachingChatClient : DelegatingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>Initializes a new instance of the <see cref="CachingChatClient"/> class.</summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
protected CachingChatClient(IChatClient innerClient)
Expand Down Expand Up @@ -45,7 +52,7 @@ public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chat
// We're only storing the final result, not the in-flight task, so that we can avoid caching failures
// or having problems when one of the callers cancels but others don't. This has the drawback that
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(false, chatMessages, options);
var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options);

if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
{
Expand All @@ -68,7 +75,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
// we make a streaming request, yielding those results, but then convert those into a non-streaming
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.

var cacheKey = GetCacheKey(true, chatMessages, options);
var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion)
{
// Yield all of the cached items.
Expand All @@ -93,7 +100,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
}
else
{
var cacheKey = GetCacheKey(true, chatMessages, options);
var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
// Yield all of the cached items.
Expand All @@ -118,14 +125,10 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
}
}

/// <summary>
/// Computes a cache key for the specified call parameters.
/// </summary>
/// <param name="streaming">A flag to indicate if this is a streaming call.</param>
/// <param name="chatMessages">The chat content.</param>
/// <param name="options">The chat options to configure the request.</param>
/// <returns>A string that will be used as a cache key.</returns>
protected abstract string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options);
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
protected abstract string GetCacheKey(params ReadOnlySpan<object?> values);

/// <summary>
/// Returns a previously cached <see cref="ChatCompletion"/>, if available.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ namespace Microsoft.Extensions.AI;
/// </remarks>
public class DistributedCachingChatClient : CachingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;

Expand Down Expand Up @@ -98,15 +92,11 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
/// <remarks>The <paramref name="values"/> are serialized to JSON using <see cref="JsonSerializerOptions"/> in order to compute the key.</remarks>
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,10 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
return results;
}

/// <summary>
/// Computes a cache key for the specified call parameters.
/// </summary>
/// <param name="value">The <typeparamref name="TInput"/> for which an embedding is being requested.</param>
/// <param name="options">The options to configure the request.</param>
/// <returns>A string that will be used as a cache key.</returns>
protected abstract string GetCacheKey(TInput value, EmbeddingGenerationOptions? options);
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
protected abstract string GetCacheKey(params ReadOnlySpan<object?> values);

/// <summary>Returns a previously cached <see cref="Embedding{TEmbedding}"/>, if available.</summary>
/// <param name="key">The cache key.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,11 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
GetCacheKey([value, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <summary>Computes a cache key for the specified values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
/// <remarks>The <paramref name="values"/> are serialized to JSON using <see cref="JsonSerializerOptions"/> in order to compute the key.</remarks>
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,18 @@ private static async Task AssertCompletionsEqualAsync(IReadOnlyList<StreamingCha
private sealed class CachingChatClientWithCustomKey(IChatClient innerClient, IDistributedCache storage)
: DistributedCachingChatClient(innerClient, storage)
{
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
var baseKey = base.GetCacheKey(streaming, chatMessages, options);
return baseKey + options?.AdditionalProperties?["someKey"]?.ToString();
var baseKey = base.GetCacheKey(values);
foreach (var value in values)
{
if (value is ChatOptions options)
{
return baseKey + options.AdditionalProperties?["someKey"]?.ToString();
}
}

return baseKey;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,18 @@ private static void AssertEmbeddingsEqual(Embedding<float> expected, Embedding<f
private sealed class CachingEmbeddingGeneratorWithCustomKey(IEmbeddingGenerator<string, Embedding<float>> innerGenerator, IDistributedCache storage)
: DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, storage)
{
protected override string GetCacheKey(string value, EmbeddingGenerationOptions? options) =>
base.GetCacheKey(value, options) + options?.AdditionalProperties?["someKey"]?.ToString();
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
var baseKey = base.GetCacheKey(values);
foreach (var value in values)
{
if (value is EmbeddingGenerationOptions options)
{
return baseKey + options.AdditionalProperties?["someKey"]?.ToString();
}
}

return baseKey;
}
}
}
Loading