Skip to content

Commit

Permalink
Improve CachingChatClient's coalescing of streaming updates (#5514)
Browse files Browse the repository at this point in the history
* Improve CachingChatClient's coalescing of streaming updates

- Avoid O(N^2) memory allocation in the length of the received text
- Propagate additional metadata from coalesced nodes
- Propagate metadata on the coalesced TextContent, like ModelId
- Expose whether to coalesce as a setting on the client

* Remove dictionary merging until we have evidence it's warranted
  • Loading branch information
stephentoub authored Oct 12, 2024
1 parent 6249779 commit e0c9a82
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ public AdditionalPropertiesDictionary(IEnumerable<KeyValuePair<string, object?>>
#endif
}

/// <summary>Creates a shallow clone of the properties dictionary.</summary>
/// <returns>
/// A shallow clone of the properties dictionary. The instance will not be the same as the current instance,
/// but it will contain all of the same key-value pairs.
/// </returns>
public AdditionalPropertiesDictionary Clone() => new AdditionalPropertiesDictionary(_dictionary);

/// <inheritdoc />
public object? this[string key]
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public virtual ChatOptions Clone()
ResponseFormat = ResponseFormat,
ModelId = ModelId,
ToolMode = ToolMode,
AdditionalProperties = AdditionalProperties?.Clone(),
};

if (StopSequences is not null)
Expand All @@ -85,11 +86,6 @@ public virtual ChatOptions Clone()
options.Tools = new List<AITool>(Tools);
}

if (AdditionalProperties is not null)
{
options.AdditionalProperties = new(AdditionalProperties);
}

return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@ public class EmbeddingGenerationOptions
/// The clone will have the same values for all properties as the original instance. Any collections, like <see cref="AdditionalProperties"/>
/// are shallow-cloned, meaning a new collection instance is created, but any references contained by the collections are shared with the original.
/// </remarks>
public virtual EmbeddingGenerationOptions Clone()
{
EmbeddingGenerationOptions options = new()
public virtual EmbeddingGenerationOptions Clone() =>
new()
{
ModelId = ModelId,
AdditionalProperties = AdditionalProperties?.Clone(),
};

if (AdditionalProperties is not null)
{
options.AdditionalProperties = new(AdditionalProperties);
}

return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S127 // "for" loop stop conditions should be invariant

namespace Microsoft.Extensions.AI;

/// <summary>
Expand All @@ -21,6 +24,20 @@ protected CachingChatClient(IChatClient innerClient)
{
}

/// <summary>Gets or sets a value indicating whether to coalesce streaming updates.</summary>
/// <remarks>
/// <para>
/// When <see langword="true"/>, the client will attempt to coalesce contiguous streaming updates
/// into a single update, in order to reduce the number of individual items that are yielded on
/// subsequent enumerations of the cached data. When <see langword="false"/>, the updates are
/// kept unaltered.
/// </para>
/// <para>
/// The default is <see langword="true"/>.
/// </para>
/// </remarks>
public bool CoalesceStreamingUpdates { get; set; } = true;

/// <inheritdoc />
public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -50,58 +67,124 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
// Yield all of the cached items.
foreach (var chunk in existingChunks)
{
yield return chunk;
}
}
else
{
var capturedItems = new List<StreamingChatCompletionUpdate>();
StreamingChatCompletionUpdate? previousCoalescedCopy = null;
await foreach (var item in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
yield return item;

// If this item is compatible with the previous one, we will coalesce them in the cache
var previous = capturedItems.Count > 0 ? capturedItems[capturedItems.Count - 1] : null;
if (item.ChoiceIndex == 0
&& item.Contents.Count == 1
&& item.Contents[0] is TextContent currentTextContent
&& previous is { ChoiceIndex: 0 }
&& previous.Role == item.Role
&& previous.Contents is { Count: 1 }
&& previous.Contents[0] is TextContent previousTextContent)
capturedItems.Add(chunk);
yield return chunk;
}

// If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list.
if (CoalesceStreamingUpdates)
{
StringBuilder coalescedText = new();

// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++)
{
if (!ReferenceEquals(previous, previousCoalescedCopy))
// If an item isn't generally coalescable, skip it.
StreamingChatCompletionUpdate update = capturedItems[startInclusive];
if (update.ChoiceIndex != 0 ||
update.Contents.Count != 1 ||
update.Contents[0] is not TextContent textContent)
{
// We don't want to mutate any object that we also yield, since the recipient might
// not expect that. Instead make a copy we can safely mutate.
previousCoalescedCopy = new()
continue;
}

// We found a coalescable item. Look for more contiguous items that are also coalescable with it.
int endExclusive = startInclusive + 1;
for (; endExclusive < capturedItems.Count; endExclusive++)
{
StreamingChatCompletionUpdate next = capturedItems[endExclusive];
if (next.ChoiceIndex != 0 ||
next.Contents.Count != 1 ||
next.Contents[0] is not TextContent ||

// changing role or author would be really strange, but check anyway
(update.Role is not null && next.Role is not null && update.Role != next.Role) ||
(update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName))
{
Role = previous.Role,
AuthorName = previous.AuthorName,
AdditionalProperties = previous.AdditionalProperties,
ChoiceIndex = previous.ChoiceIndex,
RawRepresentation = previous.RawRepresentation,
Contents = [new TextContent(previousTextContent.Text)]
};

// The last item we captured was before we knew it could be coalesced
// with this one, so replace it with the coalesced copy
capturedItems[capturedItems.Count - 1] = previousCoalescedCopy;
break;
}
}

#pragma warning disable S1643 // Strings should not be concatenated using '+' in a loop
((TextContent)previousCoalescedCopy.Contents[0]).Text += currentTextContent.Text;
#pragma warning restore S1643
}
else
{
capturedItems.Add(item);
// If we couldn't find anything to coalesce, there's nothing to do.
if (endExclusive - startInclusive <= 1)
{
continue;
}

// We found a coalescable run of items. Create a new node to represent the run. We create a new one
// rather than reappropriating one of the existing ones so as not to mutate an item already yielded.
_ = coalescedText.Clear().Append(capturedItems[startInclusive].Text);

TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
ModelId = textContent.ModelId,
};

StreamingChatCompletionUpdate coalesced = new()
{
AdditionalProperties = update.AdditionalProperties?.Clone(),
AuthorName = update.AuthorName,
CompletionId = update.CompletionId,
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
Role = update.Role,

// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
// to represent multiple, and it won't be serialized anyway.
};

// Replace the starting node with the coalesced node.
capturedItems[startInclusive] = coalesced;

// Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties,
// and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation.
// We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of
// the nulls in a single O(N) pass.
for (int i = startInclusive + 1; i < endExclusive; i++)
{
// Grab the next item.
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;

TextContent nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);

coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.Role ??= next.Role;

coalescedContent.ModelId ??= nextContent.ModelId;
}

