Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket Keep-Alive Ping and Timeout (minimal) implementation #105841

Merged
merged 14 commits into from
Aug 9, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading;

namespace System.Net.WebSockets
{
/// <summary>
/// Central repository for default values used in WebSocket settings. Not all settings are relevant
/// to or configurable by all WebSocket implementations.
/// </summary>
internal static partial class WebSocketDefaults
{
public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.Zero;
public static readonly TimeSpan DefaultClientKeepAliveInterval = TimeSpan.FromSeconds(30);

public static readonly TimeSpan DefaultKeepAliveTimeout = Timeout.InfiniteTimeSpan;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ internal static partial class WebSocketValidate
SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~");

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
{
// Exception order:
// 1. WebSocketException(InvalidState) -- if invalid state
// 2. ObjectDisposedException

string? invalidStateMessage = GetInvalidStateMessage(currentState, validStates);
if (invalidStateMessage is null) // state is valid
{
// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
return;
}

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage);
}

internal static string? GetInvalidStateMessage(WebSocketState currentState, WebSocketState[] validStates)
{
string validStatesText = string.Empty;

Expand All @@ -47,18 +64,14 @@ internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDis
{
if (currentState == validState)
{
// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
return;
return null;
}
}

validStatesText = string.Join(", ", validStates);
}

throw new WebSocketException(
WebSocketError.InvalidState,
SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText));
return SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText);
}

internal static void ValidateSubprotocol(string subProtocol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ public class Http2LoopbackConnection : GenericLoopbackConnection
private readonly TimeSpan _timeout;
private int _lastStreamId;
private bool _expectClientDisconnect;
private readonly Action<string>? _debugLog;
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved

private readonly byte[] _prefix = new byte[24];
public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length);
public bool IsInvalid => _connectionSocket == null;
public Stream Stream => _connectionStream;
public Task<bool> SettingAckWaiter => _ignoredSettingsAckPromise?.Task;

private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse)
private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse, Action<string>? debugLog = null)
{
_connectionSocket = socket;
_connectionStream = stream;
_timeout = timeout;
_transparentPingResponse = transparentPingResponse;
_debugLog = debugLog;
}

public override string ToString()
Expand Down Expand Up @@ -83,7 +85,7 @@ public static async Task<Http2LoopbackConnection> CreateAsync(SocketWrapper sock
stream = sslStream;
}

var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse);
var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse, httpOptions.DebugLog);
await con.ReadPrefixAsync().ConfigureAwait(false);

return con;
Expand Down Expand Up @@ -368,6 +370,7 @@ public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = f
// and will ignore any errors if client has already shutdown
public async Task ShutdownIgnoringErrorsAsync(int lastStreamId, ProtocolErrors errorCode = ProtocolErrors.NO_ERROR)
{
_debugLog?.Invoke($"Http2LoopbackConnection.ShutdownIgnoringErrorsAsync() with lastStreamId={lastStreamId}, errorCode={errorCode}");
try
{
await SendGoAway(lastStreamId, errorCode).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ private Http2LoopbackConnection Connection
}
}

public Action<string>? DebugLog => _options.DebugLog;
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved

public static readonly TimeSpan Timeout = TimeSpan.FromSeconds(30);

public override Uri Address
Expand Down Expand Up @@ -186,6 +188,8 @@ public class Http2Options : GenericLoopbackOptions

public bool EnableTransparentPingResponse { get; set; } = true;

public Action<string>? DebugLog { get; set; }
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved

public Http2Options()
{
SslProtocols = SslProtocols.Tls12;
Expand Down
146 changes: 136 additions & 10 deletions src/libraries/Common/tests/TestUtilities/TestEventListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.IO;
using System.Text;
Expand Down Expand Up @@ -31,6 +32,7 @@ public sealed class TestEventListener : EventListener
"Private.InternalDiagnostics.System.Net.Sockets",
"Private.InternalDiagnostics.System.Net.Security",
"Private.InternalDiagnostics.System.Net.Quic",
"Private.InternalDiagnostics.System.Net.WebSockets",
"Private.InternalDiagnostics.System.Net.Http.WinHttpHandler",
"Private.InternalDiagnostics.System.Net.HttpListener",
"Private.InternalDiagnostics.System.Net.Mail",
Expand All @@ -41,36 +43,39 @@ public sealed class TestEventListener : EventListener

private readonly Action<string> _writeFunc;
private readonly HashSet<string> _sourceNames;
private readonly bool _enableActivityId;

// Until https://github.com/dotnet/runtime/issues/63979 is solved.
private List<EventSource> _eventSources = new List<EventSource>();

public TestEventListener(TextWriter output, params string[] sourceNames)
: this(str => output.WriteLine(str), sourceNames)
: this(output.WriteLine, sourceNames)
{ }

public TestEventListener(ITestOutputHelper output, params string[] sourceNames)
: this(str => output.WriteLine(str), sourceNames)
: this(output.WriteLine, sourceNames)
{ }

public TestEventListener(Action<string> writeFunc, params string[] sourceNames)
: this(writeFunc, enableActivityId: false, sourceNames)
{ }

public TestEventListener(Action<string> writeFunc, bool enableActivityId, params string[] sourceNames)
{
List<EventSource> eventSources = _eventSources;

lock (this)
{
_writeFunc = writeFunc;
_sourceNames = new HashSet<string>(sourceNames);
_enableActivityId = enableActivityId;
_eventSources = null;
}

// eventSources were populated in the base ctor and are now owned by this thread, enable them now.
foreach (EventSource eventSource in eventSources)
{
if (_sourceNames.Contains(eventSource.Name))
{
EnableEvents(eventSource, EventLevel.LogAlways);
}
EnableEventSource(eventSource);
}
}

Expand All @@ -90,20 +95,42 @@ protected override void OnEventSourceCreated(EventSource eventSource)
}

