Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CachingChatClient's coalescing of streaming updates #5514

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ public AdditionalPropertiesDictionary(IEnumerable<KeyValuePair<string, object?>>
#endif
}

/// <summary>Creates a shallow clone of the properties dictionary.</summary>
/// <returns>
/// A shallow clone of the properties dictionary. The instance will not be the same as the current instance,
/// but it will contain all of the same key-value pairs.
/// </returns>
public AdditionalPropertiesDictionary Clone() => new AdditionalPropertiesDictionary(_dictionary);

/// <inheritdoc />
public object? this[string key]
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public virtual ChatOptions Clone()
ResponseFormat = ResponseFormat,
ModelId = ModelId,
ToolMode = ToolMode,
AdditionalProperties = AdditionalProperties?.Clone(),
};

if (StopSequences is not null)
Expand All @@ -85,11 +86,6 @@ public virtual ChatOptions Clone()
options.Tools = new List<AITool>(Tools);
}

if (AdditionalProperties is not null)
{
options.AdditionalProperties = new(AdditionalProperties);
}

return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@ public class EmbeddingGenerationOptions
/// The clone will have the same values for all properties as the original instance. Any collections, like <see cref="AdditionalProperties"/>
/// are shallow-cloned, meaning a new collection instance is created, but any references contained by the collections are shared with the original.
/// </remarks>
public virtual EmbeddingGenerationOptions Clone()
{
EmbeddingGenerationOptions options = new()
public virtual EmbeddingGenerationOptions Clone() =>
new()
{
ModelId = ModelId,
AdditionalProperties = AdditionalProperties?.Clone(),
};

if (AdditionalProperties is not null)
{
options.AdditionalProperties = new(AdditionalProperties);
}

return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S127 // "for" loop stop conditions should be invariant
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

namespace Microsoft.Extensions.AI;

/// <summary>
Expand All @@ -21,6 +24,20 @@ protected CachingChatClient(IChatClient innerClient)
{
}

/// <summary>Gets or sets a value indicating whether to coalesce streaming updates.</summary>
/// <remarks>
/// <para>
/// When <see langword="true"/>, the client will attempt to coalesce contiguous streaming updates
/// into a single update, in order to reduce the number of individual items that are yielded on
/// subsequent enumerations of the cached data. When <see langword="false"/>, the updates are
/// kept unaltered.
/// </para>
/// <para>
/// The default is <see langword="true"/>.
/// </para>
/// </remarks>
public bool CoalesceStreamingUpdates { get; set; } = true;

/// <inheritdoc />
public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -50,58 +67,124 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
// Yield all of the cached items.
foreach (var chunk in existingChunks)
{
yield return chunk;
}
}
else
{
var capturedItems = new List<StreamingChatCompletionUpdate>();
StreamingChatCompletionUpdate? previousCoalescedCopy = null;
await foreach (var item in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
yield return item;

// If this item is compatible with the previous one, we will coalesce them in the cache
var previous = capturedItems.Count > 0 ? capturedItems[capturedItems.Count - 1] : null;
if (item.ChoiceIndex == 0
&& item.Contents.Count == 1
&& item.Contents[0] is TextContent currentTextContent
&& previous is { ChoiceIndex: 0 }
&& previous.Role == item.Role
&& previous.Contents is { Count: 1 }
&& previous.Contents[0] is TextContent previousTextContent)
capturedItems.Add(chunk);
yield return chunk;
}

// If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list.
if (CoalesceStreamingUpdates)
{
StringBuilder coalescedText = new();

// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++)
{
if (!ReferenceEquals(previous, previousCoalescedCopy))
// If an item isn't generally coalescable, skip it.
StreamingChatCompletionUpdate update = capturedItems[startInclusive];
if (update.ChoiceIndex != 0 ||
update.Contents.Count != 1 ||
update.Contents[0] is not TextContent textContent)
{
// We don't want to mutate any object that we also yield, since the recipient might
// not expect that. Instead make a copy we can safely mutate.
previousCoalescedCopy = new()
continue;
}

// We found a coalescable item. Look for more contiguous items that are also coalescable with it.
int endExclusive = startInclusive + 1;
for (; endExclusive < capturedItems.Count; endExclusive++)
{
StreamingChatCompletionUpdate next = capturedItems[endExclusive];
if (next.ChoiceIndex != 0 ||
next.Contents.Count != 1 ||
next.Contents[0] is not TextContent ||

// changing role or author would be really strange, but check anyway
(update.Role is not null && next.Role is not null && update.Role != next.Role) ||
(update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName))
{
Role = previous.Role,
AuthorName = previous.AuthorName,
AdditionalProperties = previous.AdditionalProperties,
ChoiceIndex = previous.ChoiceIndex,
RawRepresentation = previous.RawRepresentation,
Contents = [new TextContent(previousTextContent.Text)]
};

// The last item we captured was before we knew it could be coalesced
// with this one, so replace it with the coalesced copy
capturedItems[capturedItems.Count - 1] = previousCoalescedCopy;
break;
}
}

#pragma warning disable S1643 // Strings should not be concatenated using '+' in a loop
((TextContent)previousCoalescedCopy.Contents[0]).Text += currentTextContent.Text;
#pragma warning restore S1643
}
else
{
capturedItems.Add(item);
// If we couldn't find anything to coalesce, there's nothing to do.
if (endExclusive - startInclusive <= 1)
{
continue;
}

// We found a coalescable run of items. Create a new node to represent the run. We create a new one
// rather than reappropriating one of the existing ones so as not to mutate an item already yielded.
_ = coalescedText.Clear().Append(capturedItems[startInclusive].Text);

TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
ModelId = textContent.ModelId,
};

StreamingChatCompletionUpdate coalesced = new()
{
AdditionalProperties = update.AdditionalProperties?.Clone(),
AuthorName = update.AuthorName,
CompletionId = update.CompletionId,
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
Role = update.Role,

// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
// to represent multiple, and it won't be serialized anyway.
};

// Replace the starting node with the coalesced node.
capturedItems[startInclusive] = coalesced;

// Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties,
// and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation.
// We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of
// the nulls in a single O(N) pass.
for (int i = startInclusive + 1; i < endExclusive; i++)
{
// Grab the next item.
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;

TextContent nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);

coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.Role ??= next.Role;

coalescedContent.ModelId ??= nextContent.ModelId;
}