// Complete the coalescing by patching the text of the coalesced node.
coalesced.Text = coalescedText.ToString();

// Jump to the last update in the run, so that when we loop around and bump ahead,
// we're at the next update just after the run.
startInclusive = endExclusive - 1;
}

// Remove all of the null slots left over from the coalescing process.
_ = capturedItems.RemoveAll(u => u is null);
}

// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ public class DistributedCachingChatClientTest
{
private readonly TestInMemoryCacheStorage _storage = new();

[Fact]
public void Ctor_ExpectedDefaults()
{
using var innerClient = new TestChatClient();
using var cachingClient = new DistributedCachingChatClient(innerClient, _storage);

Assert.True(cachingClient.CoalesceStreamingUpdates);

cachingClient.CoalesceStreamingUpdates = false;
Assert.False(cachingClient.CoalesceStreamingUpdates);

cachingClient.CoalesceStreamingUpdates = true;
Assert.True(cachingClient.CoalesceStreamingUpdates);
}

[Fact]
public async Task CachesSuccessResultsAsync()
{
Expand Down Expand Up @@ -251,8 +266,11 @@ public async Task StreamingCachesSuccessResultsAsync()
Assert.Equal(2, innerCallCount);
}

[Fact]
public async Task StreamingCoalescesConsecutiveTextChunksAsync()
[Theory]
[InlineData(false)]
[InlineData(true)]
[InlineData(null)]
public async Task StreamingCoalescesConsecutiveTextChunksAsync(bool? coalesce)
{
// Arrange
List<StreamingChatCompletionUpdate> expectedCompletion =
Expand All @@ -274,17 +292,101 @@ public async Task StreamingCoalescesConsecutiveTextChunksAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

if (coalesce is not null)
{
outer.CoalesceStreamingUpdates = coalesce.Value;
}

var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);
await ToListAsync(result1);

// Act
var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);

// Assert
if (coalesce is null or true)
{
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this becomes another one.", c.Text));
}
else
{
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This", c.Text),
c => Assert.Equal(" becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this", c.Text),
c => Assert.Equal(" becomes another", c.Text),
c => Assert.Equal(" one.", c.Text));
}
}

[Fact]
public async Task StreamingCoalescingPropagatesMetadataAsync()
{
// Arrange
List<StreamingChatCompletionUpdate> expectedCompletion =
[
new() { Role = ChatRole.Assistant, Contents = [new TextContent("Hello")] },
new() { Role = ChatRole.Assistant, Contents = [new TextContent(" world, ") { ModelId = "some model" }] },
new()
{
Role = ChatRole.Assistant,
Contents =
[
new TextContent("how ")
{
ModelId = "some other model",
AdditionalProperties = new() { ["a"] = "b", ["c"] = "d" },
}
]
},
new()
{
Role = ChatRole.Assistant,
Contents =
[
new TextContent("are you?")
{
AdditionalProperties = new() { ["e"] = "f", ["g"] = "h" },
}
],
CreatedAt = DateTime.Parse("2024-10-11T19:23:36.0152137Z"),
CompletionId = "12345",
AuthorName = "Someone",
FinishReason = ChatFinishReason.Length,
},
];

using var testClient = new TestChatClient
{
CompleteStreamingAsyncCallback = delegate { return ToAsyncEnumerableAsync(expectedCompletion); }
};
using var outer = new DistributedCachingChatClient(testClient, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);
await ToListAsync(result1);

// Act
var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]);

// Assert
Assert.Collection(await ToListAsync(result2),
c => Assert.Equal("This becomes one chunk", c.Text),
c => Assert.IsType<FunctionCallContent>(Assert.Single(c.Contents)),
c => Assert.Equal("... and this becomes another one.", c.Text));
var items = await ToListAsync(result2);
var item = Assert.Single(items);
Assert.Equal("Hello world, how are you?", item.Text);
Assert.Equal("12345", item.CompletionId);
Assert.Equal("Someone", item.AuthorName);
Assert.Equal(ChatFinishReason.Length, item.FinishReason);
Assert.Equal(DateTime.Parse("2024-10-11T19:23:36.0152137Z"), item.CreatedAt);

var content = Assert.IsType<TextContent>(Assert.Single(item.Contents));
Assert.Equal("Hello world, how are you?", content.Text);
Assert.Equal("some model", content.ModelId);
}

[Fact]
Expand Down

0 comments on commit e0c9a82

Please sign in to comment.