Skip to content

Commit

Permalink
Use the logging generator in LoggingChatClient / LoggingEmbeddingGene…
Browse files Browse the repository at this point in the history
…rator (#5508)
  • Loading branch information
stephentoub authored Oct 11, 2024
1 parent 058d827 commit 85e70b0
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

#pragma warning disable EA0000 // Use source generated logging methods for improved performance
#pragma warning disable CA2254 // Template should be a static expression

namespace Microsoft.Extensions.AI;

/// <summary>A delegating chat client that logs chat operations to an <see cref="ILogger"/>.</summary>
public class LoggingChatClient : DelegatingChatClient
public partial class LoggingChatClient : DelegatingChatClient
{
/// <summary>An <see cref="ILogger"/> instance used for all logging.</summary>
private readonly ILogger _logger;
Expand Down Expand Up @@ -45,7 +42,18 @@ public JsonSerializerOptions JsonSerializerOptions
public override async Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
LogStart(chatMessages, options);
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokedSensitive(nameof(CompleteAsync), AsJson(chatMessages), AsJson(options), AsJson(Metadata));
}
else
{
LogInvoked(nameof(CompleteAsync));
}
}

try
{
var completion = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
Expand All @@ -54,20 +62,24 @@ public override async Task<ChatCompletion> CompleteAsync(
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (completion, _jsonSerializerOptions), null, static (state, _) =>
$"CompleteAsync completed: {JsonSerializer.Serialize(state.completion, state._jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion)))}");
LogCompletedSensitive(nameof(CompleteAsync), AsJson(completion));
}
else
{
_logger.LogDebug("CompleteAsync completed.");
LogCompleted(nameof(CompleteAsync));
}
}

return completion;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteAsync failed.");
LogInvocationFailed(nameof(CompleteAsync), ex);
throw;
}
}
Expand All @@ -76,16 +88,31 @@ public override async Task<ChatCompletion> CompleteAsync(
public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
LogStart(chatMessages, options);
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokedSensitive(nameof(CompleteStreamingAsync), AsJson(chatMessages), AsJson(options), AsJson(Metadata));
}
else
{
LogInvoked(nameof(CompleteStreamingAsync));
}
}

IAsyncEnumerator<StreamingChatCompletionUpdate> e;
try
{
e = base.CompleteStreamingAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken);
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteStreamingAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteStreamingAsync failed.");
LogInvocationFailed(nameof(CompleteStreamingAsync), ex);
throw;
}

Expand All @@ -103,52 +130,63 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt

update = e.Current;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteStreamingAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteStreamingAsync failed.");
LogInvocationFailed(nameof(CompleteStreamingAsync), ex);
throw;
}

if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (update, _jsonSerializerOptions), null, static (state, _) =>
$"CompleteStreamingAsync received update: {JsonSerializer.Serialize(state.update, state._jsonSerializerOptions.GetTypeInfo(typeof(StreamingChatCompletionUpdate)))}");
LogStreamingUpdateSensitive(AsJson(update));
}
else
{
_logger.LogDebug("CompleteStreamingAsync received update.");
LogStreamingUpdate();
}
}

yield return update;
}

_logger.LogDebug("CompleteStreamingAsync completed.");
LogCompleted(nameof(CompleteStreamingAsync));
}
finally
{
await e.DisposeAsync().ConfigureAwait(false);
}
}

private void LogStart(IList<ChatMessage> chatMessages, ChatOptions? options, [CallerMemberName] string? methodName = null)
{
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (methodName, chatMessages, options, this), null, static (state, _) =>
$"{state.methodName} invoked: " +
$"Messages: {JsonSerializer.Serialize(state.chatMessages, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(IList<ChatMessage>)))}. " +
$"Options: {JsonSerializer.Serialize(state.options, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatOptions)))}. " +
$"Metadata: {JsonSerializer.Serialize(state.Item4.Metadata, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatClientMetadata)))}.");
}
else
{
_logger.LogDebug($"{methodName} invoked.");
}
}
}
private string AsJson<T>(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T)));

[LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")]
private partial void LogInvoked(string methodName);

[LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")]
private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata);

[LoggerMessage(LogLevel.Debug, "{MethodName} completed.")]
private partial void LogCompleted(string methodName);

[LoggerMessage(LogLevel.Trace, "{MethodName} completed: {ChatCompletion}.")]
private partial void LogCompletedSensitive(string methodName, string chatCompletion);

[LoggerMessage(LogLevel.Debug, "CompleteStreamingAsync received update.")]
private partial void LogStreamingUpdate();

[LoggerMessage(LogLevel.Trace, "CompleteStreamingAsync received update: {StreamingChatCompletionUpdate}")]
private partial void LogStreamingUpdateSensitive(string streamingChatCompletionUpdate);

[LoggerMessage(LogLevel.Debug, "{MethodName} canceled.")]
private partial void LogInvocationCanceled(string methodName);

[LoggerMessage(LogLevel.Error, "{MethodName} failed.")]
private partial void LogInvocationFailed(string methodName, Exception error);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

#pragma warning disable EA0000 // Use source generated logging methods for improved performance

namespace Microsoft.Extensions.AI;

/// <summary>A delegating embedding generator that logs embedding generation operations to an <see cref="ILogger"/>.</summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
public class LoggingEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
public partial class LoggingEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
where TEmbedding : Embedding
{
/// <summary>An <see cref="ILogger"/> instance used for all logging.</summary>
Expand Down Expand Up @@ -50,33 +48,48 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(IEnume
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (values, options, this), null, static (state, _) =>
"GenerateAsync invoked: " +
$"Values: {JsonSerializer.Serialize(state.values, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(IEnumerable<TInput>)))}. " +
$"Options: {JsonSerializer.Serialize(state.options, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGenerationOptions)))}. " +
$"Metadata: {JsonSerializer.Serialize(state.Item3.Metadata, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGeneratorMetadata)))}.");
LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(Metadata));
}
else
{
_logger.LogDebug("GenerateAsync invoked.");
LogInvoked();
}
}

try
{
var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false);

if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("GenerateAsync generated {Count} embedding(s).", embeddings.Count);
}
LogCompleted(embeddings.Count);

return embeddings;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled();
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "GenerateAsync failed.");
LogInvocationFailed(ex);
throw;
}
}

private string AsJson<T>(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T)));

[LoggerMessage(LogLevel.Debug, "GenerateAsync invoked.")]
private partial void LogInvoked();

[LoggerMessage(LogLevel.Trace, "GenerateAsync invoked: {Values}. Options: {EmbeddingGenerationOptions}. Metadata: {EmbeddingGeneratorMetadata}.")]
private partial void LogInvokedSensitive(string values, string embeddingGenerationOptions, string embeddingGeneratorMetadata);

[LoggerMessage(LogLevel.Debug, "GenerateAsync generated {EmbeddingsCount} embedding(s).")]
private partial void LogCompleted(int embeddingsCount);

[LoggerMessage(LogLevel.Debug, "GenerateAsync canceled.")]
private partial void LogInvocationCanceled();

[LoggerMessage(LogLevel.Error, "GenerateAsync failed.")]
private partial void LogInvocationFailed(Exception error);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<PropertyGroup>
<InjectSharedCollectionExtensions>true</InjectSharedCollectionExtensions>
<InjectSharedEmptyCollections>true</InjectSharedEmptyCollections>
<DisableMicrosoftExtensionsLoggingSourceGenerator>false</DisableMicrosoftExtensionsLoggingSourceGenerator>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ await chatClient.CompleteAsync(

Assert.Collection(logger.Entries,
entry => Assert.Contains("What is the current secret number?", entry.Message),
entry => Assert.Contains("\"name\":\"GetSecretNumber\"", entry.Message),
entry => Assert.Contains($"\"result\":{secretNumber}", entry.Message),
entry => Assert.Contains("\"name\": \"GetSecretNumber\"", entry.Message),
entry => Assert.Contains($"\"result\": {secretNumber}", entry.Message),
entry => Assert.Contains(secretNumber.ToString(), entry.Message));
}

Expand All @@ -528,8 +528,8 @@ public virtual async Task Logging_LogsFunctionCalls_Streaming()
}

Assert.Contains(logger.Entries, e => e.Message.Contains("What is the current secret number?"));
Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\":\"GetSecretNumber\""));
Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\":{secretNumber}"));
Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\": \"GetSecretNumber\""));
Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\": {secretNumber}"));
}

[ConditionalFact]
Expand Down

0 comments on commit 85e70b0

Please sign in to comment.