Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix client not raising HTTP diagnostic source events #1211

Merged
merged 1 commit into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/Shared/HttpHandlerFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ public static HttpMessageHandler EnsureTelemetryHandler(HttpMessageHandler handl
// so wrap with a handler that is responsible for setting the telemetry header.
if (HasHttpHandlerType(handler, "System.Net.Http.SocketsHttpHandler"))
{
return new TelemetryHeaderHandler(handler);
// Double check telemetry handler hasn't already been added by something else
// like the client factory when it created the primary handler.
if (!HasHttpHandlerType(handler, typeof(TelemetryHeaderHandler).FullName!))
{
return new TelemetryHeaderHandler(handler);
}
}

return handler;
Expand Down
224 changes: 215 additions & 9 deletions src/Shared/TelemetryHeaderHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,243 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;

// Copied with permission from https://github.com/dotnet/runtime/blob/7565d60891e43415f5e81b59e50c52dba46ee0d7/src/libraries/System.Net.Http/src/System/Net/Http/DiagnosticsHandler.cs
namespace Grpc.Shared
{
/// <summary>
/// This handler:
/// 1. Propagates trace headers.
/// 2. Starts and stops System.Net.Http.HttpRequestOut activity.
/// 3. Writes to diagnostics listener.
///
/// These actions are required for OpenTelemetry and for AppInsights to detect HTTP requests.
/// Note: Deprecated diagnostics listener events are still used by AppInsights.
///
/// Usually this logic is handled by https://github.com/dotnet/runtime/blob/7565d60891e43415f5e81b59e50c52dba46ee0d7/src/libraries/System.Net.Http/src/System/Net/Http/DiagnosticsHandler.cs.
/// DiagnosticsHandler is only run when HttpClientHandler is used.
/// If SocketsHttpHandler is used directly then this handler is added as a subsitute.
/// </summary>
internal sealed class TelemetryHeaderHandler : DelegatingHandler
{
public const string RequestIdHeaderName = "Request-Id";
public const string CorrelationContextHeaderName = "Correlation-Context";

public const string TraceParentHeaderName = "traceparent";
public const string TraceStateHeaderName = "tracestate";

public TelemetryHeaderHandler(HttpMessageHandler innerHandler) : base(innerHandler)
{
}

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (Activity.Current != null || DiagnosticListener.IsEnabled())
{
return SendAsyncCore(request, cancellationToken);
}

return base.SendAsync(request, cancellationToken);
}

private async Task<HttpResponseMessage> SendAsyncCore(HttpRequestMessage request, CancellationToken cancellationToken)
{
Activity? activity = null;
var diagnosticListener = DiagnosticListener;

// if there is no listener, but propagation is enabled (with previous IsEnabled() check)
// do not write any events just start/stop Activity and propagate Ids
if (!diagnosticListener.IsEnabled())
{
activity = new Activity(ActivityName);
activity.Start();
InjectHeaders(activity, request);

try
{
return await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
}
finally
{
activity.Stop();
}
}

var loggingRequestId = Guid.Empty;

// There is a listener. Check if listener wants to be notified about HttpClient Activities
if (diagnosticListener.IsEnabled(ActivityName, request))
{
activity = new Activity(ActivityName);

// Only send start event to users who subscribed for it, but start activity anyway
if (diagnosticListener.IsEnabled(ActivityStartName))
{
diagnosticListener.StartActivity(activity, new ActivityStartData(request));
}
else
{
activity.Start();
}
}
// try to write System.Net.Http.Request event (deprecated)
if (diagnosticListener.IsEnabled(RequestWriteNameDeprecated))
{
var timestamp = Stopwatch.GetTimestamp();
loggingRequestId = Guid.NewGuid();
diagnosticListener.Write(RequestWriteNameDeprecated, new RequestData(request, loggingRequestId, timestamp));
}

// If we are on at all, we propagate current activity information
var currentActivity = Activity.Current;
if (currentActivity != null)
{
InjectHeaders(currentActivity, request);
}

return base.SendAsync(request, cancellationToken);
HttpResponseMessage? response = null;
var taskStatus = TaskStatus.RanToCompletion;
try
{
response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
return response;
}
catch (OperationCanceledException)
{
taskStatus = TaskStatus.Canceled;

// we'll report task status in HttpRequestOut.Stop
throw;
}
catch (Exception ex)
{
taskStatus = TaskStatus.Faulted;

if (diagnosticListener.IsEnabled(ExceptionEventName))
{
// If request was initially instrumented, Activity.Current has all necessary context for logging
// Request is passed to provide some context if instrumentation was disabled and to avoid
// extensive Activity.Tags usage to tunnel request properties
diagnosticListener.Write(ExceptionEventName, new ExceptionData(ex, request));
}
throw;
}
finally
{
// always stop activity if it was started
if (activity != null)
{
diagnosticListener.StopActivity(activity, new ActivityStopData(
response,
// If request is failed or cancelled, there is no response, therefore no information about request;
// pass the request in the payload, so consumers can have it in Stop for failed/canceled requests
// and not retain all requests in Start
request,
taskStatus));
}
// Try to write System.Net.Http.Response event (deprecated)
if (diagnosticListener.IsEnabled(ResponseWriteNameDeprecated))
{
var timestamp = Stopwatch.GetTimestamp();
diagnosticListener.Write(ResponseWriteNameDeprecated,
new ResponseData(
response,
loggingRequestId,
timestamp,
taskStatus));
}
}
}

public static readonly DiagnosticListener DiagnosticListener = new DiagnosticListener(DiagnosticListenerName);

private const string DiagnosticListenerName = "HttpHandlerDiagnosticListener";
private const string RequestWriteNameDeprecated = "System.Net.Http.Request";
private const string ResponseWriteNameDeprecated = "System.Net.Http.Response";

private const string ExceptionEventName = "System.Net.Http.Exception";
private const string ActivityName = "System.Net.Http.HttpRequestOut";
private const string ActivityStartName = "System.Net.Http.HttpRequestOut.Start";

private const string RequestIdHeaderName = "Request-Id";
private const string CorrelationContextHeaderName = "Correlation-Context";

private const string TraceParentHeaderName = "traceparent";
private const string TraceStateHeaderName = "tracestate";

private sealed class ActivityStartData
{
internal ActivityStartData(HttpRequestMessage request)
{
Request = request;
}

public HttpRequestMessage Request { get; }

public override string ToString() => $"{{ {nameof(Request)} = {Request} }}";
}

private sealed class ActivityStopData
{
internal ActivityStopData(HttpResponseMessage? response, HttpRequestMessage request, TaskStatus requestTaskStatus)
{
Response = response;
Request = request;
RequestTaskStatus = requestTaskStatus;
}

public HttpResponseMessage? Response { get; }
public HttpRequestMessage Request { get; }
public TaskStatus RequestTaskStatus { get; }

public override string ToString() => $"{{ {nameof(Response)} = {Response}, {nameof(Request)} = {Request}, {nameof(RequestTaskStatus)} = {RequestTaskStatus} }}";
}

private sealed class ExceptionData
{
internal ExceptionData(Exception exception, HttpRequestMessage request)
{
Exception = exception;
Request = request;
}

public Exception Exception { get; }
public HttpRequestMessage Request { get; }

public override string ToString() => $"{{ {nameof(Exception)} = {Exception}, {nameof(Request)} = {Request} }}";
}

private sealed class RequestData
{
internal RequestData(HttpRequestMessage request, Guid loggingRequestId, long timestamp)
{
Request = request;
LoggingRequestId = loggingRequestId;
Timestamp = timestamp;
}

public HttpRequestMessage Request { get; }
public Guid LoggingRequestId { get; }
public long Timestamp { get; }

public override string ToString() => $"{{ {nameof(Request)} = {Request}, {nameof(LoggingRequestId)} = {LoggingRequestId}, {nameof(Timestamp)} = {Timestamp} }}";
}

private sealed class ResponseData
{
internal ResponseData(HttpResponseMessage? response, Guid loggingRequestId, long timestamp, TaskStatus requestTaskStatus)
{
Response = response;
LoggingRequestId = loggingRequestId;
Timestamp = timestamp;
RequestTaskStatus = requestTaskStatus;
}

public HttpResponseMessage? Response { get; }
public Guid LoggingRequestId { get; }
public long Timestamp { get; }
public TaskStatus RequestTaskStatus { get; }

public override string ToString() => $"{{ {nameof(Response)} = {Response}, {nameof(LoggingRequestId)} = {LoggingRequestId}, {nameof(Timestamp)} = {Timestamp}, {nameof(RequestTaskStatus)} = {RequestTaskStatus} }}";
}

private static void InjectHeaders(Activity currentActivity, HttpRequestMessage request)
Expand Down Expand Up @@ -89,7 +296,6 @@ private static void InjectHeaders(Activity currentActivity, HttpRequestMessage r
}
}
}

}
}
#endif
101 changes: 100 additions & 1 deletion test/FunctionalTests/Client/TelemetryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#endregion

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Http;
using System.Threading.Tasks;
using Greet;
Expand Down Expand Up @@ -83,10 +85,29 @@ Task<HelloReply> UnaryTelemetryHeader(HelloRequest request, ServerCallContext co
var client = CreateClient(clientType, method, handler);

// Act
await client.UnaryCall(new HelloRequest());
#if NET5_0
var result = new List<KeyValuePair<string, object?>>();

using var allSubscription = new AllListenersObserver(new Dictionary<string, IObserver<KeyValuePair<string, object?>>>
{
["HttpHandlerDiagnosticListener"] = new ObserverToList<KeyValuePair<string, object?>>(result)
});
using (DiagnosticListener.AllListeners.Subscribe(allSubscription))
#endif
{
await client.UnaryCall(new HelloRequest());
}

// Assert
Assert.IsNotNull(telemetryHeader);

#if NET5_0
Assert.AreEqual(4, result.Count);
Assert.AreEqual("System.Net.Http.HttpRequestOut.Start", result[0].Key);
Assert.AreEqual("System.Net.Http.Request", result[1].Key);
Assert.AreEqual("System.Net.Http.HttpRequestOut.Stop", result[2].Key);
Assert.AreEqual("System.Net.Http.Response", result[3].Key);
#endif
}