// Second pass called after our ctor, allow logging for specified source names.
EnableEventSource(eventSource);
}

private void EnableEventSource(EventSource eventSource)
{
if (_sourceNames.Contains(eventSource.Name))
{
EnableEvents(eventSource, EventLevel.LogAlways);
}
else if (_enableActivityId && eventSource.Name == "System.Threading.Tasks.TplEventSource")
{
EnableEvents(eventSource, EventLevel.LogAlways, (EventKeywords)0x80 /* TasksFlowActivityIds */);
}
}

protected override void OnEventWritten(EventWrittenEventArgs eventData)
{
StringBuilder sb = new StringBuilder().
StringBuilder sb = new StringBuilder();

#if NET || NETSTANDARD2_1_OR_GREATER
Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}[{eventData.EventName}] ");
#else
Append($"[{eventData.EventName}] ");
sb.Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}");
if (_enableActivityId)
{
if (eventData.ActivityId != Guid.Empty)
{
string activityId = ActivityHelpers.ActivityPathString(eventData.ActivityId);
sb.Append($" {activityId} {new string('-', activityId.Length / 2 - 1 )} ");
}
else
{
sb.Append(" / ");
}
}
#endif
sb.Append($"[{eventData.EventName}] ");

for (int i = 0; i < eventData.Payload?.Count; i++)
{
if (i > 0)
Expand All @@ -116,4 +143,103 @@ protected override void OnEventWritten(EventWrittenEventArgs eventData)
}
catch { }
}

// From https://gist.github.com/MihaZupan/cc63ee68b4146892f2e5b640ed57bc09
private static class ActivityHelpers
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
{
private enum NumberListCodes : byte
{
End = 0x0,
LastImmediateValue = 0xA,
PrefixCode = 0xB,
MultiByte1 = 0xC,
}

public static unsafe bool IsActivityPath(Guid guid)
{
uint* uintPtr = (uint*)&guid;
uint sum = uintPtr[0] + uintPtr[1] + uintPtr[2] + 0x599D99AD;
return ((sum & 0xFFF00000) == (uintPtr[3] & 0xFFF00000));
}

public static unsafe string ActivityPathString(Guid guid)
=> IsActivityPath(guid) ? CreateActivityPathString(guid) : guid.ToString();

internal static unsafe string CreateActivityPathString(Guid guid)
{
Debug.Assert(IsActivityPath(guid));

StringBuilder sb = new StringBuilder();

byte* bytePtr = (byte*)&guid;
byte* endPtr = bytePtr + 12;
char separator = '/';
while (bytePtr < endPtr)
{
uint nibble = (uint)(*bytePtr >> 4);
bool secondNibble = false;
NextNibble:
if (nibble == (uint)NumberListCodes.End)
{
break;
}
if (nibble <= (uint)NumberListCodes.LastImmediateValue)
{
sb.Append('/').Append(nibble);
if (!secondNibble)
{
nibble = (uint)(*bytePtr & 0xF);
secondNibble = true;
goto NextNibble;
}
bytePtr++;
continue;
}
else if (nibble == (uint)NumberListCodes.PrefixCode)
{
if (!secondNibble)
{
nibble = (uint)(*bytePtr & 0xF);
}
else
{
bytePtr++;
if (endPtr <= bytePtr)
{
break;
}
nibble = (uint)(*bytePtr >> 4);
}
if (nibble < (uint)NumberListCodes.MultiByte1)
{
return guid.ToString();
}
separator = '$';
}
Debug.Assert((uint)NumberListCodes.MultiByte1 <= nibble);
uint numBytes = nibble - (uint)NumberListCodes.MultiByte1;
uint value = 0;
if (!secondNibble)
{
value = (uint)(*bytePtr & 0xF);
}
bytePtr++;
numBytes++;
if (endPtr < bytePtr + numBytes)
{
break;
}
for (int i = (int)numBytes - 1; 0 <= i; --i)
{
value = (value << 8) + bytePtr[i];
}
sb.Append(separator).Append(value);

bytePtr += numBytes;
}

sb.Append('/');
return sb.ToString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ internal ClientWebSocketOptions() { }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveInterval { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.IWebProxy? Proxy { get { throw null; } set { } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<Compile Include="System\Net\WebSockets\ClientWebSocketOptions.cs" />
<Compile Include="System\Net\WebSockets\WebSocketHandle.Managed.cs" />
<Compile Include="$(CommonPath)System\Net\HttpKnownHeaderNames.cs" Link="Common\System\Net\HttpKnownHeaderNames.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketDefaults.cs" Link="Common\System\Net\WebSockets\WebSocketDefaults.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'browser'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ public TimeSpan KeepAliveInterval
set => throw new PlatformNotSupportedException();
}

[UnsupportedOSPlatform("browser")]
public TimeSpan KeepAliveTimeout
{
get => throw new PlatformNotSupportedException();
set => throw new PlatformNotSupportedException();
}

[UnsupportedOSPlatform("browser")]
public WebSocketDeflateOptions? DangerousDeflateOptions
{
Expand Down
Loading
Loading