Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework cache key handling in caching client / generator #5641

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 101 additions & 32 deletions src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,129 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.IO;
using System.Security.Cryptography;
using System.Text.Json;
using Microsoft.Shared.Diagnostics;
#if NET
using System.Threading;
using System.Threading.Tasks;
#endif

#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable SA1202 // Elements should be ordered by access
#pragma warning disable SA1502 // Element should not be on a single line

namespace Microsoft.Extensions.AI;

/// <summary>Provides internal helpers for implementing caching services.</summary>
internal static class CachingHelpers
{
/// <summary>Computes a default cache key for the specified parameters.</summary>
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
/// <param name="value">The data with which to compute the key.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
/// <returns>A string that will be used as a cache key.</returns>
public static string GetCacheKey<TValue>(TValue value, JsonSerializerOptions serializerOptions)
=> GetCacheKey(value, false, serializerOptions);

/// <summary>Computes a default cache key for the specified parameters.</summary>
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
/// <param name="value">The data with which to compute the key.</param>
/// <param name="flag">Another data item that causes the key to vary.</param>
/// <param name="values">The data with which to compute the key.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
/// <returns>A string that will be used as a cache key.</returns>
public static string GetCacheKey<TValue>(TValue value, bool flag, JsonSerializerOptions serializerOptions)
public static string GetCacheKey(ReadOnlySpan<object?> values, JsonSerializerOptions serializerOptions)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
_ = Throw.IfNull(value);
_ = Throw.IfNull(serializerOptions);
serializerOptions.MakeReadOnly();

var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue)));

if (flag && jsonKeyBytes.Length > 0)
{
// Make an arbitrary change to the hash input based on the flag
// The alternative would be including the flag in "value" in the
// first place, but that's likely to require an extra allocation
// or the inclusion of another type in the JsonSerializerContext.
// This is a micro-optimization we can change at any time.
jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]);
}
Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null");
Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only.");

// The complete JSON representation is excessively long for a cache key, duplicating much of the content
// from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes.
// If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information
// disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
// invalidating any existing cache entries that may exist in whatever IDistributedCache was in use.
#if NET8_0_OR_GREATER

#if NET
IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new();
IncrementalHashStream.ThreadStaticInstance = null;

foreach (object? value in values)
{
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
}

Span<byte> hashData = stackalloc byte[SHA256.HashSizeInBytes];
SHA256.HashData(jsonKeyBytes, hashData);
stream.GetHashAndReset(hashData);
IncrementalHashStream.ThreadStaticInstance = stream;

return Convert.ToHexString(hashData);
#else
MemoryStream stream = new();
foreach (object? value in values)
{
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
}

using var sha256 = SHA256.Create();
var hashData = sha256.ComputeHash(jsonKeyBytes);
return BitConverter.ToString(hashData).Replace("-", string.Empty);
stream.Position = 0;
var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);

var chars = new char[hashData.Length * 2];
int destPos = 0;
foreach (byte b in hashData)
{
int div = Math.DivRem(b, 16, out int rem);
chars[destPos++] = ToHexChar(div);
chars[destPos++] = ToHexChar(rem);

static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
}

Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");

return new string(chars);
#endif
}

#if NET
/// <summary>Provides a stream that writes to an <see cref="IncrementalHash"/>.</summary>
private sealed class IncrementalHashStream : Stream
{
/// <summary>A per-thread instance of <see cref="IncrementalHashStream"/>.</summary>
/// <remarks>An instance stored must be in a reset state ready to be used by another consumer.</remarks>
[ThreadStatic]
public static IncrementalHashStream? ThreadStaticInstance;

/// <summary>Gets the current hash and resets.</summary>
public void GetHashAndReset(Span<byte> bytes) => _hash.GetHashAndReset(bytes);

/// <summary>The <see cref="IncrementalHash"/> used by this instance.</summary>
private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);

protected override void Dispose(bool disposing)
{
_hash.Dispose();
base.Dispose(disposing);
}

public override void WriteByte(byte value) => Write(new ReadOnlySpan<byte>(in value));
public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
public override void Write(ReadOnlySpan<byte> buffer) => _hash.AppendData(buffer);

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Write(buffer, offset, count);
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
Write(buffer.Span);
return ValueTask.CompletedTask;
}

public override void Flush() { }
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;

public override bool CanWrite => true;
public override bool CanRead => false;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
}
#endif
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
Expand All @@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI;
/// </remarks>
public class DistributedCachingChatClient : CachingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;
private JsonSerializerOptions _jsonSerializerOptions;

/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;

/// <summary>Initializes a new instance of the <see cref="DistributedCachingChatClient"/> class.</summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
Expand All @@ -29,7 +39,6 @@ public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache s
: base(innerClient)
{
_storage = Throw.IfNull(storage);
_jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
}

/// <summary>Gets or sets JSON serialization options to use when serializing cache data.</summary>
Expand Down Expand Up @@ -90,13 +99,16 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
}

/// <inheritdoc />
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephentoub Just a thought - should this function also be overridable? If a derived class were to take over the cache key computation by overriding GetCacheKey(bool, IList...) above, would it be surprising that calling base.GetCacheKey(ReadOnlySpan<object?>) in further derived classes does something different / uses a completely different hashing algorithm than calling `base.GetCacheKey(bool, IList...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be surprising that calling base.GetCacheKey(ReadOnlySpan<object?>) in further derived classes does something different / uses a completely different hashing algorithm than calling `base.GetCacheKey(bool, IList...)?