private TestClient<HelloRequest, HelloReply> CreateClient(ClientType clientType, Method<HelloRequest, HelloReply> method, HttpMessageHandler? handler)
Expand Down Expand Up @@ -133,5 +154,83 @@ public enum ClientType
Channel,
ClientFactory
}

internal class AllListenersObserver : IObserver<DiagnosticListener>, IDisposable
{
private readonly Dictionary<string, IObserver<KeyValuePair<string, object?>>> _observers;
private readonly List<IDisposable> _subscriptions;

public AllListenersObserver(Dictionary<string, IObserver<KeyValuePair<string, object?>>> observers)
{
_observers = observers;
_subscriptions = new List<IDisposable>();
}

public bool Completed { get; private set; }

public void Dispose()
{
foreach (var subscription in _subscriptions)
{
subscription.Dispose();
}
}

public void OnCompleted()
{
Completed = true;
}

public void OnError(Exception error)
{
throw new Exception("Observer error", error);
}

public void OnNext(DiagnosticListener value)
{
if (value?.Name != null && _observers.TryGetValue(value.Name, out var observer))
{
_subscriptions.Add(value.Subscribe(observer));
}
}
}

internal class ObserverToList<T> : IObserver<T>
{
public ObserverToList(List<T> output, Predicate<T>? filter = null, string? name = null)
{
_output = output;
_output.Clear();
_filter = filter;
_name = name;
}

public bool Completed { get; private set; }

#region private
public void OnCompleted()
{
Completed = true;
}

public void OnError(Exception error)
{
Assert.True(false, "Error happened on IObserver");
}

public void OnNext(T value)
{
Assert.False(Completed);
if (_filter == null || _filter(value))
{
_output.Add(value);
}
}

private List<T> _output;
private Predicate<T>? _filter;
private string? _name; // for debugging
#endregion
}
}
}