// Complete the coalescing by patching the text of the coalesced node.
coalesced.Text = coalescedText.ToString();

// Jump to the last update in the run, so that when we loop around and bump ahead,
// we're at the next update just after the run.
startInclusive = endExclusive - 1;
}

// Remove all of the null slots left over from the coalescing process.
_ = capturedItems.RemoveAll(u => u is null);
}

// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ public class DistributedCachingChatClientTest
{
private readonly TestInMemoryCacheStorage _storage = new();

[Fact]
public void Ctor_ExpectedDefaults()
{
using var innerClient = new TestChatClient();
using var cachingClient = new DistributedCachingChatClient(innerClient, _storage);

Assert.True(cachingClient.CoalesceStreamingUpdates);

cachingClient.CoalesceStreamingUpdates = false;
Assert.False(cachingClient.CoalesceStreamingUpdates);

cachingClient.CoalesceStreamingUpdates = true;
Assert.True(cachingClient.CoalesceStreamingUpdates);
}

[Fact]
public async Task CachesSuccessResultsAsync()
{
Expand Down Expand Up @@ -251,8 +266,11 @@ public async Task StreamingCachesSuccessResultsAsync()
Assert.Equal(2, innerCallCount);
}

[Fact]
public async Task StreamingCoalescesConsecutiveTextChunksAsync()
[Theory]
[InlineData(false)]
[InlineData(true)]
[InlineData(null)]
public async Task StreamingCoalescesConsecutiveTextChunksAsync(bool? coalesce)
{
// Arrange
List<StreamingChatCompletionUpdate> expectedCompletion =
Expand All @@ -274,17 +292,101 @@ public async Task StreamingCoalescesConsecutiveTextChunksAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

if (coalesce is not null)
{
outer.CoalesceStreamingUpdates = coalesce.Value;
}

var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);
await ToListAsync(result1);

// Act
var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);

// Assert
if (coalesce is null or true)
{
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this becomes another one.", c.Text));
}
else
{
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This", c.Text),
c => Assert.Equal(" becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this", c.Text),
c => Assert.Equal(" becomes another", c.Text),
c => Assert.Equal(" one.", c.Text));
}
}

[Fact]
public async Task StreamingCoalescingPropagatesMetadataAsync()
{
// Arrange
List<StreamingChatCompletionUpdate> expectedCompletion =
[
new() { Role = ChatRole.Assistant, Contents = [new TextContent("Hello")] },
new() { Role = ChatRole.Assistant, Contents = [new TextContent(" world, ") { ModelId = "some model" }] },
new()
{
Role = ChatRole.Assistant,
Contents =
[
new TextContent("how ")
{
ModelId = "some other model",
AdditionalProperties = new() { ["a"] = "b", ["c"] = "d" },
}
]
},
new()
{
Role = ChatRole.Assistant,
Contents =
[
new TextContent("are you?")
{
AdditionalProperties = new() { ["e"] = "f", ["g"] = "h" },
}
],
CreatedAt = DateTime.Parse("2024-10-11T19:23:36.0152137Z"),
CompletionId = "12345",
AuthorName = "Someone",
FinishReason = ChatFinishReason.Length,
},
];

using var testClient = new TestChatClient
{
CompleteStreamingAsyncCallback = delegate { return ToAsyncEnumerableAsync(expectedCompletion); }
};
using var outer = new DistributedCachingChatClient(testClient, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);
await ToListAsync(result1);

// Act
var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);

// Assert
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this becomes another one.", c.Text));
var items = await ToListAsync(result2);
var item = Assert.Single(items);
Assert.Equal("Hello world, how are you?", item.Text);
Assert.Equal("12345", item.CompletionId);
Assert.Equal("Someone", item.AuthorName);
Assert.Equal(ChatFinishReason.Length, item.FinishReason);
Assert.Equal(DateTime.Parse("2024-10-11T19:23:36.0152137Z"), item.CreatedAt);

var content = Assert.IsType<TextContent>(Assert.Single(item.Contents));
Assert.Equal("Hello world, how are you?", content.Text);
Assert.Equal("some model", content.ModelId);
}

[Fact]
Expand Down
Loading