Not sure I follow. base.GetCacheKey(bool, IList, ...) uses base.GetCacheKey(ROS). The intent is the latter represents the default algorithm for handling multiple objects. An override of GetCacheKey(bool, IList, ...) can delegate to either of the methods on the base.

Could you elaborate on the concern? I'm not clear on what making it virtual would solve. Maybe it needs a different name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah here is an example of what I meant. Consider the case where we have EvaluationCachingChatClient that is derived from DistributedCachingChatClient and that ships as part of a hypothetical Evaluation library 😄 EvaluationCachingChatClient takes over the cache key computation by overriding GetCacheKey and uses a its own custom hashing algorithm for this.

Now consider the case where a consumer of this Evaluation library wants to inherit EvaluationCachingChatClient and tweak the hash computation to include a few additional custom parameters without changing the hashing algorithm used in EvaluationCachingChatClient. If this derived implementation were to call the base.GetCacheKey(ReadOnlySpan<object?>) for this, the key that is returned could look very different from the key returned from base.GetCacheKey(bool, IList...).

Depending on how different the caching algorithm used in the two base types happens to be, the returned key could have different lengths, different set of allowed characters (which could be important if the key is used as the name of a file or directory on disk for example) etc. This could be especially problematic in cases like below where the derived class wants to keep the base implementation unchanged in the default case and override only it under certain conditions.

<Pseudo Code>
class MyCustomEvaluationCachingChatClient : EvaluationCachingChatClient
{
    override GetCacheKey(bool, IList...)
    {
        if (condition)
        {
             // Compute a custom key - I want to use same hashing algorithm as the immediate base type - but the below call uses the hashing algorithm from another base class
             return base.GetCacheKey(<ReadOnlySpan<object?> with custom parameters>);
        }
        
        // Preserve the same key computation as the immediate base type.
        return base.GetCacheKey(bool, IList...);
    }
}

Each level in the inheritance hierarchy has a choice between the following options -

  • only change what parameters are included in the hash key computation without changing the hashing algorithm being used OR
  • completely take over the hash key computation and control both the set of parameters that are included in the hash as well as the hashing algorithm that is used

And regardless of which choice is made in a particular inheritance level, it would be ideal if subsequent derived implementations can also retain the same flexibility to

  • only tweak just what parameters are included in the hash while keeping the same hashing algorithm as their immediate base class OR
  • completely take over the hash key computation

Having the ability to individually override GetCacheKey(bool, IList) and GetCacheKey(ReadOnlySpan<object?>) could be one way to accomplish the above. Yes, I agree we should probably give different names to the two functions for clarity in this case.

Another option may be to eliminate GetCacheKey(ReadOnlySpan<object?>) from DistributedCachingChatClient and make it available as a standalone static helper. However, in this case, every level in the inheritance hierarchy that wants to use some custom hashing algorithm would need to make similar static helpers available for their own derived types...

I realize my response is a bit rambling and perhaps I am over thinking it 😄, but I hope this clarifies the concern. Happy to sync up offline to discuss if needed.

{
// While it might be desirable to include ChatOptions in the cache key, it's not always possible,
// since ChatOptions can contain types that are not guaranteed to be serializable or have a stable
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
// the chat contents. Developers may subclass and override this to provide custom rules.
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions);
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
Expand Down Expand Up @@ -74,12 +75,16 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
}

/// <inheritdoc />
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options)
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
GetCacheKey([value, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
{
// While it might be desirable to include options in the cache key, it's not always possible,
// since options can contain types that are not guaranteed to be serializable or have a stable
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
// the value. Developers may subclass and override this to provide custom rules.
return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions);
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync()
}

[Fact]
public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
public async Task CacheKeyVariesByChatOptionsAsync()
{
// Arrange
var innerCallCount = 0;
Expand All @@ -546,20 +546,35 @@ public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

// Act: Call with two different ChatOptions
// Act: Call with two different ChatOptions that have the same values
var result1 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 1" } }
});
var result2 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 2" } }
AdditionalProperties = new() { { "someKey", "value 1" } }
});

// Assert: Same result
Assert.Equal(1, innerCallCount);
Assert.Equal("value 1", result1.Message.Text);
Assert.Equal("value 1", result2.Message.Text);

// Act: Call with two different ChatOptions that have different values
var result3 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 1" } }
});
var result4 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 2" } }
});

// Assert: Different results
Assert.Equal(2, innerCallCount);
Assert.Equal("value 1", result3.Message.Text);
Assert.Equal("value 2", result4.Message.Text);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public async Task DoesNotCacheCanceledResultsAsync()
}

[Fact]
public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
public async Task CacheKeyVariesByEmbeddingOptionsAsync()
{
// Arrange
var innerCallCount = 0;
Expand All @@ -232,28 +232,43 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
{
innerCallCount++;
await Task.Yield();
return [_expectedEmbedding];
return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())];
}
};
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

// Act: Call with two different options
// Act: Call with two different EmbeddingGenerationOptions that have the same values
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
AdditionalProperties = new() { ["someKey"] = "value 1" }
});

// Assert: Same result
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);

// Act: Call with two different EmbeddingGenerationOptions that have different values
var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});

// Assert: Different result
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3);
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(Dictionary<string, string>))]
[JsonSerializable(typeof(DayOfWeek[]))]
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
Loading