From 1a4a54f107eb2651d951934050455a96bface2a8 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 20 Nov 2024 14:06:53 -0500 Subject: [PATCH] Fix a few FunctionInvocationChatClient streaming issues (#5680) - The non-streaming path explicitly throws if the response contains multiple choices. The streaming path wasn't doing the same and was instead silently producing bad results. - The streaming path was yielding function call content _and_ adding them to the chat history. It should only have been doing the latter. This fixes both issues. We also had close to zero test coverage in our FunctionInvocationChatClient tests for streaming, only for non-streaming. This also fixes that. --- .../FunctionInvokingChatClient.cs | 51 +- .../FunctionInvokingChatClientTests.cs | 476 ++++++++++++------ 2 files changed, 371 insertions(+), 156 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 70fddc68718..e1e4542d5d0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -216,7 +216,7 @@ public override async Task CompleteAsync(IList chat // doesn't realize this and is wasting their budget requesting extra choices we'd never use. if (response.Choices.Count > 1) { - throw new InvalidOperationException($"Automatic function call invocation only accepts a single choice, but {response.Choices.Count} choices were received."); + ThrowForMultipleChoices(); } // Extract any function call contents on the first choice. If there are none, we're done. @@ -301,22 +301,47 @@ public override async IAsyncEnumerable CompleteSt _ = Throw.IfNull(chatMessages); HashSet? messagesToRemove = null; + List functionCallContents = []; + int? choice; try { for (int iteration = 0; ; iteration++) { - List? functionCallContents = null; - await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + choice = null; + functionCallContents.Clear(); + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) { // We're going to emit all StreamingChatMessage items upstream, even ones that represent - // function calls, because a given StreamingChatMessage can contain other content too. - yield return chunk; + // function calls, because a given StreamingChatMessage can contain other content, too. + // And if we yield the function calls, and the consumer adds all the content into a message + // that's then added into history, they'll end up with function call contents that aren't + // directly paired with function result contents, which may cause issues for some models + // when the history is later sent again. + + // Find all the FCCs. We need to track these separately in order to be able to process them later. + int preFccCount = functionCallContents.Count; + functionCallContents.AddRange(update.Contents.OfType()); + + // If there were any, remove them from the update. We do this before yielding the update so + // that we're not modifying an instance already provided back to the caller. + int addedFccs = functionCallContents.Count - preFccCount; + if (addedFccs > preFccCount) + { + update.Contents = addedFccs == update.Contents.Count ? + [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); + } - foreach (var item in chunk.Contents.OfType()) + // Only one choice is allowed with automatic function calling. + if (choice is null) + { + choice = update.ChoiceIndex; + } + else if (choice != update.ChoiceIndex) { - functionCallContents ??= []; - functionCallContents.Add(item); + ThrowForMultipleChoices(); } + + yield return update; } // If there are no tools to call, or for any other reason we should stop, return the response. @@ -373,6 +398,16 @@ public override async IAsyncEnumerable CompleteSt } } + /// Throws an exception when multiple choices are received. + private static void ThrowForMultipleChoices() + { + // If there's more than one choice, we don't know which one to add to chat history, or which + // of their function calls to process. This should not happen except if the developer has + // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer + // doesn't realize this and is wasting their budget requesting extra choices we'd never use. + throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received."); + } + /// /// Removes all of the messages in from /// and all of the content in from the messages in . diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index d9df2fc89e3..da983243acb 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -12,6 +12,8 @@ using OpenTelemetry.Trace; using Xunit; +#pragma warning disable SA1118 // Parameter should not span multiple lines + namespace Microsoft.Extensions.AI; public class FunctionInvokingChatClientTests @@ -41,14 +43,16 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() { var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create(() => "Result 1", "Func1"), AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), AIFunctionFactory.Create((int i) => { }, "VoidReturn"), ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), @@ -57,7 +61,11 @@ await InvokeAndAssertAsync(options, [ new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), new ChatMessage(ChatRole.Assistant, "world"), - ]); + ]; + + await InvokeAndAssertAsync(options, plan); + + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -67,31 +75,46 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn { var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create((int i) => "Result 1", "Func1"), AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), ] }; - await InvokeAndAssertAsync(options, [ + + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func1"), new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 34 } }), new FunctionCallContent("callId3", "Func2", arguments: new Dictionary { { "i", 56 } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func1", result: "Result 1"), new FunctionResultContent("callId2", "Func2", result: "Result 2: 34"), new FunctionResultContent("callId3", "Func2", result: "Result 2: 56"), ]), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId4", "Func2", arguments: new Dictionary { { "i", 78 } }), - new FunctionCallContent("callId5", "Func1")]), - new ChatMessage(ChatRole.Tool, [ + new FunctionCallContent("callId5", "Func1") + ]), + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId4", "Func2", result: "Result 2: 78"), - new FunctionResultContent("callId5", "Func1", result: "Result 1")]), + new FunctionResultContent("callId5", "Func1", result: "Result 1") + ]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -101,7 +124,8 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create((string arg) => { barrier.SignalAndWait(); @@ -110,18 +134,27 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func", result: "hellohello"), new FunctionResultContent("callId2", "Func", result: "worldworld"), ]), new ChatMessage(ChatRole.Assistant, "done"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -131,7 +164,8 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create(async (string arg) => { Interlocked.Increment(ref activeCount); @@ -143,18 +177,25 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func", result: "hellohello"), new FunctionResultContent("callId2", "Func", result: "worldworld"), ]), new ChatMessage(ChatRole.Assistant, "done"), - ]); + ]; + + await InvokeAndAssertAsync(options, plan); + + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -172,36 +213,40 @@ public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunc ] }; -#pragma warning disable SA1118 // Parameter should not span multiple lines - var finalChat = await InvokeAndAssertAsync( - options, - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], - expected: keepFunctionCallingMessages ? - null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "world") - ], - configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); -#pragma warning restore SA1118 - - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + List? expected = keepFunctionCallingMessages ? null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "world") + ]; + + Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + + Validate(await InvokeAndAssertAsync(options, plan, expected, configure)); + Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure)); + + void Validate(List finalChat) { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } } } @@ -220,37 +265,56 @@ public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunct ] }; -#pragma warning disable SA1118 // Parameter should not span multiple lines - var finalChat = await InvokeAndAssertAsync(options, - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], - expected: keepFunctionCallingMessages ? - null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Assistant, "more"), - new ChatMessage(ChatRole.Assistant, "world"), - ], - configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); -#pragma warning restore SA1118 - - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + +#pragma warning disable SA1005, S125 + Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Assistant, "more"), + new ChatMessage(ChatRole.Assistant, "world"), + ], configure)); + + Validate(await InvokeAndAssertStreamingAsync(options, plan, keepFunctionCallingMessages ? + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), + ] : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), + ], configure)); + + void Validate(List finalChat) { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } } } @@ -267,12 +331,19 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -281,28 +352,36 @@ public async Task RejectsMultipleChoicesAsync() var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); + var expected = new ChatCompletion( + [ + new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), + new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), + ]); + using var innerClient = new TestChatClient { CompleteAsyncCallback = async (chatContents, options, cancellationToken) => { await Task.Yield(); - - return new ChatCompletion( - [ - new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), - new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), - ]); - } + return expected; + }, + CompleteStreamingAsyncCallback = (chatContents, options, cancellationToken) => + YieldAsync(expected.ToStreamingChatCompletionUpdates()), }; IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build(); List chat = [new ChatMessage(ChatRole.User, "hello")]; - var ex = await Assert.ThrowsAsync( - () => service.CompleteAsync(chat, new ChatOptions { Tools = [func1, func2] })); + ChatOptions options = new() { Tools = [func1, func2] }; - Assert.Contains("only accepts a single choice", ex.Message); - Assert.Single(chat); // It didn't add anything to the chat history + Validate(await Assert.ThrowsAsync(() => service.CompleteAsync(chat, options))); + Validate(await Assert.ThrowsAsync(() => service.CompleteStreamingAsync(chat, options).ToChatCompletionAsync())); + + void Validate(Exception ex) + { + Assert.Contains("only accepts a single choice", ex.Message); + Assert.Single(chat); // It didn't add anything to the chat history + } } [Theory] @@ -311,39 +390,51 @@ public async Task RejectsMultipleChoicesAsync() [InlineData(LogLevel.Information)] public async Task FunctionInvocationsLogged(LogLevel level) { - using CapturingLoggerProvider clp = new(); - - ServiceCollection c = new(); - c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); - var services = c.BuildServiceProvider(); + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; var options = new ChatOptions { Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] }; - await InvokeAndAssertAsync(options, [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(c => new FunctionInvokingChatClient(c, services.GetRequiredService>()))); + Func configure = b => + b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())); - if (level is LogLevel.Trace) - { - Assert.Collection(clp.Logger.Entries, - entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")), - entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\""))); - } - else if (level is LogLevel.Debug) - { - Assert.Collection(clp.Logger.Entries, - entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")), - entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result"))); - } - else + await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services)); + + await InvokeAsync(services => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure, services: services)); + + async Task InvokeAsync(Func work) { - Assert.Empty(clp.Logger.Entries); + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + + await work(c.BuildServiceProvider()); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\""))); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result"))); + } + else + { + Assert.Empty(clp.Logger.Entries); + } } } @@ -353,38 +444,51 @@ await InvokeAndAssertAsync(options, [ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) { string sourceName = Guid.NewGuid().ToString(); - var activities = new List(); - using TracerProvider? tracerProvider = enableTelemetry ? - OpenTelemetry.Sdk.CreateTracerProviderBuilder() - .AddSource(sourceName) - .AddInMemoryExporter(activities) - .Build() : - null; - - var options = new ChatOptions - { - Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] - }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(c => - new FunctionInvokingChatClient( - new OpenTelemetryChatClient(c, sourceName: sourceName)))); + ]; - if (enableTelemetry) + ChatOptions options = new() { - Assert.Collection(activities, - activity => Assert.Equal("chat", activity.DisplayName), - activity => Assert.Equal("Func1", activity.DisplayName), - activity => Assert.Equal("chat", activity.DisplayName)); - } - else + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + Func configure = b => b.Use(c => + new FunctionInvokingChatClient( + new OpenTelemetryChatClient(c, sourceName: sourceName))); + + await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure)); + + await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure)); + + async Task InvokeAsync(Func work) { - Assert.Empty(activities); + var activities = new List(); + using TracerProvider? tracerProvider = enableTelemetry ? + OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build() : + null; + + await work(); + + if (enableTelemetry) + { + Assert.Collection(activities, + activity => Assert.Equal("chat", activity.DisplayName), + activity => Assert.Equal("Func1", activity.DisplayName), + activity => Assert.Equal("chat", activity.DisplayName)); + } + else + { + Assert.Empty(activities); + } } } @@ -392,7 +496,8 @@ private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan, List? expected = null, - Func? configurePipeline = null) + Func? configurePipeline = null, + IServiceProvider? services = null) { Assert.NotEmpty(plan); @@ -400,7 +505,6 @@ private static async Task> InvokeAndAssertAsync( using CancellationTokenSource cts = new(); List chat = [plan[0]]; - int i = 0; using var innerClient = new TestChatClient { @@ -411,11 +515,11 @@ private static async Task> InvokeAndAssertAsync( await Task.Yield(); - return new ChatCompletion([plan[contents.Count]]); + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])); } }; - IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(); + IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.CompleteAsync(chat, options, cts.Token); chat.Add(result.Message); @@ -423,7 +527,7 @@ private static async Task> InvokeAndAssertAsync( expected ??= plan; Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); - for (; i < expected.Count; i++) + for (int i = 0; i < expected.Count; i++) { var expectedMessage = expected[i]; var chatMessage = chat[i]; @@ -456,4 +560,80 @@ private static async Task> InvokeAndAssertAsync( return chat; } + + private static async Task> InvokeAndAssertStreamingAsync( + ChatOptions options, + List plan, + List? expected = null, + Func? configurePipeline = null, + IServiceProvider? services = null) + { + Assert.NotEmpty(plan); + + configurePipeline ??= static b => b.UseFunctionInvocation(); + + using CancellationTokenSource cts = new(); + List chat = [plan[0]]; + + using var innerClient = new TestChatClient + { + CompleteStreamingAsyncCallback = (contents, actualOptions, actualCancellationToken) => + { + Assert.Same(chat, contents); + Assert.Equal(cts.Token, actualCancellationToken); + + return YieldAsync(new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToStreamingChatCompletionUpdates()); + } + }; + + IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); + + var result = await service.CompleteStreamingAsync(chat, options, cts.Token).ToChatCompletionAsync(); + chat.Add(result.Message); + + expected ??= plan; + Assert.NotNull(result); + Assert.Equal(expected.Count, chat.Count); + for (int i = 0; i < expected.Count; i++) + { + var expectedMessage = expected[i]; + var chatMessage = chat[i]; + + Assert.Equal(expectedMessage.Role, chatMessage.Role); + Assert.Equal(expectedMessage.Text, chatMessage.Text); + Assert.Equal(expectedMessage.GetType(), chatMessage.GetType()); + + Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count); + for (int j = 0; j < expectedMessage.Contents.Count; j++) + { + var expectedItem = expectedMessage.Contents[j]; + var chatItem = chatMessage.Contents[j]; + + Assert.Equal(expectedItem.GetType(), chatItem.GetType()); + Assert.Equal(expectedItem.ToString(), chatItem.ToString()); + if (expectedItem is FunctionCallContent expectedFunctionCall) + { + var chatFunctionCall = (FunctionCallContent)chatItem; + Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name); + AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments); + } + else if (expectedItem is FunctionResultContent expectedFunctionResult) + { + var chatFunctionResult = (FunctionResultContent)chatItem; + AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result); + } + } + } + + return chat; + } + + private static async IAsyncEnumerable YieldAsync(params T[] items) + { + await Task.Yield(); + foreach (var item in items) + { + yield return item; + } + } }