From dd9de54b3499b85bf70f0d73dc2234e15aef1aec Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 1 Aug 2024 13:49:16 +0100 Subject: [PATCH 01/13] Add Keep-Alive Ping and Timeout implementation --- .../Net/WebSockets/WebSocketDefaults.cs | 19 + .../ref/System.Net.WebSockets.Client.cs | 2 + .../src/System.Net.WebSockets.Client.csproj | 1 + .../ClientWebSocketOptions.cs | 7 + .../Net/WebSockets/ClientWebSocketOptions.cs | 20 +- .../Net/WebSockets/WebSocketHandle.Managed.cs | 1 + .../tests/ClientWebSocketOptionsTests.cs | 19 + .../tests/CloseTest.cs | 11 +- .../tests/WebSocketHelper.cs | 34 +- .../ref/System.Net.WebSockets.cs | 1 + .../src/Resources/Strings.resx | 3 + .../src/System.Net.WebSockets.csproj | 9 + .../src/System/Net/WebSockets/AsyncMutex.cs | 11 + .../WebSockets/ManagedWebSocket.KeepAlive.cs | 315 ++++++++++++++ .../System/Net/WebSockets/ManagedWebSocket.cs | 396 +++++++++++++----- .../WebSockets/NetEventSource.WebSockets.cs | 286 +++++++++++++ .../src/System/Net/WebSockets/WebSocket.cs | 6 +- .../WebSockets/WebSocketCreationOptions.cs | 22 +- .../tests/System.Net.WebSockets.Tests.csproj | 2 + .../tests/WebSocketCloseTests.cs | 25 ++ .../tests/WebSocketKeepAliveTests.cs | 280 +++++++++++++ .../tests/WebSocketTests.cs | 21 + 22 files changed, 1377 insertions(+), 114 deletions(-) create mode 100644 src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs new file mode 100644 index 0000000000000..8f2c768fd4bc9 --- /dev/null +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + /// + /// Central repository for default values used in WebSocket settings. Not all settings are relevant + /// to or configurable by all WebSocket implementations. + /// + internal static partial class WebSocketDefaults + { + public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.Zero; + public static readonly TimeSpan DefaultClientKeepAliveInterval = TimeSpan.FromSeconds(30); + + public static readonly TimeSpan DefaultKeepAliveTimeout = Timeout.InfiniteTimeSpan; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index edb4eb043bcb0..88bf41f23fbce 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -42,6 +42,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index a309737d6917d..cbad5c01b6da0 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -27,6 +27,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 59096fc864d3a..aa8164d1099c2 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -122,6 +122,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public TimeSpan KeepAliveTimeout + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public WebSocketDeflateOptions? DangerousDeflateOptions { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index f78882f8b005d..3639bf8caaf4d 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -14,7 +14,8 @@ namespace System.Net.WebSockets public sealed class ClientWebSocketOptions { private bool _isReadOnly; // After ConnectAsync is called the options cannot be modified. - private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval; + private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultClientKeepAliveInterval; + private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout; private bool _useDefaultCredentials; private ICredentials? _credentials; private IWebProxy? _proxy; @@ -188,6 +189,23 @@ public TimeSpan KeepAliveInterval } } + [UnsupportedOSPlatform("browser")] + public TimeSpan KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + ThrowIfReadOnly(); + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(value), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, + Timeout.InfiniteTimeSpan.ToString())); + } + _keepAliveTimeout = value; + } + } + /// /// Gets or sets the options for the per-message-deflate extension. /// When present, the options are sent to the server during the handshake phase. If the server diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 3301bfead64c7..96e091663ff79 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -198,6 +198,7 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio IsServer = false, SubProtocol = subprotocol, KeepAliveInterval = options.KeepAliveInterval, + KeepAliveTimeout = options.KeepAliveTimeout, DangerousDeflateOptions = negotiatedDeflateOptions }); _negotiatedDeflateOptions = negotiatedDeflateOptions; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs index b2137a7faa7a2..7a39f2423cad8 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs @@ -142,6 +142,25 @@ public static void KeepAliveInterval_Roundtrips() AssertExtensions.Throws("value", () => cws.Options.KeepAliveInterval = TimeSpan.MinValue); } + [ConditionalFact(nameof(WebSocketsSupported))] + [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] + public static void KeepAliveTimeout_Roundtrips() + { + var cws = new ClientWebSocket(); + Assert.True(cws.Options.KeepAliveTimeout == Timeout.InfiniteTimeSpan); + + cws.Options.KeepAliveTimeout = TimeSpan.Zero; + Assert.Equal(TimeSpan.Zero, cws.Options.KeepAliveTimeout); + + cws.Options.KeepAliveTimeout = TimeSpan.MaxValue; + Assert.Equal(TimeSpan.MaxValue, cws.Options.KeepAliveTimeout); + + cws.Options.KeepAliveTimeout = Timeout.InfiniteTimeSpan; + Assert.Equal(Timeout.InfiniteTimeSpan, cws.Options.KeepAliveTimeout); + + AssertExtensions.Throws("value", () => cws.Options.KeepAliveTimeout = TimeSpan.MinValue); + } + [ConditionalFact(nameof(WebSocketsSupported))] [SkipOnPlatform(TestPlatforms.Browser, "Certificates not supported on browser")] public void RemoteCertificateValidationCallback_Roundtrips() diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs index fb73485fc7fe1..c0e71d42bb047 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs @@ -495,11 +495,11 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => try { using (var cws = new ClientWebSocket()) - using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + using (var testTimeoutCts = new CancellationTokenSource(TimeOutMilliseconds)) { - await ConnectAsync(cws, uri, cts.Token); + await ConnectAsync(cws, uri, testTimeoutCts.Token); - Task receiveTask = cws.ReceiveAsync(new byte[1], CancellationToken.None); + Task receiveTask = cws.ReceiveAsync(new byte[1], testTimeoutCts.Token); var cancelCloseCts = new CancellationTokenSource(); await Assert.ThrowsAnyAsync(async () => @@ -509,7 +509,12 @@ await Assert.ThrowsAnyAsync(async () => await t; }); + Assert.True(cancelCloseCts.Token.IsCancellationRequested); + Assert.False(testTimeoutCts.Token.IsCancellationRequested); + await Assert.ThrowsAnyAsync(() => receiveTask); + + Assert.False(testTimeoutCts.Token.IsCancellationRequested); } } finally diff --git a/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs index d409007f9995d..df29e843590e9 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs @@ -69,19 +69,35 @@ public static Task GetConnectedWebSocket( ITestOutputHelper output, TimeSpan keepAliveInterval = default, IWebProxy proxy = null, + HttpMessageInvoker? invoker = null) => + GetConnectedWebSocket( + server, + timeOutMilliseconds, + output, + options => + { + if (proxy != null) + { + options.Proxy = proxy; + } + if (keepAliveInterval.TotalSeconds > 0) + { + options.KeepAliveInterval = keepAliveInterval; + } + }, + invoker + ); + + public static Task GetConnectedWebSocket( + Uri server, + int timeOutMilliseconds, + ITestOutputHelper output, + Action configureOptions, HttpMessageInvoker? invoker = null) => Retry(output, async () => { var cws = new ClientWebSocket(); - if (proxy != null) - { - cws.Options.Proxy = proxy; - } - - if (keepAliveInterval.TotalSeconds > 0) - { - cws.Options.KeepAliveInterval = keepAliveInterval; - } + configureOptions(cws.Options); using (var cts = new CancellationTokenSource(timeOutMilliseconds)) { diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e3d230708b3f2..ae5337ec05385 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -141,6 +141,7 @@ public sealed partial class WebSocketCreationOptions public bool IsServer { get { throw null; } set { } } public string? SubProtocol { get { throw null; } set { } } public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } } public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } } public sealed partial class WebSocketDeflateOptions diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index 8e01fce49ad88..a57e81b239a92 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -138,6 +138,9 @@ The argument must be a value between {0} and {1}. + + The WebSocket didn't recieve a Pong frame in response to a Ping frame within the configured KeepAliveTimeout. + The WebSocket received a continuation frame with Per-Message Compressed flag set. diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 98ace5cfbf038..177e95dacee0a 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -17,6 +17,8 @@ + + @@ -29,6 +31,8 @@ + + + @@ -57,6 +65,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs index 4191466dd4efa..0de994799cd59 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Net; using System.Threading.Tasks; namespace System.Threading @@ -46,6 +47,12 @@ internal sealed class AsyncMutex /// Gets the object used to synchronize contended operations. private object SyncObj => this; + /// Attempts to syncronously enter the mutex. + /// This will succeed in case the mutex is not currently held nor contended. + /// Whether the mutex has been entered. + public bool TryEnter() + => Interlocked.CompareExchange(ref _gate, 0, 1) == 1; + /// Asynchronously waits to enter the mutex. /// The CancellationToken token to observe. /// A task that will complete when the mutex has been entered or the enter canceled. @@ -65,6 +72,8 @@ public Task EnterAsync(CancellationToken cancellationToken) Task Contended(CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate); + var w = new Waiter(this); // We need to register for cancellation before storing the waiter into the list. @@ -185,6 +194,8 @@ public void Exit() void Contended() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate); + Waiter? w; lock (SyncObj) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs new file mode 100644 index 0000000000000..cc34b61e7a044 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -0,0 +1,315 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal sealed partial class ManagedWebSocket : WebSocket + { + // "Observe" either a ValueTask result, or any exception, ignoring it + // to prevent the unobserved exception event from being raised. + public void Observe(ValueTask t) + { + if (t.IsCompletedSuccessfully) + { + t.GetAwaiter().GetResult(); + } + else + { + ObserveException(t.AsTask()); + } + } + + // "Observe" either a Task result, or any exception, ignoring it + // to prevent the unobserved exception event from being raised. + public void Observe(Task t) + { + if (t.IsCompletedSuccessfully) + { + t.GetAwaiter().GetResult(); + } + else + { + ObserveException(t); + } + } + + private void ObserveException(Task task) + { + task.ContinueWith( + LogFaulted, + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + + void LogFaulted(Task task) + { + Debug.Assert(task.IsFaulted); + + _ = task.Exception; // accessing exception anyway, to observe it regardless of whether the tracing is enabled + + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, task.Exception); + } + } + + private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null; + private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1; + private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1; + + private void HeartBeat() + { + if (IsUnsolicitedPongKeepAlive) + { + UnsolicitedPongHeartBeat(); + } + else + { + KeepAlivePingHeartBeat(); + } + } + + private void UnsolicitedPongHeartBeat() + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + this.Observe( + TrySendKeepAliveFrameAsync(MessageOpcode.Pong)); + } + + private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory? payload = null) + { + Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping); + + if (!IsValidSendState(_state)) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Cannot send keep-alive frame in {nameof(_state)}={_state}"); + + // we can't send any frames, but no need to throw as we are not observing errors anyway + return ValueTask.CompletedTask; + } + + payload ??= ReadOnlyMemory.Empty; + + return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None); + } + + private void KeepAlivePingHeartBeat() + { + Debug.Assert(_keepAlivePingState != null); + Debug.Assert(_keepAlivePingState.Exception == null); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"{nameof(_keepAlivePingState.AwaitingPong)}={_keepAlivePingState.AwaitingPong}"); + + try + { + if (_keepAlivePingState.AwaitingPong) + { + KeepAlivePingThrowIfTimedOut(); + } + else + { + SendKeepAlivePingIfNeeded(); + } + } + catch (Exception e) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, e); + + bool aborting = false; + lock (StateUpdateLock) + { + if (!_disposed) + { + // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal + // The exception needs to be assigned before _disposed is set to true + _keepAlivePingState.Exception = e; + aborting = true; + } + } + + if (aborting) + { + Abort(); + } + } + } + + private void KeepAlivePingThrowIfTimedOut() + { + Debug.Assert(_keepAlivePingState != null); + Debug.Assert(_keepAlivePingState.AwaitingPong); + Debug.Assert(_keepAlivePingState.WillTimeoutTimestamp != Timeout.Infinite); + + long now = Environment.TickCount64; + + if (now > Interlocked.Read(ref _keepAlivePingState.WillTimeoutTimestamp)) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {_keepAlivePingState.PingPayload}"); + } + + throw new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout); + } + } + + private void SendKeepAlivePingIfNeeded() + { + Debug.Assert(_keepAlivePingState != null); + Debug.Assert(!_keepAlivePingState.AwaitingPong); + + long now = Environment.TickCount64; + + // Check whether keep alive delay has passed since last frame received + if (now > Interlocked.Read(ref _keepAlivePingState.NextPingTimestamp)) + { + // Set the status directly to ping sent and set the timestamp + Interlocked.Exchange(ref _keepAlivePingState.WillTimeoutTimestamp, now + _keepAlivePingState.TimeoutMs); + _keepAlivePingState.AwaitingPong = true; + + long pingPayload = Interlocked.Increment(ref _keepAlivePingState.PingPayload); + + this.Observe( + SendPingAsync(pingPayload)); + } + } + + private async ValueTask SendPingAsync(long pingPayload) + { + Debug.Assert(_keepAlivePingState != null); + + byte[] pingPayloadBuffer = ArrayPool.Shared.Rent(sizeof(long)); + BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload); + try + { + await TrySendKeepAliveFrameAsync( + MessageOpcode.Ping, + pingPayloadBuffer.AsMemory(0, sizeof(long))) + .ConfigureAwait(false); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload); + } + finally + { + ArrayPool.Shared.Return(pingPayloadBuffer); + } + } + + private void OnDataReceived(int bytesRead) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + if (_keepAlivePingState != null && bytesRead > 0) + { + _keepAlivePingState.OnDataReceived(); + } + } + + private void ThrowIfDisposedOrKeepAliveFaulted() + { + Debug.Assert(_keepAlivePingState is not null); + + if (_disposed && _state == WebSocketState.Aborted && _keepAlivePingState.Exception is not null) + { + // If Exception is not null, it triggered the abort which also disposed the websocket + // We only save the Exception if it actually triggered the abort + throw new OperationCanceledException(nameof(WebSocketState.Aborted), _keepAlivePingState.Exception); + } + + ObjectDisposedException.ThrowIf(_disposed, this); + } + + private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[] validStates) + { + Debug.Assert(_keepAlivePingState is not null); + + try + { + WebSocketValidate.ThrowIfInvalidState(_state, _disposed, validStates); + } + catch (Exception exc) when (_state == WebSocketState.Aborted && _keepAlivePingState.Exception is not null) + { + // If Exception is not null, it triggered the abort which also disposed the websocket + // We only save the Exception if it actually triggered the abort + if (exc is ObjectDisposedException ode && ode.ObjectName == typeof(ManagedWebSocket).FullName) + { + throw new OperationCanceledException(nameof(WebSocketState.Aborted), _keepAlivePingState.Exception); + } + + if (exc is WebSocketException we && we.WebSocketErrorCode == WebSocketError.InvalidState) + { + throw new WebSocketException(WebSocketError.InvalidState, we.Message, _keepAlivePingState.Exception); + } + } + } + + private sealed class KeepAlivePingState + { + internal const int PingPayloadSize = sizeof(long); + internal const long MinIntervalMs = 1; + + internal long DelayMs; + internal long TimeoutMs; + internal long NextPingTimestamp; + internal long WillTimeoutTimestamp; + + internal long HeartBeatIntervalMs; + + internal bool AwaitingPong; + internal long PingPayload; + internal Exception? Exception; + + public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) + { + DelayMs = TimeSpanToMs(keepAliveInterval); + TimeoutMs = TimeSpanToMs(keepAliveTimeout); + NextPingTimestamp = Environment.TickCount64 + DelayMs; + WillTimeoutTimestamp = Timeout.Infinite; + + HeartBeatIntervalMs = Math.Max( + Math.Min(DelayMs, TimeoutMs) / 4, + MinIntervalMs); + + static long TimeSpanToMs(TimeSpan value) => Math.Max( + (long) Math.Min(value.TotalMilliseconds, int.MaxValue), + MinIntervalMs); + } + + internal void OnDataReceived() + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + Interlocked.Exchange(ref NextPingTimestamp, Environment.TickCount64 + DelayMs); + } + + internal void OnPongResponseReceived(Span pongPayload) + { + Debug.Assert(AwaitingPong); + Debug.Assert(pongPayload.Length == sizeof(long)); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + long pongPayloadValue = BinaryPrimitives.ReadInt64BigEndian(pongPayload); + if (pongPayloadValue == Interlocked.Read(ref PingPayload)) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayloadValue); + + Interlocked.Exchange(ref WillTimeoutTimestamp, Timeout.Infinite); + AwaitingPong = false; + } + else if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Received pong with unexpected payload {pongPayloadValue}. Expected {Interlocked.Read(ref PingPayload)}. Skipping."); + } + } + } + + + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 3ee864e71cc8c..04f934a1aec68 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -137,23 +137,34 @@ internal sealed partial class ManagedWebSocket : WebSocket private readonly WebSocketInflater? _inflater; private readonly WebSocketDeflater? _deflater; + private readonly KeepAlivePingState? _keepAlivePingState; + /// Initializes the websocket. /// The connected Stream. /// true if this is the server-side of the connection; false if this is the client-side of the connection. /// The agreed upon subprotocol for the connection. /// The interval to use for keep-alive pings. - internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + /// The timeout to use when waiting for keep-alive pong response. + internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(stream != null, $"Expected non-null {nameof(stream)}"); Debug.Assert(stream.CanRead, $"Expected readable {nameof(stream)}"); Debug.Assert(stream.CanWrite, $"Expected writeable {nameof(stream)}"); Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid {nameof(keepAliveInterval)}: {keepAliveInterval}"); + Debug.Assert(keepAliveTimeout == Timeout.InfiniteTimeSpan || keepAliveTimeout >= TimeSpan.Zero, $"Invalid {nameof(keepAliveTimeout)}: {keepAliveTimeout}"); _stream = stream; _isServer = isServer; _subprotocol = subprotocol; + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Associate(this, stream); + NetEventSource.Associate(this, _sendMutex); + NetEventSource.Associate(this, _receiveMutex); + } + // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer // supplied to ReceiveAsync. @@ -165,14 +176,33 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim // that could keep the web socket rooted in erroneous cases. if (keepAliveInterval > TimeSpan.Zero) { + long heartBeatIntervalMs = (long)keepAliveInterval.TotalMilliseconds; + if (keepAliveTimeout > TimeSpan.Zero) + { + _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout); + heartBeatIntervalMs = _keepAlivePingState.HeartBeatIntervalMs; + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Associate(this, _keepAlivePingState); + + NetEventSource.Trace(this, + $"Enabling Ping/Pong Keep-Alive strategy: ping delay={_keepAlivePingState.DelayMs}ms, timeout={_keepAlivePingState.TimeoutMs}ms, heartbeat={heartBeatIntervalMs}ms"); + } + } + else if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Enabling Unsolicited Pong Keep-Alive strategy: heartbeat={heartBeatIntervalMs}ms"); + } + _keepAliveTimer = new Timer(static s => { var wr = (WeakReference)s!; if (wr.TryGetTarget(out ManagedWebSocket? thisRef)) { - thisRef.SendKeepAliveFrameAsync(); + thisRef.HeartBeat(); } - }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + }, new WeakReference(this), heartBeatIntervalMs, heartBeatIntervalMs); } } @@ -180,7 +210,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim /// The connected Stream. /// The options with which the websocket must be created. internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) - : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval) + : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval, options.KeepAliveTimeout) { var deflateOptions = options.DangerousDeflateOptions; @@ -201,6 +231,8 @@ internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) public override void Dispose() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + lock (StateUpdateLock) { DisposeCore(); @@ -210,17 +242,23 @@ public override void Dispose() private void DisposeCore() { Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held"); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"{nameof(_disposed)}={_disposed}"); + if (!_disposed) { _disposed = true; _keepAliveTimer?.Dispose(); _stream.Dispose(); - if (_state < WebSocketState.Aborted) + WebSocketState state = _state; + if (state < WebSocketState.Aborted) { _state = WebSocketState.Closed; } + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); + DisposeSafe(_inflater, _receiveMutex); DisposeSafe(_deflater, _sendMutex); } @@ -234,15 +272,23 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex) if (lockTask.IsCompleted) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex); + resource.Dispose(); mutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex); } else { lockTask.GetAwaiter().UnsafeOnCompleted(() => { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex); + resource.Dispose(); mutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex); }); } } @@ -258,6 +304,8 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex) public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { throw new ArgumentException(SR.Format( @@ -276,6 +324,8 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { throw new ArgumentException(SR.Format( @@ -286,10 +336,11 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates); + ThrowIfInvalidState(s_validSendStates); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); return ValueTask.FromException(exc); } @@ -319,44 +370,53 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates); + ThrowIfInvalidState(s_validReceiveStates); return ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); return Task.FromException(exc); } } public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates); + ThrowIfInvalidState(s_validReceiveStates); return ReceiveAsyncPrivate(buffer, cancellationToken); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); return ValueTask.FromException(exc); } } public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription); try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseStates); + ThrowIfInvalidState(s_validCloseStates); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); return Task.FromException(exc); } @@ -365,13 +425,17 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription); return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken); } private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseOutputStates); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + ThrowIfInvalidState(s_validCloseOutputStates); await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); @@ -388,12 +452,16 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string public override void Abort() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + OnAborted(); Dispose(); // forcibly tear down connection } private void OnAborted() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + lock (StateUpdateLock) { WebSocketState state = _state; @@ -403,6 +471,8 @@ private void OnAborted() WebSocketState.Aborted : WebSocketState.Closed; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } } @@ -414,6 +484,8 @@ private void OnAborted() /// The CancellationToken to use to cancel the websocket. private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.SendFrameAsyncStarted(this, opcode.ToString(), payloadBuffer.Length); + // If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path. // Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again @@ -433,6 +505,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, { Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex); + // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. ValueTask writeTask = default; @@ -468,6 +542,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + return ValueTask.FromException( exc is OperationCanceledException ? exc : _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) : @@ -479,12 +555,19 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } return WaitForWriteTaskAsync(writeTask, shouldFlush: true); } + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFlush) { try @@ -495,8 +578,15 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl await _stream.FlushAsync().ConfigureAwait(false); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + + if (exc is not OperationCanceledException) + { + throw; + } + throw _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) : new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); @@ -505,12 +595,21 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, Task lockTask, CancellationToken cancellationToken) { await lockTask.ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex); + try { int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); @@ -520,8 +619,15 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + + if (exc is not OperationCanceledException) + { + throw; + } + throw _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc, cancellationToken) : new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); @@ -530,13 +636,19 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer) { - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); if (_deflater is not null && !disableCompression) { @@ -585,26 +697,6 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool return headerLength + payloadLength; } - private void SendKeepAliveFrameAsync() - { - // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures. - // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by - // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response. - ValueTask t = SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty, CancellationToken.None); - if (t.IsCompletedSuccessfully) - { - t.GetAwaiter().GetResult(); - } - else - { - // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised. - t.AsTask().ContinueWith(static p => { _ = p.Exception; }, - CancellationToken.None, - TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Default); - } - } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed) { // Client header format: @@ -697,13 +789,22 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo // those to be much less frequent (e.g. we should only get one close per websocket), and thus we can afford to pay // a bit more for readability and maintainability. - CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); + if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateStarted(this, payloadBuffer.Length); + + CancellationTokenRegistration registration = default; try { + if (cancellationToken.CanBeCanceled) + { + registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); + } + await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex); + try { - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); while (true) // in case we get control frames that should be ignored from the user's perspective { @@ -715,6 +816,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo MessageHeader header = _lastReceiveHeader; if (header.Processed) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Reading the next frame header"); + if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength))) { // Make sure we have the first two bytes, which includes the start of the payload length. @@ -758,6 +861,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } _receivedMaskOffsetOffset = 0; + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Next frame opcode={header.Opcode}, fin={header.Fin}, compressed={header.Compressed}, payloadLength={header.PayloadLength}"); + } + if (header.PayloadLength == 0 && header.Compressed) { // In the rare case where we receive a compressed message with no payload @@ -845,6 +953,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo { ThrowEOFUnexpected(); } + OnDataReceived(numBytesRead); totalBytesReceived += numBytesRead; } @@ -882,6 +991,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } + if (header.Processed) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Data frame fully processed"); + } + _lastReceiveHeader = header; return GetReceiveResult( totalBytesReceived, @@ -892,12 +1006,25 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo finally { _receiveMutex.Exit(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + + if (exc is OperationCanceledException) + { + throw; + } + if (_state == WebSocketState.Aborted) { + if (_disposed && _keepAlivePingState?.Exception is not null) + { + throw new OperationCanceledException(nameof(WebSocketState.Aborted), new AggregateException(exc, _keepAlivePingState.Exception)); + } + throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); } OnAborted(); @@ -912,6 +1039,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo finally { registration.Dispose(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateCompleted(this); } } @@ -941,14 +1069,17 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat lock (StateUpdateLock) { _receivedCloseFrame = true; - if (_sentCloseFrame && _state < WebSocketState.Closed) + WebSocketState state = _state; + if (_sentCloseFrame && state < WebSocketState.Closed) { _state = WebSocketState.Closed; } - else if (_state < WebSocketState.CloseReceived) + else if (state < WebSocketState.CloseReceived) { _state = WebSocketState.CloseReceived; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure; @@ -1005,6 +1136,8 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat /// Issues a read on the stream to wait for EOF. private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. // We simply issue a read and don't care what we get back; we could validate that we don't get // additional data, but at this point we're about to close the connection and we're just stalling @@ -1036,19 +1169,31 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca /// The CancellationToken used to cancel the websocket operation. private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken) { + Debug.Assert(_receiveMutex.IsHeld, $"Caller should hold the {nameof(_receiveMutex)}"); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + // Consume any (optional) payload associated with the ping/pong. if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength) { await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); } + bool processPing = header.Opcode == MessageOpcode.Ping; + + bool processPong = header.Opcode == MessageOpcode.Pong + && _keepAlivePingState is not null && _keepAlivePingState.AwaitingPong + && header.PayloadLength == KeepAlivePingState.PingPayloadSize; + + if ((processPing || processPong) && _isServer) + { + ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); + } + // If this was a ping, send back a pong response. - if (header.Opcode == MessageOpcode.Ping) + if (processPing) { - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Ping"); await SendFrameAsync( MessageOpcode.Pong, @@ -1057,6 +1202,16 @@ await SendFrameAsync( _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), cancellationToken).ConfigureAwait(false); } + else if (processPong) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Pong with a suitable payload length"); + + _keepAlivePingState!.OnPongResponseReceived(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength)); + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Ignoring incoming Unsolicited Pong"); + } // Regardless of whether it was a ping or pong, we no longer need the payload. if (header.PayloadLength > 0) @@ -1115,6 +1270,8 @@ private static bool IsValidCloseStatus(WebSocketCloseStatus closeStatus) private async ValueTask CloseWithReceiveErrorAndThrowAsync( WebSocketCloseStatus closeStatus, WebSocketError error, string? errorMessage = null, Exception? innerException = null) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, errorMessage); + // Close the connection if it hasn't already been closed if (!_sentCloseFrame) { @@ -1258,79 +1415,98 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( /// The CancellationToken to use to cancel the websocket. private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { - // Send the close message. Skip sending a close frame if we're currently in a CloseSent state, - // for example having just done a CloseOutputAsync. - if (!_sentCloseFrame) + if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateStarted(this); + try { - await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); - } - - // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case - // there was a concurrent receive that ended up handling an immediate close frame response from the server. - // Of course it could also be Aborted if something happened concurrently to cause things to blow up. - Debug.Assert( - State == WebSocketState.CloseSent || - State == WebSocketState.Closed || - State == WebSocketState.Aborted, - $"Unexpected state {State}."); + // Send the close message. Skip sending a close frame if we're currently in a CloseSent state, + // for example having just done a CloseOutputAsync. + if (!_sentCloseFrame) + { + await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); + } - // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed - // State then it means we already received a close frame. If we are in the Aborted State, then we should not - // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection". - if (State == WebSocketState.CloseSent) - { - // Wait until we've received a close response - byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); - try + // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case + // there was a concurrent receive that ended up handling an immediate close frame response from the server. + // Of course it could also be Aborted if something happened concurrently to cause things to blow up. + Debug.Assert( + State == WebSocketState.CloseSent || + State == WebSocketState.Closed || + State == WebSocketState.Aborted, + $"Unexpected state {State}."); + + // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed + // State then it means we already received a close frame. If we are in the Aborted State, then we should not + // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection". + if (State == WebSocketState.CloseSent) { - // Loop until we've received a close frame. - while (!_receivedCloseFrame) + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Waiting for a close frame"); + + // Wait until we've received a close response + byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); + try { - // Enter the receive lock in order to get a consistent view of whether we've received a close - // frame. If we haven't, issue a receive. Since that receive will try to take the same - // non-entrant receive lock, we then exit the lock before waiting for the receive to complete, - // as it will always complete asynchronously and only after we've exited the lock. - ValueTask receiveTask = default; - try + // Loop until we've received a close frame. + while (!_receivedCloseFrame) { - await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + // Enter the receive lock in order to get a consistent view of whether we've received a close + // frame. If we haven't, issue a receive. Since that receive will try to take the same + // non-entrant receive lock, we then exit the lock before waiting for the receive to complete, + // as it will always complete asynchronously and only after we've exited the lock. + ValueTask receiveTask = default; try { - if (!_receivedCloseFrame) + await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex); + + try + { + if (!_receivedCloseFrame) + { + receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken); + } + } + finally { - receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken); + _receiveMutex.Exit(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex); } } - finally + catch (OperationCanceledException) { - _receiveMutex.Exit(); + // If waiting on the receive lock was canceled, abort the connection, as we would do + // as part of the receive itself. + Abort(); + throw; } - } - catch (OperationCanceledException) - { - // If waiting on the receive lock was canceled, abort the connection, as we would do - // as part of the receive itself. - Abort(); - throw; - } - // Wait for the receive to complete if we issued one. - await receiveTask.ConfigureAwait(false); + // Wait for the receive to complete if we issued one. + await receiveTask.ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(closeBuffer); } } - finally + + // We're closed. Close the connection and update the status. + lock (StateUpdateLock) { - ArrayPool.Shared.Return(closeBuffer); + DisposeCore(); } } - - // We're closed. Close the connection and update the status. - lock (StateUpdateLock) + catch (Exception exc) { - DisposeCore(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + throw; + } + finally + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateCompleted(this); } } + /// Sends a close message to the server. /// The close status to send. /// The close status description to send. @@ -1370,14 +1546,17 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st lock (StateUpdateLock) { _sentCloseFrame = true; - if (_receivedCloseFrame && _state < WebSocketState.Closed) + WebSocketState state = _state; + if (_receivedCloseFrame && state < WebSocketState.Closed) { _state = WebSocketState.Closed; } - else if (_state < WebSocketState.CloseSent) + else if (state < WebSocketState.CloseSent) { _state = WebSocketState.CloseSent; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } if (!_isServer && _receivedCloseFrame) @@ -1421,6 +1600,7 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc { ThrowEOFUnexpected(); } + OnDataReceived(numRead); } } } @@ -1430,7 +1610,7 @@ private void ThrowEOFUnexpected() // The connection closed before we were able to read everything we needed. // If it was due to us being disposed, fail with the correct exception. // Otherwise, it was due to the connection being closed and it wasn't expected. - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); } @@ -1542,6 +1722,28 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa cancellationToken); } + private void ThrowIfDisposed() + { + if (_keepAlivePingState is null) + { + ThrowIfDisposedOrKeepAliveFaulted(); + return; + } + + ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + } + + private void ThrowIfInvalidState(WebSocketState[] validStates) + { + if (_keepAlivePingState is null) + { + ThrowIfInvalidStateOrKeepAliveFaulted(validStates); + return; + } + + WebSocketValidate.ThrowIfInvalidState(_state, _disposed, validStates); + } + // From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75 // Performs a stateful validation of UTF-8 bytes. // It checks for valid formatting, overlong encodings, surrogates, and value ranges. diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs new file mode 100644 index 0000000000000..d0977fb767060 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs @@ -0,0 +1,286 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; + +namespace System.Net +{ + [EventSource(Name = "Private.InternalDiagnostics.System.Net.WebSockets")] + internal sealed partial class NetEventSource + { + // NOTE + // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They + // enable creating 'activities'. + // For more information, take a look at the following blog post: + // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ + // - A stop event's event id must be next one after its start event. + + private const int KeepAliveSentId = NextAvailableEventId; + private const int KeepAliveAckedId = KeepAliveSentId + 1; + + private const int WsTraceId = KeepAliveAckedId + 1; + + private const int CloseStartId = WsTraceId + 1; + private const int CloseStopId = CloseStartId + 1; + + private const int ReceiveStartId = CloseStopId + 1; + private const int ReceiveStopId = ReceiveStartId + 1; + + private const int SendStartId = ReceiveStopId + 1; + private const int SendStopId = SendStartId + 1; + + private const int MutexEnterId = SendStopId + 1; + private const int MutexExitId = MutexEnterId + 1; + private const int MutexContendedId = MutexExitId + 1; + + // + // Keep-Alive + // + + private const string Ping = "Ping"; + private const string Pong = "Pong"; + + [Event(KeepAliveSentId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void KeepAliveSent(string objName, string opcode, long payload) => + WriteEvent(KeepAliveSentId, objName, opcode, payload); + + [Event(KeepAliveAckedId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void KeepAliveAcked(string objName, long payload) => + WriteEvent(KeepAliveAckedId, objName, payload); + + [NonEvent] + public static void KeepAlivePingSent(object? obj, long payload) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveSent(IdOf(obj), Ping, payload); + } + + [NonEvent] + public static void UnsolicitedPongSent(object? obj) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveSent(IdOf(obj), Pong, 0); + } + + [NonEvent] + public static void PongResponseReceived(object? obj, long payload) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveAcked(IdOf(obj), payload); + } + + // + // Debug Messages + // + + [Event(WsTraceId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void WsTrace(string objName, string memberName, string message) => + WriteEvent(WsTraceId, objName, memberName, message); + + [NonEvent] + public static void TraceErrorMsg(object? obj, Exception exception, [CallerMemberName] string? memberName = null) + => Trace(obj, $"{exception.GetType().Name}: {exception.Message}", memberName); + + [NonEvent] + public static void TraceException(object? obj, Exception exception, [CallerMemberName] string? memberName = null) + => Trace(obj, exception.ToString(), memberName); + + [NonEvent] + public static void Trace(object? obj, string? message = null, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.WsTrace(IdOf(obj), memberName ?? MissingMember, message ?? memberName ?? string.Empty); + } + + // + // Close + // + + [Event(CloseStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void CloseStart(string objName, string memberName) => + WriteEvent(CloseStartId, objName, memberName); + + [Event(CloseStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void CloseStop(string objName, string memberName) => + WriteEvent(CloseStopId, objName, memberName); + + [NonEvent] + public static void CloseAsyncPrivateStarted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.CloseStart(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void CloseAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.CloseStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // ReceiveAsyncPrivate + // + + [Event(ReceiveStartId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void ReceiveStart(string objName, string memberName, int bufferLength) => + WriteEvent(ReceiveStartId, objName, memberName, bufferLength); + + [Event(ReceiveStopId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void ReceiveStop(string objName, string memberName) => + WriteEvent(ReceiveStopId, objName, memberName); + + [NonEvent] + public static void ReceiveAsyncPrivateStarted(object? obj, int bufferLength, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.ReceiveStart(IdOf(obj), memberName ?? MissingMember, bufferLength); + } + + [NonEvent] + public static void ReceiveAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.ReceiveStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // SendFrameAsync + // + + [Event(SendStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void SendStart(string objName, string memberName, string opcode, int bufferLength) => + WriteEvent(SendStartId, objName, memberName, opcode, bufferLength); + + [Event(SendStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void SendStop(string objName, string memberName) => + WriteEvent(SendStopId, objName, memberName); + + [NonEvent] + public static void SendFrameAsyncStarted(object? obj, string opcode, int bufferLength, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.SendStart(IdOf(obj), memberName ?? MissingMember, opcode, bufferLength); + } + + [NonEvent] + public static void SendFrameAsyncCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.SendStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // AsyncMutex + // + + [Event(MutexEnterId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexEnter(string objName, string memberName) => + WriteEvent(MutexEnterId, objName, memberName); + + [Event(MutexExitId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexExit(string objName, string memberName) => + WriteEvent(MutexExitId, objName, memberName); + + [Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexContended(string objName, string memberName, int queueLength) => + WriteEvent(MutexContendedId, objName, memberName, queueLength); + + [NonEvent] + public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexEnter(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void MutexExited(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexExit(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue); + } + + // + // WriteEvent overloads + // + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern", + Justification = EventSourceSuppressMessage)] + [NonEvent] + private unsafe void WriteEvent(int eventId, string arg1, string arg2, long arg3) + { + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + { + const int NumEventDatas = 3; + EventData* descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(&arg3), + Size = sizeof(long) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern", + Justification = EventSourceSuppressMessage)] + [NonEvent] + private unsafe void WriteEvent(int eventId, string arg1, string arg2, string arg3, int arg4) + { + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + fixed (char* arg3Ptr = arg3) + { + const int NumEventDatas = 4; + EventData* descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(arg3Ptr), + Size = (arg3.Length + 1) * sizeof(char) + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)(&arg4), + Size = sizeof(int) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index fc4436926a6f0..ce47b894d32c7 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -84,7 +84,7 @@ private async ValueTask SendWithArrayPoolAsync( public static TimeSpan DefaultKeepAliveInterval { // In the .NET Framework, this pulls the value from a P/Invoke. Here we just hardcode it to a reasonable default. - get { return TimeSpan.FromSeconds(30); } + get { return WebSocketDefaults.DefaultClientKeepAliveInterval; } } protected static void ThrowOnInvalidState(WebSocketState state, params WebSocketState[] validStates) @@ -150,7 +150,7 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s 0)); } - return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval); + return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout); } /// Creates a that operates on a representing a web socket connection. @@ -209,7 +209,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval); + return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index d042583da5444..fb002592b8d39 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -11,7 +11,8 @@ namespace System.Net.WebSockets public sealed class WebSocketCreationOptions { private string? _subProtocol; - private TimeSpan _keepAliveInterval; + private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultKeepAliveInterval; + private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout; /// /// Defines if this websocket is the server-side of the connection. The default value is false. @@ -52,6 +53,25 @@ public TimeSpan KeepAliveInterval } } + /// + /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or + /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead. + /// The default is . + /// + public TimeSpan KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(KeepAliveTimeout), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + } + _keepAliveTimeout = value; + } + } + /// /// The agreed upon options for per message deflate. /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 807da709ea755..a7f09ff31db29 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,6 +1,7 @@ $(NetCoreAppCurrent) + ../src/Resources/Strings.resx @@ -9,6 +10,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs index 86d1dfb2cd530..423c8d40ed5d5 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs @@ -85,5 +85,30 @@ public async Task ReceiveAsync_ValidCloseStatus_Success(WebSocketCloseStatus clo Assert.Equal(closeStatusDescription, closing.CloseStatusDescription); } } + + [Fact] + public async Task CloseAsync_CancelableEvenWhenPendingReceive_Throws() + { + using var stream = new WebSocketTestStream(); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + Task receiveTask = websocket.ReceiveAsync(new byte[1], CancellationToken); + await Task.Delay(100); // give the receive task a chance to aquire the lock + var cancelCloseCts = new CancellationTokenSource(); + await Assert.ThrowsAnyAsync(async () => + { + Task t = websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancelCloseCts.Token); + await Task.Delay(100); // give the close task time to get in the queue waiting for the lock + cancelCloseCts.Cancel(); + await t; + }); + + Assert.True(cancelCloseCts.Token.IsCancellationRequested); + Assert.False(CancellationToken.IsCancellationRequested); + + await Assert.ThrowsAnyAsync(() => receiveTask); + + Assert.False(CancellationToken.IsCancellationRequested); + } } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs new file mode 100644 index 0000000000000..11dd28117ac5d --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs @@ -0,0 +1,280 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public class WebSocketKeepAliveTests + { + public static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10); + public static readonly TimeSpan KeepAliveInterval = TimeSpan.FromMilliseconds(100); + public static readonly TimeSpan KeepAliveTimeout = TimeSpan.FromSeconds(1); + public const int FramesToTestCount = 5; + +#region Frame format helper constants + + public const int MinHeaderLength = 2; + public const int MaskLength = 4; + public const int SingleInt64PayloadLength = sizeof(long); + public const int PingPayloadLength = SingleInt64PayloadLength; + + // 0b_1_***_**** -- fin=true + public const byte FirstByteBits_FinFlag = 0b_1_000_0000; + + // 0b_*_***_0010 -- opcode=BINARY (0x02) + public const byte FirstByteBits_OpcodeBinary = 0b_0_000_0010; + + // 0b_*_***_1001 -- opcode=PING (0x09) + public const byte FirstByteBits_OpcodePing = 0b_0_000_1001; + + // 0b_*_***_1010 -- opcode=PONG (0x10) + public const byte FirstByteBits_OpcodePong = 0b_0_000_1010; + + // 0b_1_******* -- mask=true + public const byte SecondByteBits_MaskFlag = 0b_1_0000000; + + // 0b_*_0001000 -- length=8 + public const byte SecondByteBits_PayloadLength8 = SingleInt64PayloadLength; + + public const byte FirstByte_PingFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePing; + public const byte FirstByte_PongFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePong; + public const byte FirstByte_DataFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodeBinary; + + public const byte SecondByte_Server_NoPayload = 0; + public const byte SecondByte_Client_NoPayload = SecondByteBits_MaskFlag; + + public const byte SecondByte_Server_8bPayload = SecondByteBits_PayloadLength8; + public const byte SecondByte_Client_8bPayload = SecondByteBits_MaskFlag | SecondByteBits_PayloadLength8; + + public const int Server_FrameHeaderLength = MinHeaderLength; + public const int Client_FrameHeaderLength = MinHeaderLength + MaskLength; + +#endregion + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WebSocket_NoUserReadOrWrite_SendsUnsolicitedPong(bool isServer) + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream localEndpointStream = testStream; + Stream remoteEndpointStream = testStream.Remote; + + using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions + { + IsServer = isServer, + KeepAliveInterval = KeepAliveInterval + }); + + // --- "remote endpoint" side --- + + int pongFrameLength = isServer ? Server_FrameHeaderLength : Client_FrameHeaderLength; + var pongBuffer = new byte[pongFrameLength]; + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pongs "indefinitely", let's check a few + { + await remoteEndpointStream.ReadExactlyAsync(pongBuffer, cancellationToken); + + Assert.Equal(FirstByte_PongFrame, pongBuffer[0]); + Assert.Equal( + isServer ? SecondByte_Server_NoPayload : SecondByte_Client_NoPayload, + pongBuffer[1]); + } + } + + [Fact] + public async Task WebSocketServer_NoUserReadOrWrite_SendsPingAndReadsPongResponse() + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream serverStream = testStream; + Stream clientStream = testStream.Remote; + + using WebSocket webSocketServer = WebSocket.CreateFromStream(serverStream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here + }); + + // we need an outstanding read to ensure the client receives pongs + var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var serverReadTask = webSocketServer.ReceiveAsync(Memory.Empty, readCts.Token); + + // --- "client" side --- + + var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking + + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few + { + Assert.Equal(WebSocketState.Open, webSocketServer.State); + Assert.False(serverReadTask.IsCompleted); + + buffer.AsSpan().Clear(); + await clientStream.ReadExactlyAsync( + buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength), + cancellationToken); + + Assert.Equal(FirstByte_PingFrame, buffer[0]); + + // implementation detail: payload is a long counter starting from 1 + Assert.Equal(SecondByte_Server_8bPayload, buffer[1]); + + var payloadBytes = buffer.AsSpan().Slice(Server_FrameHeaderLength, PingPayloadLength); + long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes); + + Assert.Equal(i+1, pingCounter); + + // --- sending pong back --- + + buffer[0] = FirstByte_PongFrame; + buffer[1] = SecondByte_Client_8bPayload; + + // using zeroes as a "mask" -- applying such a mask is a no-op + Array.Clear(buffer, MinHeaderLength, MaskLength); + + // sending the same payload back + BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Client_FrameHeaderLength), pingCounter); + + await clientStream.WriteAsync(buffer, cancellationToken); + } + + Assert.Equal(WebSocketState.Open, webSocketServer.State); + Assert.False(serverReadTask.IsCompleted); + + readCts.Cancel(); + + await Assert.ThrowsAsync(() => serverReadTask.AsTask()); + Assert.Equal(WebSocketState.Aborted, webSocketServer.State); + } + + [Fact] + public async Task WebSocketClient_NoServerDataSent_SendsPingAndReadsPongResponse() + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream clientStream = testStream; + Stream serverStream = testStream.Remote; + + using WebSocket webSocketClient = WebSocket.CreateFromStream(clientStream, new WebSocketCreationOptions + { + IsServer = false, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here + }); + + // we need an outstanding read to ensure the client receives pongs + var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var clientReadTask = webSocketClient.ReceiveAsync(Memory.Empty, readCts.Token); + + + // --- "server" side --- + + var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking + + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few + { + Assert.Equal(WebSocketState.Open, webSocketClient.State); + Assert.False(clientReadTask.IsCompleted); + + buffer.AsSpan().Clear(); + await serverStream.ReadExactlyAsync(buffer, cancellationToken); + + Assert.Equal(FirstByte_PingFrame, buffer[0]); + + // implementation detail: payload is a long counter starting from 1 + Assert.Equal(SecondByte_Client_8bPayload, buffer[1]); + + var payloadBytes = buffer.AsSpan().Slice(Client_FrameHeaderLength, PingPayloadLength); + ApplyMask(payloadBytes, buffer.AsSpan().Slice(Client_FrameHeaderLength - MaskLength, MaskLength)); + long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes); + Assert.Equal(i+1, pingCounter); + + // --- sending pong back --- + + buffer[0] = FirstByte_PongFrame; + buffer[1] = SecondByte_Server_8bPayload; + + // sending the same payload back + BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Server_FrameHeaderLength), pingCounter); + + await serverStream.WriteAsync( + buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength), + cancellationToken); + } + + Assert.Equal(WebSocketState.Open, webSocketClient.State); + Assert.False(clientReadTask.IsCompleted); + + readCts.Cancel(); + + await Assert.ThrowsAsync(() => clientReadTask.AsTask()); + Assert.Equal(WebSocketState.Aborted, webSocketClient.State); + + // Octet i of the transformed data ("transformed-octet-i") is the XOR of + // octet i of the original data ("original-octet-i") with octet at index + // i modulo 4 of the masking key ("masking-key-octet-j"): + // + // j = i MOD 4 + // transformed-octet-i = original-octet-i XOR masking-key-octet-j + // + static void ApplyMask(Span buffer, Span mask) + { + + for (int i = 0; i < buffer.Length; i++) + { + buffer[i] ^= mask[i % MaskLength]; + } + } + } + + [OuterLoop("Uses Task.Delay")] + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WebSocket_NoPongResponseWithinTimeout_Aborted(bool outstandingUserRead) + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream localEndpointStream = testStream; + Stream remoteEndpointStream = testStream.Remote; + + using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = KeepAliveTimeout + }); + + Debug.Assert(webSocket.State == WebSocketState.Open); + + ValueTask userReadTask = default; + if (outstandingUserRead) + { + userReadTask = webSocket.ReceiveAsync(Memory.Empty, cancellationToken); + } + + await Task.Delay(2 * (KeepAliveTimeout + KeepAliveInterval), cancellationToken); + + Assert.Equal(WebSocketState.Aborted, webSocket.State); + + Exception readException = outstandingUserRead + ? await Assert.ThrowsAsync(() => userReadTask.AsTask()) + : await Assert.ThrowsAsync(() => webSocket.ReceiveAsync(Memory.Empty, cancellationToken).AsTask()); + + var wse = Assert.IsType(readException.InnerException); + Assert.Equal(WebSocketError.Faulted, wse.WebSocketErrorCode); + Assert.Equal(SR.net_Websockets_KeepAlivePingTimeout, wse.Message); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index 4cf9c279ba5f3..73e84998a9419 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -200,6 +201,26 @@ public async Task ReceiveAsync_WhenDisposedInParallel_DoesNotGetStuck() await Assert.ThrowsAsync(() => r3.WaitAsync(TimeSpan.FromSeconds(1))); } + [Fact] + public async Task ReceiveAsync_AfterCancellationDoReceiveAsync_ThrowsWebSocketException() + { + using var stream = new WebSocketTestStream(); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + var recvBuffer = new byte[100]; + var segment = new ArraySegment(recvBuffer); + var cts = new CancellationTokenSource(); + + Task receive = websocket.ReceiveAsync(segment, cts.Token); + cts.Cancel(); + await Assert.ThrowsAnyAsync(() => receive); + + WebSocketException ex = await Assert.ThrowsAsync(() => + websocket.ReceiveAsync(segment, CancellationToken.None)); + Assert.Equal( + SR.Format(SR.net_WebSockets_InvalidState, "Aborted", "Open, CloseSent"), + ex.Message); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) => From b1a30b749a8ca52de028390c9b776eb87b21530e Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 1 Aug 2024 19:55:20 +0100 Subject: [PATCH 02/13] Minor fixes and add ClientWebSocket tests --- .../tests/KeepAliveTest.Loopback.cs | 123 ++++++++++++++++++ .../tests/KeepAliveTest.cs | 34 ++++- .../LoopbackServer/LoopbackWebSocketServer.cs | 3 + .../System.Net.WebSockets.Client.Tests.csproj | 1 + .../WebSockets/ManagedWebSocket.KeepAlive.cs | 33 ++--- .../System/Net/WebSockets/ManagedWebSocket.cs | 28 ++-- 6 files changed, 192 insertions(+), 30 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs new file mode 100644 index 0000000000000..e1d21d533dda7 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] + public class KeepAliveTest_Loopback : ClientWebSocketTestBase + { + public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } + + protected virtual Version HttpVersion => Net.HttpVersion.Version11; + + public static readonly object[][] UseSsl_MemberData = PlatformDetection.SupportsAlpn + ? new[] { new object[] { false }, new object[] { true } } + : new[] { new object[] { false } }; + + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var longDelayByServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TimeSpan LongDelay = TimeSpan.FromSeconds(10); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + var options = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()) + { + DisposeServerWebSocket = true, + DisposeClientWebSocket = true, + ConfigureClientOptions = clientOptions => + { + //clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100); + //clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(500); + }, + }; + + return LoopbackWebSocketServer.RunAsync( + async (clientWebSocket, token) => + { + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + + // We need to always have a read task active to keep processing pongs + var outstandingReadTask = clientWebSocket.ReceiveAsync(Array.Empty(), token); + + await longDelayByServerTcs.Task.WaitAsync(token); + + var result = await outstandingReadTask; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + Assert.False(result.EndOfMessage); + Assert.Equal(0, result.Count); // we issued a zero byte read, just to wait for data to become available + + Assert.Equal(WebSocketState.Open, clientWebSocket.State); + + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + + Assert.Equal(WebSocketState.Open, clientWebSocket.State); + + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token); + + Assert.Equal(WebSocketState.Closed, clientWebSocket.State); + }, + async (serverWebSocket, token) => + { + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + + Assert.Equal(WebSocketState.Open, serverWebSocket.State); + + await Task.Delay(LongDelay); + + Assert.Equal(WebSocketState.Open, serverWebSocket.State); + + // recreate already-completed TCS for another round + clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + longDelayByServerTcs.SetResult(); + + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + + var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token); + Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType); + Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State); + + await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token); + Assert.Equal(WebSocketState.Closed, serverWebSocket.State); + }, + options, + timeoutCts.Token); + } + + private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) + { + var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); + + var recvBuf = new byte[remoteMsg.Length * 2]; + var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false); + + Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType); + Assert.Equal(remoteMsg.Length, recvResult.Count); + Assert.True(recvResult.EndOfMessage); + Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]); + + localAckTcs.SetResult(); + + await sendTask.ConfigureAwait(false); + await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs index 56acccdc05590..e819d9800c675 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs @@ -9,6 +9,8 @@ using Xunit; using Xunit.Abstractions; +using static System.Net.Test.Common.Configuration.WebSockets; + namespace System.Net.WebSockets.Client.Tests { [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] @@ -20,7 +22,7 @@ public KeepAliveTest(ITestOutputHelper output) : base(output) { } [OuterLoop] // involves long delay public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() { - using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(System.Net.Test.Common.Configuration.WebSockets.RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1))) + using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1))) { await cws.SendAsync(new ArraySegment(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None); @@ -33,5 +35,35 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None); } } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [OuterLoop] // involves long delay + [InlineData(1, 0)] // unsolicited pong + [InlineData(1, 2)] // ping/pong + public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec) + { + using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket( + RemoteEchoServer, + TimeOutMilliseconds, + _output, + options => + { + options.KeepAliveInterval = TimeSpan.FromSeconds(keepAliveIntervalSec); + options.KeepAliveTimeout = TimeSpan.FromSeconds(keepAliveTimeoutSec); + })) + { + byte[] receiveBuffer = new byte[1]; + var receiveTask = cws.ReceiveAsync(new ArraySegment(receiveBuffer), CancellationToken.None); // this will wait until we trigger the echo server by sending a message + + await Task.Delay(TimeSpan.FromSeconds(10)); + + await cws.SendAsync(new ArraySegment(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None); + + Assert.Equal(1, (await receiveTask).Count); + Assert.Equal(42, receiveBuffer[0]); + + await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None); + } + } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs index b24e2e20d40df..7ea2e5962c4d4 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -132,6 +132,8 @@ public static async Task GetConnectedClientAsync(Uri uri, Optio clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; }; } + options.ConfigureClientOptions?.Invoke(clientWebSocket.Options); + await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false); return clientWebSocket; @@ -143,6 +145,7 @@ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker public bool DisposeClientWebSocket { get; set; } public bool DisposeHttpInvoker { get; set; } public bool ManualServerHandshakeResponse { get; set; } + public Action? ConfigureClientOptions { get; set; } } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 0c07922eb10ec..b45fbad02d0c9 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -58,6 +58,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index cc34b61e7a044..9ca844838aafb 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -4,7 +4,6 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; -using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; @@ -54,7 +53,7 @@ void LogFaulted(Task task) _ = task.Exception; // accessing exception anyway, to observe it regardless of whether the tracing is enabled - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, task.Exception); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, task.Exception); } } @@ -119,22 +118,20 @@ private void KeepAlivePingHeartBeat() } catch (Exception e) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, e); - - bool aborting = false; - lock (StateUpdateLock) + if (NetEventSource.Log.IsEnabled()) { - if (!_disposed) - { - // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal - // The exception needs to be assigned before _disposed is set to true - _keepAlivePingState.Exception = e; - aborting = true; - } + NetEventSource.TraceException(this, e); + NetEventSource.Trace(this, $"_disposed={_disposed}"); } - if (aborting) + if (!_disposed) { + // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal + // The exception needs to be assigned before _disposed is set to true + _keepAlivePingState.Exception = e; + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Exception saved in _keepAlivePingState, aborting..."); + Abort(); } } @@ -215,7 +212,9 @@ private void ThrowIfDisposedOrKeepAliveFaulted() { Debug.Assert(_keepAlivePingState is not null); - if (_disposed && _state == WebSocketState.Aborted && _keepAlivePingState.Exception is not null) + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={_disposed}, _state={_state}, _keepAlivePingState.Exception={_keepAlivePingState.Exception?.Message}"); + + if (_disposed && _keepAlivePingState.Exception is not null) { // If Exception is not null, it triggered the abort which also disposed the websocket // We only save the Exception if it actually triggered the abort @@ -229,11 +228,13 @@ private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[] validStates) { Debug.Assert(_keepAlivePingState is not null); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={_disposed}, _state={_state}, _keepAlivePingState.Exception={_keepAlivePingState.Exception?.Message}"); + try { WebSocketValidate.ThrowIfInvalidState(_state, _disposed, validStates); } - catch (Exception exc) when (_state == WebSocketState.Aborted && _keepAlivePingState.Exception is not null) + catch (Exception exc) when (_disposed && _keepAlivePingState.Exception is not null) { // If Exception is not null, it triggered the abort which also disposed the websocket // We only save the Exception if it actually triggered the abort diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 04f934a1aec68..2e9c78c44bbc7 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -340,7 +340,7 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return ValueTask.FromException(exc); } @@ -382,7 +382,7 @@ public override Task ReceiveAsync(ArraySegment buf } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return Task.FromException(exc); } } @@ -399,7 +399,7 @@ public override ValueTask ReceiveAsync(Memory } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return ValueTask.FromException(exc); } } @@ -416,7 +416,7 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return Task.FromException(exc); } @@ -542,7 +542,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return ValueTask.FromException( exc is OperationCanceledException ? exc : @@ -580,9 +580,9 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); - if (exc is not OperationCanceledException) + if (exc is OperationCanceledException) { throw; } @@ -621,9 +621,9 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); - if (exc is not OperationCanceledException) + if (exc is OperationCanceledException) { throw; } @@ -1011,7 +1011,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); if (exc is OperationCanceledException) { @@ -1022,6 +1022,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo { if (_disposed && _keepAlivePingState?.Exception is not null) { + // it should have already been wrapped in an OperationCanceledException and thrown above, + // but just in case it wasn't due to some race, let's surface both exceptions throw new OperationCanceledException(nameof(WebSocketState.Aborted), new AggregateException(exc, _keepAlivePingState.Exception)); } @@ -1497,7 +1499,7 @@ private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? s } catch (Exception exc) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); throw; } finally @@ -1724,7 +1726,7 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa private void ThrowIfDisposed() { - if (_keepAlivePingState is null) + if (_keepAlivePingState is not null) { ThrowIfDisposedOrKeepAliveFaulted(); return; @@ -1735,7 +1737,7 @@ private void ThrowIfDisposed() private void ThrowIfInvalidState(WebSocketState[] validStates) { - if (_keepAlivePingState is null) + if (_keepAlivePingState is not null) { ThrowIfInvalidStateOrKeepAliveFaulted(validStates); return; From 5f30316ef9429d9385f97254b4e89d28536109de Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 1 Aug 2024 20:15:06 +0100 Subject: [PATCH 03/13] Expand KeepAliveTest_Loopback --- .../tests/KeepAliveTest.Loopback.cs | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index e1d21d533dda7..c507ac292e657 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -12,7 +12,7 @@ namespace System.Net.WebSockets.Client.Tests { [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] - public class KeepAliveTest_Loopback : ClientWebSocketTestBase + public abstract class KeepAliveTest_Loopback : ClientWebSocketTestBase { public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } @@ -41,8 +41,8 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) DisposeClientWebSocket = true, ConfigureClientOptions = clientOptions => { - //clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100); - //clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(500); + clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100); + clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); }, }; @@ -118,6 +118,38 @@ private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, await sendTask.ConfigureAwait(false); await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false); } + } + + // --- HTTP/1.1 WebSocket loopback tests --- + + public class KeepAliveTest_Invoker_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseCustomInvoker => true; + } + + public class KeepAliveTest_HttpClient_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseHttpClient => true; + } + public class KeepAliveTest_SharedHandler_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } + } + + // --- HTTP/2 WebSocket loopback tests --- + + public class KeepAliveTest_Invoker_Http2 : KeepAliveTest_Invoker_Loopback + { + public KeepAliveTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + } + + public class KeepAliveTest_HttpClient_Http2 : KeepAliveTest_HttpClient_Loopback + { + public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; } } From 9c9312d7c4ce4c8fd5af6ee222b94d0ba0062b7b Mon Sep 17 00:00:00 2001 From: Radek Zikmund <32671551+rzikm@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:29:49 +0200 Subject: [PATCH 04/13] Update WebSocketCreationOptions.cs --- .../src/System/Net/WebSockets/WebSocketCreationOptions.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index fb002592b8d39..dfc74241379f8 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -37,6 +37,8 @@ public string? SubProtocol /// /// The keep-alive interval to use, or or to disable keep-alives. + /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise, + /// unsolicited PONG messages are used as a keep-alive heartbeat. /// The default is . /// public TimeSpan KeepAliveInterval From 9854c6373c8252c28af578128177ac903a65a3a9 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 2 Aug 2024 16:02:45 +0200 Subject: [PATCH 05/13] Address code review feedback --- .../Net/WebSockets/ClientWebSocketOptions.cs | 5 ++++ .../src/System/Net/WebSockets/AsyncMutex.cs | 6 ---- .../WebSockets/ManagedWebSocket.KeepAlive.cs | 28 +++++++++---------- .../System/Net/WebSockets/ManagedWebSocket.cs | 7 +---- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 3639bf8caaf4d..2b28a561d8fdb 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -189,6 +189,11 @@ public TimeSpan KeepAliveInterval } } + /// + /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or + /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead. + /// The default is . + /// [UnsupportedOSPlatform("browser")] public TimeSpan KeepAliveTimeout { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs index 0de994799cd59..abf7e5e56a276 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs @@ -47,12 +47,6 @@ internal sealed class AsyncMutex /// Gets the object used to synchronize contended operations. private object SyncObj => this; - /// Attempts to syncronously enter the mutex. - /// This will succeed in case the mutex is not currently held nor contended. - /// Whether the mutex has been entered. - public bool TryEnter() - => Interlocked.CompareExchange(ref _gate, 0, 1) == 1; - /// Asynchronously waits to enter the mutex. /// The CancellationToken token to observe. /// A task that will complete when the mutex has been entered or the enter canceled. diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index 9ca844838aafb..036c09acfa067 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -43,17 +43,18 @@ private void ObserveException(Task task) { task.ContinueWith( LogFaulted, + this, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); - void LogFaulted(Task task) + static void LogFaulted(Task task, object? state) { Debug.Assert(task.IsFaulted); - _ = task.Exception; // accessing exception anyway, to observe it regardless of whether the tracing is enabled + Exception? e = task.Exception!.InnerException; // accessing exception anyway, to observe it regardless of whether the tracing is enabled - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, task.Exception); + if (NetEventSource.Log.IsEnabled() && e != null) NetEventSource.TraceException((ManagedWebSocket)state!, e); } } @@ -77,7 +78,7 @@ private void UnsolicitedPongHeartBeat() { if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); - this.Observe( + Observe( TrySendKeepAliveFrameAsync(MessageOpcode.Pong)); } @@ -172,7 +173,7 @@ private void SendKeepAlivePingIfNeeded() long pingPayload = Interlocked.Increment(ref _keepAlivePingState.PingPayload); - this.Observe( + Observe( SendPingAsync(pingPayload)); } } @@ -253,15 +254,15 @@ private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[] validStates) private sealed class KeepAlivePingState { internal const int PingPayloadSize = sizeof(long); - internal const long MinIntervalMs = 1; + internal const int MinIntervalMs = 1; + + internal int DelayMs; + internal int TimeoutMs; + internal int HeartBeatIntervalMs; - internal long DelayMs; - internal long TimeoutMs; internal long NextPingTimestamp; internal long WillTimeoutTimestamp; - internal long HeartBeatIntervalMs; - internal bool AwaitingPong; internal long PingPayload; internal Exception? Exception; @@ -277,9 +278,8 @@ public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) Math.Min(DelayMs, TimeoutMs) / 4, MinIntervalMs); - static long TimeSpanToMs(TimeSpan value) => Math.Max( - (long) Math.Min(value.TotalMilliseconds, int.MaxValue), - MinIntervalMs); + static int TimeSpanToMs(TimeSpan value) => + Math.Clamp((int)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); } internal void OnDataReceived() @@ -310,7 +310,5 @@ internal void OnPongResponseReceived(Span pongPayload) } } } - - } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 2e9c78c44bbc7..d2bbfc28f5f1e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -567,7 +567,6 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, return WaitForWriteTaskAsync(writeTask, shouldFlush: true); } - [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFlush) { try @@ -604,7 +603,6 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl } } - [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, Task lockTask, CancellationToken cancellationToken) { await lockTask.ConfigureAwait(false); @@ -794,10 +792,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo CancellationTokenRegistration registration = default; try { - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); - } + registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex); From 42ba661b1c1fae28708afb15485a663e6050a828 Mon Sep 17 00:00:00 2001 From: Radek Zikmund Date: Fri, 2 Aug 2024 16:22:47 +0200 Subject: [PATCH 06/13] More feedback --- .../src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index 036c09acfa067..60bc112838ebc 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -256,9 +256,9 @@ private sealed class KeepAlivePingState internal const int PingPayloadSize = sizeof(long); internal const int MinIntervalMs = 1; - internal int DelayMs; - internal int TimeoutMs; - internal int HeartBeatIntervalMs; + internal readonly int DelayMs; + internal readonly int TimeoutMs; + internal readonly int HeartBeatIntervalMs; internal long NextPingTimestamp; internal long WillTimeoutTimestamp; From eeae3204232122655e3442119da0ea6f6120d040 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 6 Aug 2024 15:31:30 +0100 Subject: [PATCH 07/13] add lock and some debug logs --- .../Net/WebSockets/WebSocketValidate.cs | 25 +- .../Net/Http/Http2LoopbackConnection.cs | 7 +- .../System/Net/Http/Http2LoopbackServer.cs | 4 + .../tests/KeepAliveTest.Loopback.cs | 22 ++ .../LoopbackServer/Http2LoopbackStream.cs | 5 +- .../LoopbackServer/LoopbackWebSocketServer.cs | 6 +- .../WebSocketHandshakeHelper.cs | 2 +- .../WebSockets/ManagedWebSocket.KeepAlive.cs | 221 ++++++++++-------- .../System/Net/WebSockets/ManagedWebSocket.cs | 19 +- 9 files changed, 193 insertions(+), 118 deletions(-) diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index e087677be4608..1da18d62093ac 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -38,6 +38,23 @@ internal static partial class WebSocketValidate SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"); internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) + { + // Exception order: + // 1. WebSocketException(InvalidState) -- if invalid state + // 2. ObjectDisposedException + + string? invalidStateMessage = GetInvalidStateMessage(currentState, validStates); + if (invalidStateMessage is null) // state is valid + { + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); + return; + } + + throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage); + } + + internal static string? GetInvalidStateMessage(WebSocketState currentState, WebSocketState[] validStates) { string validStatesText = string.Empty; @@ -47,18 +64,14 @@ internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDis { if (currentState == validState) { - // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. - ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); - return; + return null; } } validStatesText = string.Join(", ", validStates); } - throw new WebSocketException( - WebSocketError.InvalidState, - SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText)); + return SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText); } internal static void ValidateSubprotocol(string subProtocol) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index e607c42aa48ba..e3bd2380c541a 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -28,6 +28,7 @@ public class Http2LoopbackConnection : GenericLoopbackConnection private readonly TimeSpan _timeout; private int _lastStreamId; private bool _expectClientDisconnect; + private readonly Action? _debugLog; private readonly byte[] _prefix = new byte[24]; public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length); @@ -35,12 +36,13 @@ public class Http2LoopbackConnection : GenericLoopbackConnection public Stream Stream => _connectionStream; public Task SettingAckWaiter => _ignoredSettingsAckPromise?.Task; - private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse) + private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse, Action? debugLog = null) { _connectionSocket = socket; _connectionStream = stream; _timeout = timeout; _transparentPingResponse = transparentPingResponse; + _debugLog = debugLog; } public override string ToString() @@ -83,7 +85,7 @@ public static async Task CreateAsync(SocketWrapper sock stream = sslStream; } - var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse); + var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse, httpOptions.DebugLog); await con.ReadPrefixAsync().ConfigureAwait(false); return con; @@ -368,6 +370,7 @@ public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = f // and will ignore any errors if client has already shutdown public async Task ShutdownIgnoringErrorsAsync(int lastStreamId, ProtocolErrors errorCode = ProtocolErrors.NO_ERROR) { + _debugLog?.Invoke($"Http2LoopbackConnection.DisposeAsync() with errorCode={errorCode}; stack={Environment.StackTrace}"); try { await SendGoAway(lastStreamId, errorCode).ConfigureAwait(false); diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs index 90929b70eec37..5f9f52d3cd2b1 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs @@ -32,6 +32,8 @@ private Http2LoopbackConnection Connection } } + public Action? DebugLog => _options.DebugLog; + public static readonly TimeSpan Timeout = TimeSpan.FromSeconds(30); public override Uri Address @@ -186,6 +188,8 @@ public class Http2Options : GenericLoopbackOptions public bool EnableTransparentPingResponse { get; set; } = true; + public Action? DebugLog { get; set; } + public Http2Options() { SslProtocols = SslProtocols.Tls12; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index c507ac292e657..218a6269110c0 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -24,6 +24,7 @@ public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } [Theory] [MemberData(nameof(UseSsl_MemberData))] + [InlineData(false)] public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) { var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; @@ -44,6 +45,7 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100); clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); }, + DebugLog = DebugLog }; return LoopbackWebSocketServer.RunAsync( @@ -67,9 +69,13 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) Assert.Equal(WebSocketState.Open, clientWebSocket.State); + DebugLog("Sending close frame from client"); + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token); Assert.Equal(WebSocketState.Closed, clientWebSocket.State); + + DebugLog("Client closed"); }, async (serverWebSocket, token) => { @@ -89,15 +95,31 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + DebugLog("Receiving close frame on server"); + var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token); Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType); Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State); + DebugLog("Sending close frame response from server"); + await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token); Assert.Equal(WebSocketState.Closed, serverWebSocket.State); + + DebugLog("Server closed"); }, options, timeoutCts.Token); + + + void DebugLog(string str) + { + const int MaxLogLength = 3000; + lock (Console.Out) + { + Console.WriteLine($"{this.GetType().Name} | useSsl={useSsl} | {str.Substring(0, Math.Min(str.Length, MaxLogLength))}{(str.Length > MaxLogLength ? "" : "")}"); + } + } } private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs index 1b3b51840ec99..3f62ae33e5f95 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -15,6 +15,7 @@ public class Http2LoopbackStream : Stream private readonly int _streamId; private bool _readEnded; private ReadOnlyMemory _leftoverReadData; + private readonly Action? _debugLog; public override bool CanRead => true; public override bool CanSeek => false; @@ -23,10 +24,11 @@ public class Http2LoopbackStream : Stream public Http2LoopbackConnection Connection => _connection; public int StreamId => _streamId; - public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId) + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, Action? debugLog = null) { _connection = connection; _streamId = streamId; + _debugLog = debugLog; } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) @@ -67,6 +69,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask DisposeAsync() { + _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}; stack={Environment.StackTrace}"); try { await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs index 7ea2e5962c4d4..34f7a5327c8d2 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -73,9 +73,11 @@ await server.AcceptConnectionAsync(async connection => await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); + options.DebugLog?.Invoke("loopbackServerFunc completed; disposing the connection"); + await http2Connection.DisposeAsync().ConfigureAwait(false); }, - new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); + new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl, DebugLog = options.DebugLog }); } else { @@ -96,6 +98,7 @@ private static async Task RunServerAsync( if (options.DisposeServerWebSocket) { + options.DebugLog?.Invoke("Disposing server websocket"); serverWebSocket.Dispose(); } } @@ -146,6 +149,7 @@ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker public bool DisposeHttpInvoker { get; set; } public bool ManualServerHandshakeResponse { get; set; } public Action? ConfigureClientOptions { get; set; } + public Action? DebugLog { get; set; } } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs index 2a8c84e7de8ea..271ed4beca79e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -83,7 +83,7 @@ public static async Task ProcessHttp2RequestAsync(Http2Loo await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); } - data.WebSocketStream = new Http2LoopbackStream(connection, streamId); + data.WebSocketStream = new Http2LoopbackStream(connection, streamId, server.DebugLog); return data; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index 60bc112838ebc..4cdd537cc658d 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -48,13 +48,13 @@ private void ObserveException(Task task) TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); - static void LogFaulted(Task task, object? state) + static void LogFaulted(Task task, object? thisObj) { Debug.Assert(task.IsFaulted); - Exception? e = task.Exception!.InnerException; // accessing exception anyway, to observe it regardless of whether the tracing is enabled + Exception? innerException = task.Exception!.InnerException; // accessing exception anyway, to observe it regardless of whether the tracing is enabled - if (NetEventSource.Log.IsEnabled() && e != null) NetEventSource.TraceException((ManagedWebSocket)state!, e); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, innerException ?? task.Exception!); } } @@ -102,82 +102,87 @@ private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemor private void KeepAlivePingHeartBeat() { Debug.Assert(_keepAlivePingState != null); - Debug.Assert(_keepAlivePingState.Exception == null); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"{nameof(_keepAlivePingState.AwaitingPong)}={_keepAlivePingState.AwaitingPong}"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); try { - if (_keepAlivePingState.AwaitingPong) + bool timedOut = false; + bool sendPing = false; + long pingPayload = -1; + + lock (StateUpdateLock) { - KeepAlivePingThrowIfTimedOut(); + if (_keepAlivePingState.Exception is not null) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"KeepAlive already faulted, skipping... (exception: {_keepAlivePingState.Exception.Message})"); + return; + } + + long now = Environment.TickCount64; + + if (_keepAlivePingState.AwaitingPong) + { + Debug.Assert(_keepAlivePingState.WillTimeoutTimestamp != Timeout.Infinite); + + if (now > _keepAlivePingState.WillTimeoutTimestamp) + { + timedOut = true; + pingPayload = _keepAlivePingState.PingPayload; + } + } + else + { + if (now > _keepAlivePingState.NextPingTimestamp) + { + sendPing = true; + pingPayload = ++_keepAlivePingState.PingPayload; + + _keepAlivePingState.AwaitingPong = true; + _keepAlivePingState.WillTimeoutTimestamp = now + _keepAlivePingState.TimeoutMs; + } + } } - else + + if (timedOut) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {pingPayload}"); + } + + throw new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout); + } + else if (sendPing) { - SendKeepAlivePingIfNeeded(); + Observe( + SendPingAsync(pingPayload)); } } catch (Exception e) { - if (NetEventSource.Log.IsEnabled()) + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, e); + + bool shouldAbort = false; + lock (StateUpdateLock) { - NetEventSource.TraceException(this, e); - NetEventSource.Trace(this, $"_disposed={_disposed}"); + if (!_disposed) + { + // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal + // The exception needs to be assigned before _disposed is set to true + _keepAlivePingState.Exception = e; + shouldAbort = true; + } } - if (!_disposed) + if (shouldAbort) { - // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal - // The exception needs to be assigned before _disposed is set to true - _keepAlivePingState.Exception = e; - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Exception saved in _keepAlivePingState, aborting..."); - Abort(); } } } - private void KeepAlivePingThrowIfTimedOut() - { - Debug.Assert(_keepAlivePingState != null); - Debug.Assert(_keepAlivePingState.AwaitingPong); - Debug.Assert(_keepAlivePingState.WillTimeoutTimestamp != Timeout.Infinite); - - long now = Environment.TickCount64; - - if (now > Interlocked.Read(ref _keepAlivePingState.WillTimeoutTimestamp)) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {_keepAlivePingState.PingPayload}"); - } - - throw new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout); - } - } - - private void SendKeepAlivePingIfNeeded() - { - Debug.Assert(_keepAlivePingState != null); - Debug.Assert(!_keepAlivePingState.AwaitingPong); - - long now = Environment.TickCount64; - - // Check whether keep alive delay has passed since last frame received - if (now > Interlocked.Read(ref _keepAlivePingState.NextPingTimestamp)) - { - // Set the status directly to ping sent and set the timestamp - Interlocked.Exchange(ref _keepAlivePingState.WillTimeoutTimestamp, now + _keepAlivePingState.TimeoutMs); - _keepAlivePingState.AwaitingPong = true; - - long pingPayload = Interlocked.Increment(ref _keepAlivePingState.PingPayload); - - Observe( - SendPingAsync(pingPayload)); - } - } - private async ValueTask SendPingAsync(long pingPayload) { Debug.Assert(_keepAlivePingState != null); @@ -201,54 +206,62 @@ await TrySendKeepAliveFrameAsync( private void OnDataReceived(int bytesRead) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={bytesRead}"); if (_keepAlivePingState != null && bytesRead > 0) { - _keepAlivePingState.OnDataReceived(); + lock (StateUpdateLock) + { + _keepAlivePingState.OnDataReceived(); + } } } private void ThrowIfDisposedOrKeepAliveFaulted() + => ThrowIfInvalidStateOrKeepAliveFaulted(validStates: null); + + private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[]? validStates) { Debug.Assert(_keepAlivePingState is not null); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={_disposed}, _state={_state}, _keepAlivePingState.Exception={_keepAlivePingState.Exception?.Message}"); - - if (_disposed && _keepAlivePingState.Exception is not null) + // Exception order: WebSocketException -> OperationCanceledException -> ObjectDisposedException + // + // If keepAlive exception present: + // 1. WebSocketException(InvalidState), keepAlive exception as inner -- if invalid state + // 2. OperationCanceledException, keepAlive exception as inner + // + // If keepAlive exception not present: + // 1. WebSocketException(InvalidState) -- if invalid state + // 2. ObjectDisposedException + + bool disposed; + WebSocketState state; + Exception? keepAliveException; + lock (StateUpdateLock) { - // If Exception is not null, it triggered the abort which also disposed the websocket - // We only save the Exception if it actually triggered the abort - throw new OperationCanceledException(nameof(WebSocketState.Aborted), _keepAlivePingState.Exception); + disposed = _disposed; + state = _state; + keepAliveException = _keepAlivePingState.Exception; } - ObjectDisposedException.ThrowIf(_disposed, this); - } + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={disposed}, _state={state}, _keepAlivePingState.Exception={keepAliveException?.Message}"); - private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[] validStates) - { - Debug.Assert(_keepAlivePingState is not null); - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={_disposed}, _state={_state}, _keepAlivePingState.Exception={_keepAlivePingState.Exception?.Message}"); - - try + string? invalidStateMessage = validStates is not null ? WebSocketValidate.GetInvalidStateMessage(state, validStates) : null; + if (invalidStateMessage is not null) { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, validStates); + // Surface keepAliveException as inner exception, if present + throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, keepAliveException); } - catch (Exception exc) when (_disposed && _keepAlivePingState.Exception is not null) - { - // If Exception is not null, it triggered the abort which also disposed the websocket - // We only save the Exception if it actually triggered the abort - if (exc is ObjectDisposedException ode && ode.ObjectName == typeof(ManagedWebSocket).FullName) - { - throw new OperationCanceledException(nameof(WebSocketState.Aborted), _keepAlivePingState.Exception); - } - if (exc is WebSocketException we && we.WebSocketErrorCode == WebSocketError.InvalidState) - { - throw new WebSocketException(WebSocketError.InvalidState, we.Message, _keepAlivePingState.Exception); - } + // If keepAliveException is not null, it triggered the abort which also disposed the websocket + // We only save the exception if it actually triggered the abort + if (keepAliveException is not null) + { + throw new OperationCanceledException(nameof(WebSocketState.Aborted), keepAliveException); } + + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + ObjectDisposedException.ThrowIf(disposed, this); } private sealed class KeepAlivePingState @@ -267,6 +280,8 @@ private sealed class KeepAlivePingState internal long PingPayload; internal Exception? Exception; + internal object Debug_WebSocket_StateUpdateLock = null!; // for Debug.Asserts + public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) { DelayMs = TimeSpanToMs(keepAliveInterval); @@ -279,34 +294,38 @@ public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) MinIntervalMs); static int TimeSpanToMs(TimeSpan value) => - Math.Clamp((int)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); + (int)Math.Clamp((long)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); } internal void OnDataReceived() { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + Debug.Assert(Monitor.IsEntered(Debug_WebSocket_StateUpdateLock)); - Interlocked.Exchange(ref NextPingTimestamp, Environment.TickCount64 + DelayMs); + NextPingTimestamp = Environment.TickCount64 + DelayMs; } - internal void OnPongResponseReceived(Span pongPayload) + internal void OnPongResponseReceived(long pongPayload) { - Debug.Assert(AwaitingPong); - Debug.Assert(pongPayload.Length == sizeof(long)); + Debug.Assert(Monitor.IsEntered(Debug_WebSocket_StateUpdateLock)); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"pongPayload={pongPayload}"); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (!AwaitingPong) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Not waiting for Pong. Skipping."); + return; + } - long pongPayloadValue = BinaryPrimitives.ReadInt64BigEndian(pongPayload); - if (pongPayloadValue == Interlocked.Read(ref PingPayload)) + if (pongPayload == PingPayload) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayloadValue); + if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayload); - Interlocked.Exchange(ref WillTimeoutTimestamp, Timeout.Infinite); + WillTimeoutTimestamp = Timeout.Infinite; AwaitingPong = false; } - else if (NetEventSource.Log.IsEnabled()) + else { - NetEventSource.Trace(this, $"Received pong with unexpected payload {pongPayloadValue}. Expected {Interlocked.Read(ref PingPayload)}. Skipping."); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Expected payload {PingPayload}. Skipping."); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index d2bbfc28f5f1e..cbdc09ae7dbfe 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -180,6 +180,10 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim if (keepAliveTimeout > TimeSpan.Zero) { _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout); +#if DEBUG + _keepAlivePingState.Debug_WebSocket_StateUpdateLock = StateUpdateLock; +#endif + heartBeatIntervalMs = _keepAlivePingState.HeartBeatIntervalMs; if (NetEventSource.Log.IsEnabled()) @@ -1015,7 +1019,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo if (_state == WebSocketState.Aborted) { - if (_disposed && _keepAlivePingState?.Exception is not null) + if (_keepAlivePingState?.Exception is not null) { // it should have already been wrapped in an OperationCanceledException and thrown above, // but just in case it wasn't due to some race, let's surface both exceptions @@ -1178,8 +1182,7 @@ private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, Cancel bool processPing = header.Opcode == MessageOpcode.Ping; - bool processPong = header.Opcode == MessageOpcode.Pong - && _keepAlivePingState is not null && _keepAlivePingState.AwaitingPong + bool processPong = header.Opcode == MessageOpcode.Pong && _keepAlivePingState is not null && header.PayloadLength == KeepAlivePingState.PingPayloadSize; if ((processPing || processPong) && _isServer) @@ -1201,13 +1204,17 @@ await SendFrameAsync( } else if (processPong) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Pong with a suitable payload length"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Pong"); - _keepAlivePingState!.OnPongResponseReceived(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength)); + long pongPayload = BinaryPrimitives.ReadInt64BigEndian(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength)); + lock (StateUpdateLock) + { + _keepAlivePingState!.OnPongResponseReceived(pongPayload); + } } else { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Ignoring incoming Unsolicited Pong"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Received Unsolicited Pong. Skipping."); } // Regardless of whether it was a ping or pong, we no longer need the payload. From 3641d8629679502997e2098527e7084100c5d9a9 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Tue, 6 Aug 2024 23:21:58 +0100 Subject: [PATCH 08/13] more debug logs --- .../Net/Http/Http2LoopbackConnection.cs | 2 +- .../tests/TestUtilities/TestEventListener.cs | 146 ++++++++++++++-- .../tests/AbortTest.Loopback.cs | 163 ++++++++++++++---- .../tests/ClientWebSocketTestBase.cs | 108 +++++++++++- .../tests/CloseTest.cs | 11 +- .../tests/KeepAliveTest.Loopback.cs | 63 ++++--- .../LoopbackServer/Http2LoopbackStream.cs | 2 +- 7 files changed, 407 insertions(+), 88 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index e3bd2380c541a..e75474cd5bf8f 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -370,7 +370,7 @@ public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = f // and will ignore any errors if client has already shutdown public async Task ShutdownIgnoringErrorsAsync(int lastStreamId, ProtocolErrors errorCode = ProtocolErrors.NO_ERROR) { - _debugLog?.Invoke($"Http2LoopbackConnection.DisposeAsync() with errorCode={errorCode}; stack={Environment.StackTrace}"); + _debugLog?.Invoke($"Http2LoopbackConnection.ShutdownIgnoringErrorsAsync() with lastStreamId={lastStreamId}, errorCode={errorCode}"); try { await SendGoAway(lastStreamId, errorCode).ConfigureAwait(false); diff --git a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs index 8cb70ee3cbd8c..b83a165906164 100644 --- a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs +++ b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.Tracing; using System.IO; using System.Text; @@ -31,6 +32,7 @@ public sealed class TestEventListener : EventListener "Private.InternalDiagnostics.System.Net.Sockets", "Private.InternalDiagnostics.System.Net.Security", "Private.InternalDiagnostics.System.Net.Quic", + "Private.InternalDiagnostics.System.Net.WebSockets", "Private.InternalDiagnostics.System.Net.Http.WinHttpHandler", "Private.InternalDiagnostics.System.Net.HttpListener", "Private.InternalDiagnostics.System.Net.Mail", @@ -41,19 +43,24 @@ public sealed class TestEventListener : EventListener private readonly Action _writeFunc; private readonly HashSet _sourceNames; + private readonly bool _enableActivityId; // Until https://github.com/dotnet/runtime/issues/63979 is solved. private List _eventSources = new List(); public TestEventListener(TextWriter output, params string[] sourceNames) - : this(str => output.WriteLine(str), sourceNames) + : this(output.WriteLine, sourceNames) { } public TestEventListener(ITestOutputHelper output, params string[] sourceNames) - : this(str => output.WriteLine(str), sourceNames) + : this(output.WriteLine, sourceNames) { } public TestEventListener(Action writeFunc, params string[] sourceNames) + : this(writeFunc, enableActivityId: false, sourceNames) + { } + + public TestEventListener(Action writeFunc, bool enableActivityId, params string[] sourceNames) { List eventSources = _eventSources; @@ -61,16 +68,14 @@ public TestEventListener(Action writeFunc, params string[] sourceNames) { _writeFunc = writeFunc; _sourceNames = new HashSet(sourceNames); + _enableActivityId = enableActivityId; _eventSources = null; } // eventSources were populated in the base ctor and are now owned by this thread, enable them now. foreach (EventSource eventSource in eventSources) { - if (_sourceNames.Contains(eventSource.Name)) - { - EnableEvents(eventSource, EventLevel.LogAlways); - } + EnableEventSource(eventSource); } } @@ -90,20 +95,42 @@ protected override void OnEventSourceCreated(EventSource eventSource) } // Second pass called after our ctor, allow logging for specified source names. + EnableEventSource(eventSource); + } + + private void EnableEventSource(EventSource eventSource) + { if (_sourceNames.Contains(eventSource.Name)) { EnableEvents(eventSource, EventLevel.LogAlways); } + else if (_enableActivityId && eventSource.Name == "System.Threading.Tasks.TplEventSource") + { + EnableEvents(eventSource, EventLevel.LogAlways, (EventKeywords)0x80 /* TasksFlowActivityIds */); + } } protected override void OnEventWritten(EventWrittenEventArgs eventData) { - StringBuilder sb = new StringBuilder(). + StringBuilder sb = new StringBuilder(); + #if NET || NETSTANDARD2_1_OR_GREATER - Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}[{eventData.EventName}] "); -#else - Append($"[{eventData.EventName}] "); + sb.Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}"); + if (_enableActivityId) + { + if (eventData.ActivityId != Guid.Empty) + { + string activityId = ActivityHelpers.ActivityPathString(eventData.ActivityId); + sb.Append($" {activityId} {new string('-', activityId.Length / 2 - 1 )} "); + } + else + { + sb.Append(" / "); + } + } #endif + sb.Append($"[{eventData.EventName}] "); + for (int i = 0; i < eventData.Payload?.Count; i++) { if (i > 0) @@ -116,4 +143,103 @@ protected override void OnEventWritten(EventWrittenEventArgs eventData) } catch { } } + + // From https://gist.github.com/MihaZupan/cc63ee68b4146892f2e5b640ed57bc09 + private static class ActivityHelpers + { + private enum NumberListCodes : byte + { + End = 0x0, + LastImmediateValue = 0xA, + PrefixCode = 0xB, + MultiByte1 = 0xC, + } + + public static unsafe bool IsActivityPath(Guid guid) + { + uint* uintPtr = (uint*)&guid; + uint sum = uintPtr[0] + uintPtr[1] + uintPtr[2] + 0x599D99AD; + return ((sum & 0xFFF00000) == (uintPtr[3] & 0xFFF00000)); + } + + public static unsafe string ActivityPathString(Guid guid) + => IsActivityPath(guid) ? CreateActivityPathString(guid) : guid.ToString(); + + internal static unsafe string CreateActivityPathString(Guid guid) + { + Debug.Assert(IsActivityPath(guid)); + + StringBuilder sb = new StringBuilder(); + + byte* bytePtr = (byte*)&guid; + byte* endPtr = bytePtr + 12; + char separator = '/'; + while (bytePtr < endPtr) + { + uint nibble = (uint)(*bytePtr >> 4); + bool secondNibble = false; + NextNibble: + if (nibble == (uint)NumberListCodes.End) + { + break; + } + if (nibble <= (uint)NumberListCodes.LastImmediateValue) + { + sb.Append('/').Append(nibble); + if (!secondNibble) + { + nibble = (uint)(*bytePtr & 0xF); + secondNibble = true; + goto NextNibble; + } + bytePtr++; + continue; + } + else if (nibble == (uint)NumberListCodes.PrefixCode) + { + if (!secondNibble) + { + nibble = (uint)(*bytePtr & 0xF); + } + else + { + bytePtr++; + if (endPtr <= bytePtr) + { + break; + } + nibble = (uint)(*bytePtr >> 4); + } + if (nibble < (uint)NumberListCodes.MultiByte1) + { + return guid.ToString(); + } + separator = '$'; + } + Debug.Assert((uint)NumberListCodes.MultiByte1 <= nibble); + uint numBytes = nibble - (uint)NumberListCodes.MultiByte1; + uint value = 0; + if (!secondNibble) + { + value = (uint)(*bytePtr & 0xF); + } + bytePtr++; + numBytes++; + if (endPtr < bytePtr + numBytes) + { + break; + } + for (int i = (int)numBytes - 1; 0 <= i; --i) + { + value = (value << 8) + bytePtr[i]; + } + sb.Append(separator).Append(value); + + bytePtr += numBytes; + } + + sb.Append('/'); + return sb.ToString(); + } + } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs index 0aa83697a9de7..471e82c5e15fd 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.IO; +using System.Net.Sockets; +using System.Net.Test.Common; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -17,6 +20,8 @@ public AbortTest_Loopback(ITestOutputHelper output) : base(output) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; + public static object[][] AbortClient_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values, /* verifySendReceive */ Bool_Values); + [Theory] [MemberData(nameof(AbortClient_MemberData))] public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive) @@ -64,6 +69,8 @@ public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool use timeoutCts.Token); } + public static object[][] ServerPrematureEos_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values); + [Theory] [MemberData(nameof(ServerPrematureEos_MemberData))] public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl) @@ -146,34 +153,6 @@ await SendServerResponseAndEosAsync( protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken) => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2 - private static readonly bool[] Bool_Values = new[] { false, true }; - private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; - - public static IEnumerable AbortClient_MemberData() - { - foreach (var abortType in Enum.GetValues()) - { - foreach (var useSsl in UseSsl_Values) - { - foreach (var verifySendReceive in Bool_Values) - { - yield return new object[] { abortType, useSsl, verifySendReceive }; - } - } - } - } - - public static IEnumerable ServerPrematureEos_MemberData() - { - foreach (var serverEosType in Enum.GetValues()) - { - foreach (var useSsl in UseSsl_Values) - { - yield return new object[] { serverEosType, useSsl }; - } - } - } - public enum AbortType { Abort, @@ -187,7 +166,7 @@ public enum ServerEosType AfterSomeData } - private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + protected static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) { var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); @@ -228,19 +207,133 @@ public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) // --- HTTP/2 WebSocket loopback tests --- - public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback + public abstract class AbortTest_Http2 : AbortTest_Loopback { - public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + public AbortTest_Http2(ITestOutputHelper output) : base(output) { } protected override Version HttpVersion => Net.HttpVersion.Version20; protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + + public static object[][] ServerResetsAfterCloseHandshake_MemberData = ToMemberData(Bool_Values, UseSsl_Values); + + [Theory] + [MemberData(nameof(ServerResetsAfterCloseHandshake_MemberData))] + public Task ServerResetsAfterCloseHandshake_NoExceptionOnClient(bool sendGoAway, bool useSsl) + { + Assert.True(HttpVersion == Net.HttpVersion.Version20); + + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var serverFinalMsg = new byte[] { 123 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null) + { + DisposeServerWebSocket = false + }; + + var serverSentResetTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + return LoopbackWebSocketServer.RunAsync( + async uri => + { + var token = timeoutCts.Token; + var clientOptions = globalOptions with { HttpInvoker = GetInvoker() }; + var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false); + + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + + await clientWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); + + await serverSentResetTcs.Task.WaitAsync(token).ConfigureAwait(false); + + // even though the server sent a reset, the client should receive all the data sent before the reset + + // receive final message + var readBuffer = new byte[1]; + var result = await clientWebSocket.ReceiveAsync(readBuffer, token).ConfigureAwait(false); + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + Assert.Equal(1, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal(serverFinalMsg, readBuffer); + + // receive close frame + result = await clientWebSocket.ReceiveAsync(readBuffer, token).ConfigureAwait(false); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + + // we've already closed our side, so we should be fully closed now + Assert.Equal(WebSocketState.Closed, clientWebSocket.State); + clientWebSocket.Dispose(); + clientReceivedEosTcs.SetResult(); + }, + async (requestData, token) => + { + var connection = requestData.Http2Connection!; + var streamId = requestData.Http2StreamId!.Value; + + var wsOptions = new WebSocketCreationOptions { IsServer = true }; + var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions); + + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); + + // wait for client to send close frame + var result = await serverWebSocket.ReceiveAsync(new byte[1], token).ConfigureAwait(false); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + + // send final message + await serverWebSocket.SendAsync(serverFinalMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false); + + await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); + // we've already closed our side, so we should be fully closed now + Assert.Equal(WebSocketState.Closed, serverWebSocket.State); + + try + { + await connection.SendResponseDataAsync(streamId, Memory.Empty, endStream: true).ConfigureAwait(false); + + if (sendGoAway) + { + await connection.SendGoAway(streamId).ConfigureAwait(false); + } + else + { + var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId); + await connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); + } + } + catch (IOException) + { + // Ignore connection errors + } + catch (SocketException) + { + // Ignore connection errors + } + + await Task.Delay(1000); // give the client some time to process the reset + + serverSentResetTcs.SetResult(); + await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false); + serverWebSocket.Dispose(); + }, + globalOptions, + timeoutCts.Token); + } + } + + public class AbortTest_Invoker_Http2 : AbortTest_Http2 + { + public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + protected override bool UseCustomInvoker => true; } - public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback + public class AbortTest_HttpClient_Http2 : AbortTest_Http2 { public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } - protected override Version HttpVersion => Net.HttpVersion.Version20; - protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) - => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + protected override bool UseHttpClient => true; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index 0dc1775b57345..58fa33d71f5ec 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -2,18 +2,20 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Net.Http; +using System.Reflection; using System.Threading; using System.Threading.Tasks; +using TestUtilities; using Xunit; using Xunit.Abstractions; -using System.Net.Http; -using System.Diagnostics; namespace System.Net.WebSockets.Client.Tests { - public class ClientWebSocketTestBase + public class ClientWebSocketTestBase : IDisposable { public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.GetEchoServers(); public static readonly object[][] EchoHeadersServers = System.Net.Test.Common.Configuration.WebSockets.GetEchoHeadersServers(); @@ -23,13 +25,40 @@ public class ClientWebSocketTestBase new object[] { o[0], true } }).ToArray(); + public static readonly bool[] Bool_Values = new[] { false, true }; + public static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; + public static readonly object[][] UseSsl_MemberData = ToMemberData(UseSsl_Values); + + public static object[][] ToMemberData(IEnumerable data) + => data.Select(a => new object[] { a }).ToArray(); + + public static object[][] ToMemberData(IEnumerable dataA, IEnumerable dataB) + => dataA.SelectMany(a => dataB.Select(b => new object[] { a, b })).ToArray(); + + public static object[][] ToMemberData(IEnumerable dataA, IEnumerable dataB, IEnumerable dataC) + => dataA.SelectMany(a => dataB.SelectMany(b => dataC.Select(c => new object[] { a, b, c }))).ToArray(); + public const int TimeOutMilliseconds = 30000; public const int CloseDescriptionMaxLength = 123; public readonly ITestOutputHelper _output; + public readonly TracingTestCollection? _collection; - public ClientWebSocketTestBase(ITestOutputHelper output) + public ClientWebSocketTestBase(ITestOutputHelper output, TracingTestCollection? collection = null) { _output = output; + _collection = collection; + + if (_collection != null) + { + _collection._tracePrefix = $"{GetType().Name}#{GetHashCode()}"; + } + + Trace($"{Environment.NewLine}===== Starting {GetType().Name}#{GetHashCode()} ====={Environment.NewLine}"); + } + + public void Dispose() + { + Trace($"{Environment.NewLine}===== Disposing {GetType().Name}#{GetHashCode()} ====={Environment.NewLine}"); } public static IEnumerable UnavailableWebSocketServers @@ -146,5 +175,76 @@ protected Task TestEcho(Uri uri, WebSocketMessageType type, int timeOutMilliseco WebSocketHelper.TestEcho(uri, WebSocketMessageType.Text, TimeOutMilliseconds, _output, GetInvoker()); public static bool WebSocketsSupported { get { return WebSocketHelper.WebSocketsSupported; } } + + public void Trace(FormattableString message) => _collection?.Trace(message); + + public void Trace(string message) => _collection?.Trace(message); + } + + [CollectionDefinition(nameof(TracingTestCollection), DisableParallelization = true)] + public class TracingTestCollection : ICollectionFixture, IDisposable + { + private static readonly Dictionary s_unobservedExceptions = new Dictionary(); + + internal string _tracePrefix = "(null)"; + + private readonly TestEventListener _listener; + + private static readonly EventHandler s_eventHandler = static (_, e) => + { + lock (s_unobservedExceptions) + { + string text = e.Exception.ToString(); + s_unobservedExceptions[text] = s_unobservedExceptions.GetValueOrDefault(text) + 1; + } + }; + + private static readonly FieldInfo s_ClientWebSocket_innerWebSocketField = + typeof(ClientWebSocket).GetField("_innerWebSocket", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new Exception("Could not find ClientWebSocket._innerWebSocket field"); + private static readonly PropertyInfo s_WebSocketHandle_WebSocketProperty = + typeof(ClientWebSocket).Assembly.GetType("System.Net.WebSockets.WebSocketHandle", throwOnError: true)! + .GetProperty("WebSocket", BindingFlags.Instance | BindingFlags.Public) + ?? throw new Exception("Could not find WebSocketHandle.WebSocket property"); + + private static WebSocket GetUnderlyingWebSocket(ClientWebSocket clientWebSocket) + { + object? innerWebSocket = s_ClientWebSocket_innerWebSocketField.GetValue(clientWebSocket); + if (innerWebSocket == null) + { + throw new Exception("ClientWebSocket._innerWebSocket is null"); + } + + return (WebSocket)s_WebSocketHandle_WebSocketProperty.GetValue(innerWebSocket); + } + + public TracingTestCollection() + { + Console.WriteLine(Environment.NewLine + "===== Running TracingTestCollection =====" + Environment.NewLine); + + TaskScheduler.UnobservedTaskException += s_eventHandler; + + _listener = new TestEventListener(Trace, enableActivityId: true, "System.Net.Http", "Private.InternalDiagnostics.System.Net.Http", "Private.InternalDiagnostics.System.Net.WebSockets"); + } + + public void Dispose() + { + Console.WriteLine(Environment.NewLine + "===== Disposing TracingTestCollection =====" + Environment.NewLine); + _listener.Dispose(); + + TaskScheduler.UnobservedTaskException -= s_eventHandler; + Console.WriteLine($"Unobserved exceptions of {s_unobservedExceptions.Count} different types: {Environment.NewLine}{string.Join(Environment.NewLine + new string('=', 120) + Environment.NewLine, s_unobservedExceptions.Select(pair => $"Count {pair.Value}: {pair.Key}"))}"); + } + + public void Trace(string message) => Trace((FormattableString)$"{message}"); + + public void Trace(FormattableString message) + { + var str = $"{DateTime.UtcNow:HH:mm:ss.fff} {_tracePrefix} | {message}"; + lock (Console.Out) + { + Console.WriteLine(str); + } + } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs index c0e71d42bb047..14affae6bd39e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs @@ -264,7 +264,7 @@ public async Task CloseOutputAsync_ClientInitiated_CanReceive_CanClose(Uri serve [ActiveIssue("https://github.com/dotnet/runtime/issues/28957", typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] [OuterLoop("Uses external servers", typeof(PlatformDetection), nameof(PlatformDetection.LocalEchoServerIsNotAvailable))] - [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersWithSwitch))] + [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersAndBoolean))] public async Task CloseOutputAsync_ServerInitiated_CanReceive(Uri server, bool delayReceiving) { var expectedCloseStatus = WebSocketCloseStatus.NormalClosure; @@ -367,15 +367,8 @@ await cws.SendAsync( } } - public static IEnumerable EchoServersWithSwitch => - EchoServers.SelectMany(server => new List - { - new object[] { server[0], true }, - new object[] { server[0], false } - }); - [ActiveIssue("https://github.com/dotnet/runtime/issues/28957", typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] - [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersWithSwitch))] + [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersAndBoolean))] public async Task CloseOutputAsync_ServerInitiated_CanReceiveAfterClose(Uri server, bool syncState) { using (ClientWebSocket cws = await GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index 218a6269110c0..7d11771c26906 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -14,17 +14,13 @@ namespace System.Net.WebSockets.Client.Tests [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] public abstract class KeepAliveTest_Loopback : ClientWebSocketTestBase { - public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; - public static readonly object[][] UseSsl_MemberData = PlatformDetection.SupportsAlpn - ? new[] { new object[] { false }, new object[] { true } } - : new[] { new object[] { false } }; - + [ActiveIssue("")] // TODO [Theory] [MemberData(nameof(UseSsl_MemberData))] - [InlineData(false)] public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) { var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; @@ -42,21 +38,26 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) DisposeClientWebSocket = true, ConfigureClientOptions = clientOptions => { - clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100); + clientOptions.KeepAliveInterval = TimeSpan.FromMilliseconds(100); clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); }, - DebugLog = DebugLog + //DebugLog = Trace }; return LoopbackWebSocketServer.RunAsync( async (clientWebSocket, token) => { + Trace("VerifySendReceiveAsync #1 starting on client"); await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + Trace("VerifySendReceiveAsync #1 completed on client"); // We need to always have a read task active to keep processing pongs var outstandingReadTask = clientWebSocket.ReceiveAsync(Array.Empty(), token); + Trace("Client waiting for long delay by server"); + await longDelayByServerTcs.Task.WaitAsync(token); + Trace("Long delay completed on client"); var result = await outstandingReadTask; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); @@ -65,26 +66,36 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) Assert.Equal(WebSocketState.Open, clientWebSocket.State); + Trace("VerifySendReceiveAsync #2 starting on client"); await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + Trace("VerifySendReceiveAsync #2 completed on client"); Assert.Equal(WebSocketState.Open, clientWebSocket.State); - DebugLog("Sending close frame from client"); + Trace("Sending close frame from client"); await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token); Assert.Equal(WebSocketState.Closed, clientWebSocket.State); - DebugLog("Client closed"); + Trace("Client closed"); }, async (serverWebSocket, token) => { + Trace("VerifySendReceiveAsync #1 starting on server"); + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + Trace("VerifySendReceiveAsync #1 completed on server"); + Assert.Equal(WebSocketState.Open, serverWebSocket.State); + Trace("Server initiating long delay"); + await Task.Delay(LongDelay); + Trace("Server long delay completed"); + Assert.Equal(WebSocketState.Open, serverWebSocket.State); // recreate already-completed TCS for another round @@ -93,33 +104,27 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) longDelayByServerTcs.SetResult(); + Trace("VerifySendReceiveAsync #2 starting on server"); + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); - DebugLog("Receiving close frame on server"); + Trace("VerifySendReceiveAsync #2 completed on server"); + + Trace("Receiving close frame on server"); var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token); Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType); Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State); - DebugLog("Sending close frame response from server"); + Trace("Sending close frame response from server"); await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token); Assert.Equal(WebSocketState.Closed, serverWebSocket.State); - DebugLog("Server closed"); + Trace("Server closed"); }, options, timeoutCts.Token); - - - void DebugLog(string str) - { - const int MaxLogLength = 3000; - lock (Console.Out) - { - Console.WriteLine($"{this.GetType().Name} | useSsl={useSsl} | {str.Substring(0, Math.Min(str.Length, MaxLogLength))}{(str.Length > MaxLogLength ? "" : "")}"); - } - } } private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, @@ -146,32 +151,34 @@ private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, public class KeepAliveTest_Invoker_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } protected override bool UseCustomInvoker => true; } public class KeepAliveTest_HttpClient_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } protected override bool UseHttpClient => true; } public class KeepAliveTest_SharedHandler_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } } // --- HTTP/2 WebSocket loopback tests --- + [Collection(nameof(TracingTestCollection))] public class KeepAliveTest_Invoker_Http2 : KeepAliveTest_Invoker_Loopback { - public KeepAliveTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_Invoker_Http2(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } protected override Version HttpVersion => Net.HttpVersion.Version20; } + [Collection(nameof(TracingTestCollection))] public class KeepAliveTest_HttpClient_Http2 : KeepAliveTest_HttpClient_Loopback { - public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } + public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } protected override Version HttpVersion => Net.HttpVersion.Version20; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs index 3f62ae33e5f95..94ecf46f5afb7 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -69,7 +69,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask DisposeAsync() { - _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}; stack={Environment.StackTrace}"); + _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}"); try { await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); From 1803be89f65bc96a6b09f63ae752d11a862f2128 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 7 Aug 2024 19:05:57 +0100 Subject: [PATCH 09/13] fix test --- .../src/System.Net.WebSockets.Client.csproj | 2 +- .../tests/AbortTest.Loopback.cs | 10 +- .../tests/KeepAliveTest.Loopback.cs | 81 +++--------- .../tests/KeepAliveTest.cs | 4 +- .../LoopbackServer/Http2LoopbackStream.cs | 8 +- .../LoopbackServer/LoopbackWebSocketServer.cs | 2 +- .../LoopbackServer/ReadAheadWebSocket.cs | 123 ++++++++++++++++++ .../WebSocketHandshakeHelper.cs | 2 +- .../System.Net.WebSockets.Client.Tests.csproj | 1 + 9 files changed, 161 insertions(+), 72 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index cbad5c01b6da0..8265edd7e9369 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -48,6 +48,7 @@ + @@ -58,7 +59,6 @@ - diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs index 471e82c5e15fd..4bd2e574300ec 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -207,15 +207,17 @@ public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) // --- HTTP/2 WebSocket loopback tests --- - public abstract class AbortTest_Http2 : AbortTest_Loopback + public abstract class AbortTest_Loopback_Http2 : AbortTest_Loopback { - public AbortTest_Http2(ITestOutputHelper output) : base(output) { } + public AbortTest_Loopback_Http2(ITestOutputHelper output) : base(output) { } protected override Version HttpVersion => Net.HttpVersion.Version20; protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); public static object[][] ServerResetsAfterCloseHandshake_MemberData = ToMemberData(Bool_Values, UseSsl_Values); + [ActiveIssue("TODO")] // flaky test; unrelated existing issue + [OuterLoop("Uses Task.Delay")] [Theory] [MemberData(nameof(ServerResetsAfterCloseHandshake_MemberData))] public Task ServerResetsAfterCloseHandshake_NoExceptionOnClient(bool sendGoAway, bool useSsl) @@ -325,13 +327,13 @@ public Task ServerResetsAfterCloseHandshake_NoExceptionOnClient(bool sendGoAway, } } - public class AbortTest_Invoker_Http2 : AbortTest_Http2 + public class AbortTest_Invoker_Http2 : AbortTest_Loopback_Http2 { public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } protected override bool UseCustomInvoker => true; } - public class AbortTest_HttpClient_Http2 : AbortTest_Http2 + public class AbortTest_HttpClient_Http2 : AbortTest_Loopback_Http2 { public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } protected override bool UseHttpClient => true; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index 7d11771c26906..d34c076bedd1e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -5,6 +5,7 @@ using System.Net.Test.Common; using System.Threading; using System.Threading.Tasks; +using System.Threading.Channels; using Xunit; using Xunit.Abstractions; @@ -14,11 +15,11 @@ namespace System.Net.WebSockets.Client.Tests [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] public abstract class KeepAliveTest_Loopback : ClientWebSocketTestBase { - public KeepAliveTest_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; - [ActiveIssue("")] // TODO + [OuterLoop("Uses Task.Delay")] [Theory] [MemberData(nameof(UseSsl_MemberData))] public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) @@ -41,61 +42,33 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) clientOptions.KeepAliveInterval = TimeSpan.FromMilliseconds(100); clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); }, - //DebugLog = Trace + DebugLog = Trace }; return LoopbackWebSocketServer.RunAsync( - async (clientWebSocket, token) => + async (cws, token) => { - Trace("VerifySendReceiveAsync #1 starting on client"); - await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); - Trace("VerifySendReceiveAsync #1 completed on client"); + ReadAheadWebSocket clientWebSocket = new(cws); - // We need to always have a read task active to keep processing pongs - var outstandingReadTask = clientWebSocket.ReceiveAsync(Array.Empty(), token); - - Trace("Client waiting for long delay by server"); - - await longDelayByServerTcs.Task.WaitAsync(token); - Trace("Long delay completed on client"); - - var result = await outstandingReadTask; - Assert.Equal(WebSocketMessageType.Binary, result.MessageType); - Assert.False(result.EndOfMessage); - Assert.Equal(0, result.Count); // we issued a zero byte read, just to wait for data to become available + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + await longDelayByServerTcs.Task.WaitAsync(token).ConfigureAwait(false); Assert.Equal(WebSocketState.Open, clientWebSocket.State); - Trace("VerifySendReceiveAsync #2 starting on client"); - await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); - Trace("VerifySendReceiveAsync #2 completed on client"); - + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); Assert.Equal(WebSocketState.Open, clientWebSocket.State); - Trace("Sending close frame from client"); - - await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token); - + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); Assert.Equal(WebSocketState.Closed, clientWebSocket.State); - - Trace("Client closed"); }, - async (serverWebSocket, token) => + async (sws, token) => { - Trace("VerifySendReceiveAsync #1 starting on server"); - - await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); - - Trace("VerifySendReceiveAsync #1 completed on server"); + ReadAheadWebSocket serverWebSocket = new(sws); + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); Assert.Equal(WebSocketState.Open, serverWebSocket.State); - Trace("Server initiating long delay"); - await Task.Delay(LongDelay); - - Trace("Server long delay completed"); - Assert.Equal(WebSocketState.Open, serverWebSocket.State); // recreate already-completed TCS for another round @@ -104,24 +77,14 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) longDelayByServerTcs.SetResult(); - Trace("VerifySendReceiveAsync #2 starting on server"); - - await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); - Trace("VerifySendReceiveAsync #2 completed on server"); - - Trace("Receiving close frame on server"); - - var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token); + var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token).ConfigureAwait(false); Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType); Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State); - Trace("Sending close frame response from server"); - - await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token); + await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); Assert.Equal(WebSocketState.Closed, serverWebSocket.State); - - Trace("Server closed"); }, options, timeoutCts.Token); @@ -151,34 +114,32 @@ private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, public class KeepAliveTest_Invoker_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } protected override bool UseCustomInvoker => true; } public class KeepAliveTest_HttpClient_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } protected override bool UseHttpClient => true; } public class KeepAliveTest_SharedHandler_Loopback : KeepAliveTest_Loopback { - public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } } // --- HTTP/2 WebSocket loopback tests --- - [Collection(nameof(TracingTestCollection))] public class KeepAliveTest_Invoker_Http2 : KeepAliveTest_Invoker_Loopback { - public KeepAliveTest_Invoker_Http2(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } protected override Version HttpVersion => Net.HttpVersion.Version20; } - [Collection(nameof(TracingTestCollection))] public class KeepAliveTest_HttpClient_Http2 : KeepAliveTest_HttpClient_Loopback { - public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output, TracingTestCollection c) : base(output, c) { } + public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } protected override Version HttpVersion => Net.HttpVersion.Version20; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs index e819d9800c675..5ff9c51e56a6a 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs @@ -19,7 +19,7 @@ public class KeepAliveTest : ClientWebSocketTestBase public KeepAliveTest(ITestOutputHelper output) : base(output) { } [ConditionalFact(nameof(WebSocketsSupported))] - [OuterLoop] // involves long delay + [OuterLoop("Uses Task.Delay")] public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() { using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1))) @@ -37,7 +37,7 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() } [ConditionalTheory(nameof(WebSocketsSupported))] - [OuterLoop] // involves long delay + [OuterLoop("Uses Task.Delay")] [InlineData(1, 0)] // unsolicited pong [InlineData(1, 2)] // ping/pong public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs index 94ecf46f5afb7..8c4dbe0604c74 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -16,6 +16,7 @@ public class Http2LoopbackStream : Stream private bool _readEnded; private ReadOnlyMemory _leftoverReadData; private readonly Action? _debugLog; + private bool _sendResetOnDispose; public override bool CanRead => true; public override bool CanSeek => false; @@ -24,11 +25,12 @@ public class Http2LoopbackStream : Stream public Http2LoopbackConnection Connection => _connection; public int StreamId => _streamId; - public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, Action? debugLog = null) + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, bool sendResetOnDispose = true, Action? debugLog = null) { _connection = connection; _streamId = streamId; _debugLog = debugLog; + _sendResetOnDispose = sendResetOnDispose; } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) @@ -69,12 +71,12 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask DisposeAsync() { - _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}"); + _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}; sendResetOnDispose={_sendResetOnDispose}"); try { await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); - if (!_readEnded) + if (!_readEnded && _sendResetOnDispose) { var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId); await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs index 34f7a5327c8d2..f2ccb6e95a82f 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -75,7 +75,7 @@ await server.AcceptConnectionAsync(async connection => options.DebugLog?.Invoke("loopbackServerFunc completed; disposing the connection"); - await http2Connection.DisposeAsync().ConfigureAwait(false); + await http2Connection.ShutdownIgnoringErrorsAsync(http2StreamId).ConfigureAwait(false); }, new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl, DebugLog = options.DebugLog }); } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs new file mode 100644 index 0000000000000..f3c27e048429e --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Channels; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests; + +internal class ReadAheadWebSocket : WebSocket +{ + private const int ReadAheadBufferSize = 64 * 1024 * 1024; + + private record struct DataFrame(ValueWebSocketReceiveResult Metadata, Memory Memory, byte[] _rented); + + private Channel _incomingFrames = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); + private DataFrame? _currentFrame; + + private SemaphoreSlim receiveMutex = new SemaphoreSlim(1, 1); + private readonly WebSocket _innerWebSocket; + + public ReadAheadWebSocket(WebSocket innerWebSocket, Action? debugLog = null) + { + _innerWebSocket = innerWebSocket; + _ = ProcessIncomingFrames(debugLog); + } + + private async Task ProcessIncomingFrames(Action? debugLog) + { + var buffer = new byte[ReadAheadBufferSize]; + while (true) + { + try + { + ValueWebSocketReceiveResult result = await _innerWebSocket.ReceiveAsync((Memory)buffer, default).ConfigureAwait(false); + + byte[] rented = result.Count > 0 ? ArrayPool.Shared.Rent(result.Count) : Array.Empty(); + Memory message = rented.AsMemory(0, result.Count); + buffer.AsMemory(0, result.Count).CopyTo(message); + + await _incomingFrames.Writer.WriteAsync(new DataFrame(result, message, rented), default).ConfigureAwait(false); + + if (result.MessageType == WebSocketMessageType.Close) + { + _incomingFrames.Writer.Complete(); + break; + } + } + catch (Exception e) + { + _incomingFrames.Writer.Complete(e); + debugLog?.Invoke($"Exception during {nameof(ProcessIncomingFrames)}: {e}"); + break; + } + } + } + + public override async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + { + await receiveMutex.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + _currentFrame ??= await _incomingFrames.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); + + var (result, message, rented) = _currentFrame.Value; + + if (buffer.Length < result.Count) + { + message.Slice(0, buffer.Length).CopyTo(buffer); + var remaining = message.Slice(buffer.Length); + _currentFrame = _currentFrame.Value with { Metadata = new (remaining.Length, result.MessageType, result.EndOfMessage), Memory = remaining }; + + return new (buffer.Length, result.MessageType, endOfMessage: false); + } + else + { + message.CopyTo(buffer); + if (rented.Length > 0) + { + ArrayPool.Shared.Return(rented); + } + _currentFrame = null; + return result; + } + } + finally + { + receiveMutex.Release(); + } + } + + public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + ValueWebSocketReceiveResult valueResult = await ReceiveAsync((Memory)buffer, cancellationToken).ConfigureAwait(false); + var result = new WebSocketReceiveResult( + valueResult.Count, + valueResult.MessageType, + valueResult.EndOfMessage, + valueResult.MessageType == WebSocketMessageType.Close ? CloseStatus : null, + valueResult.MessageType == WebSocketMessageType.Close ? CloseStatusDescription : null); + return result; + } + + public override WebSocketCloseStatus? CloseStatus => _innerWebSocket.CloseStatus; + public override string? CloseStatusDescription => _innerWebSocket.CloseStatusDescription; + public override string? SubProtocol => _innerWebSocket.SubProtocol; + public override WebSocketState State => _innerWebSocket.State; + public override void Abort() => _innerWebSocket.Abort(); + public override void Dispose() => _innerWebSocket.Dispose(); + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) => _innerWebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) => _innerWebSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs index 271ed4beca79e..e1eb0f3c63520 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -83,7 +83,7 @@ public static async Task ProcessHttp2RequestAsync(Http2Loo await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); } - data.WebSocketStream = new Http2LoopbackStream(connection, streamId, server.DebugLog); + data.WebSocketStream = new Http2LoopbackStream(connection, streamId, sendResetOnDispose: false, server.DebugLog); return data; } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index b45fbad02d0c9..af3355cd701c6 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -62,6 +62,7 @@ + From a4c5bcf4257af73779dae45a5493036ea3ecc27f Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 8 Aug 2024 15:00:50 +0100 Subject: [PATCH 10/13] Address Feedback --- .../Net/WebSockets/WebSocketValidate.cs | 40 +-- .../Net/Http/Http2LoopbackConnection.cs | 7 +- .../System/Net/Http/Http2LoopbackServer.cs | 4 - .../tests/AbortTest.Loopback.cs | 128 +------ .../tests/ClientWebSocketTestBase.cs | 89 +---- .../tests/KeepAliveTest.Loopback.cs | 3 +- .../LoopbackServer/Http2LoopbackStream.cs | 5 +- .../LoopbackServer/LoopbackWebSocketServer.cs | 6 +- .../LoopbackServer/ReadAheadWebSocket.cs | 7 +- .../WebSocketHandshakeHelper.cs | 2 +- .../WebSockets/ManagedWebSocket.KeepAlive.cs | 319 ++++++++---------- .../System/Net/WebSockets/ManagedWebSocket.cs | 79 +++-- 12 files changed, 222 insertions(+), 467 deletions(-) diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index 1da18d62093ac..c11524bdeef9b 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -38,40 +38,26 @@ internal static partial class WebSocketValidate SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"); internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) - { - // Exception order: - // 1. WebSocketException(InvalidState) -- if invalid state - // 2. ObjectDisposedException + => ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []); - string? invalidStateMessage = GetInvalidStateMessage(currentState, validStates); - if (invalidStateMessage is null) // state is valid + internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null) + { + if (validStates is not null && Array.IndexOf(validStates, currentState) == -1) { - // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. - ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); - return; - } + string invalidStateMessage = SR.Format( + SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates)); - throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage); - } - - internal static string? GetInvalidStateMessage(WebSocketState currentState, WebSocketState[] validStates) - { - string validStatesText = string.Empty; + throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException); + } - if (validStates != null && validStates.Length > 0) + if (innerException is not null) { - foreach (WebSocketState validState in validStates) - { - if (currentState == validState) - { - return null; - } - } - - validStatesText = string.Join(", ", validStates); + Debug.Assert(currentState == WebSocketState.Aborted); + throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException); } - return SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText); + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); } internal static void ValidateSubprotocol(string subProtocol) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index e75474cd5bf8f..e607c42aa48ba 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -28,7 +28,6 @@ public class Http2LoopbackConnection : GenericLoopbackConnection private readonly TimeSpan _timeout; private int _lastStreamId; private bool _expectClientDisconnect; - private readonly Action? _debugLog; private readonly byte[] _prefix = new byte[24]; public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length); @@ -36,13 +35,12 @@ public class Http2LoopbackConnection : GenericLoopbackConnection public Stream Stream => _connectionStream; public Task SettingAckWaiter => _ignoredSettingsAckPromise?.Task; - private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse, Action? debugLog = null) + private Http2LoopbackConnection(SocketWrapper socket, Stream stream, TimeSpan timeout, bool transparentPingResponse) { _connectionSocket = socket; _connectionStream = stream; _timeout = timeout; _transparentPingResponse = transparentPingResponse; - _debugLog = debugLog; } public override string ToString() @@ -85,7 +83,7 @@ public static async Task CreateAsync(SocketWrapper sock stream = sslStream; } - var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse, httpOptions.DebugLog); + var con = new Http2LoopbackConnection(socket, stream, timeout, httpOptions.EnableTransparentPingResponse); await con.ReadPrefixAsync().ConfigureAwait(false); return con; @@ -370,7 +368,6 @@ public async Task WaitForConnectionShutdownAsync(bool ignoreUnexpectedFrames = f // and will ignore any errors if client has already shutdown public async Task ShutdownIgnoringErrorsAsync(int lastStreamId, ProtocolErrors errorCode = ProtocolErrors.NO_ERROR) { - _debugLog?.Invoke($"Http2LoopbackConnection.ShutdownIgnoringErrorsAsync() with lastStreamId={lastStreamId}, errorCode={errorCode}"); try { await SendGoAway(lastStreamId, errorCode).ConfigureAwait(false); diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs index 5f9f52d3cd2b1..90929b70eec37 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs @@ -32,8 +32,6 @@ private Http2LoopbackConnection Connection } } - public Action? DebugLog => _options.DebugLog; - public static readonly TimeSpan Timeout = TimeSpan.FromSeconds(30); public override Uri Address @@ -188,8 +186,6 @@ public class Http2Options : GenericLoopbackOptions public bool EnableTransparentPingResponse { get; set; } = true; - public Action? DebugLog { get; set; } - public Http2Options() { SslProtocols = SslProtocols.Tls12; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs index 4bd2e574300ec..8d0a89b320d61 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -207,135 +207,19 @@ public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) // --- HTTP/2 WebSocket loopback tests --- - public abstract class AbortTest_Loopback_Http2 : AbortTest_Loopback + public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback { - public AbortTest_Loopback_Http2(ITestOutputHelper output) : base(output) { } + public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } protected override Version HttpVersion => Net.HttpVersion.Version20; protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); - - public static object[][] ServerResetsAfterCloseHandshake_MemberData = ToMemberData(Bool_Values, UseSsl_Values); - - [ActiveIssue("TODO")] // flaky test; unrelated existing issue - [OuterLoop("Uses Task.Delay")] - [Theory] - [MemberData(nameof(ServerResetsAfterCloseHandshake_MemberData))] - public Task ServerResetsAfterCloseHandshake_NoExceptionOnClient(bool sendGoAway, bool useSsl) - { - Assert.True(HttpVersion == Net.HttpVersion.Version20); - - var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; - var serverMsg = new byte[] { 42 }; - var serverFinalMsg = new byte[] { 123 }; - var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); - - var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null) - { - DisposeServerWebSocket = false - }; - - var serverSentResetTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - return LoopbackWebSocketServer.RunAsync( - async uri => - { - var token = timeoutCts.Token; - var clientOptions = globalOptions with { HttpInvoker = GetInvoker() }; - var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false); - - await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); - - await clientWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); - - await serverSentResetTcs.Task.WaitAsync(token).ConfigureAwait(false); - - // even though the server sent a reset, the client should receive all the data sent before the reset - - // receive final message - var readBuffer = new byte[1]; - var result = await clientWebSocket.ReceiveAsync(readBuffer, token).ConfigureAwait(false); - Assert.Equal(WebSocketMessageType.Binary, result.MessageType); - Assert.Equal(1, result.Count); - Assert.True(result.EndOfMessage); - Assert.Equal(serverFinalMsg, readBuffer); - - // receive close frame - result = await clientWebSocket.ReceiveAsync(readBuffer, token).ConfigureAwait(false); - Assert.Equal(WebSocketMessageType.Close, result.MessageType); - - // we've already closed our side, so we should be fully closed now - Assert.Equal(WebSocketState.Closed, clientWebSocket.State); - clientWebSocket.Dispose(); - clientReceivedEosTcs.SetResult(); - }, - async (requestData, token) => - { - var connection = requestData.Http2Connection!; - var streamId = requestData.Http2StreamId!.Value; - - var wsOptions = new WebSocketCreationOptions { IsServer = true }; - var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions); - - await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); - - // wait for client to send close frame - var result = await serverWebSocket.ReceiveAsync(new byte[1], token).ConfigureAwait(false); - Assert.Equal(WebSocketMessageType.Close, result.MessageType); - - // send final message - await serverWebSocket.SendAsync(serverFinalMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false); - - await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); - // we've already closed our side, so we should be fully closed now - Assert.Equal(WebSocketState.Closed, serverWebSocket.State); - - try - { - await connection.SendResponseDataAsync(streamId, Memory.Empty, endStream: true).ConfigureAwait(false); - - if (sendGoAway) - { - await connection.SendGoAway(streamId).ConfigureAwait(false); - } - else - { - var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId); - await connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); - } - } - catch (IOException) - { - // Ignore connection errors - } - catch (SocketException) - { - // Ignore connection errors - } - - await Task.Delay(1000); // give the client some time to process the reset - - serverSentResetTcs.SetResult(); - await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false); - serverWebSocket.Dispose(); - }, - globalOptions, - timeoutCts.Token); - } } - public class AbortTest_Invoker_Http2 : AbortTest_Loopback_Http2 - { - public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } - protected override bool UseCustomInvoker => true; - } - - public class AbortTest_HttpClient_Http2 : AbortTest_Loopback_Http2 + public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback { public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } - protected override bool UseHttpClient => true; + protected override Version HttpVersion => Net.HttpVersion.Version20; + protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) + => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index 58fa33d71f5ec..4e4fb4b3d87c7 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -15,7 +15,7 @@ namespace System.Net.WebSockets.Client.Tests { - public class ClientWebSocketTestBase : IDisposable + public class ClientWebSocketTestBase { public static readonly object[][] EchoServers = System.Net.Test.Common.Configuration.WebSockets.GetEchoServers(); public static readonly object[][] EchoHeadersServers = System.Net.Test.Common.Configuration.WebSockets.GetEchoHeadersServers(); @@ -41,24 +41,10 @@ public static object[][] ToMemberData(IEnumerable dataA, IEnumer public const int TimeOutMilliseconds = 30000; public const int CloseDescriptionMaxLength = 123; public readonly ITestOutputHelper _output; - public readonly TracingTestCollection? _collection; - public ClientWebSocketTestBase(ITestOutputHelper output, TracingTestCollection? collection = null) + public ClientWebSocketTestBase(ITestOutputHelper output) { _output = output; - _collection = collection; - - if (_collection != null) - { - _collection._tracePrefix = $"{GetType().Name}#{GetHashCode()}"; - } - - Trace($"{Environment.NewLine}===== Starting {GetType().Name}#{GetHashCode()} ====={Environment.NewLine}"); - } - - public void Dispose() - { - Trace($"{Environment.NewLine}===== Disposing {GetType().Name}#{GetHashCode()} ====={Environment.NewLine}"); } public static IEnumerable UnavailableWebSocketServers @@ -175,76 +161,5 @@ protected Task TestEcho(Uri uri, WebSocketMessageType type, int timeOutMilliseco WebSocketHelper.TestEcho(uri, WebSocketMessageType.Text, TimeOutMilliseconds, _output, GetInvoker()); public static bool WebSocketsSupported { get { return WebSocketHelper.WebSocketsSupported; } } - - public void Trace(FormattableString message) => _collection?.Trace(message); - - public void Trace(string message) => _collection?.Trace(message); - } - - [CollectionDefinition(nameof(TracingTestCollection), DisableParallelization = true)] - public class TracingTestCollection : ICollectionFixture, IDisposable - { - private static readonly Dictionary s_unobservedExceptions = new Dictionary(); - - internal string _tracePrefix = "(null)"; - - private readonly TestEventListener _listener; - - private static readonly EventHandler s_eventHandler = static (_, e) => - { - lock (s_unobservedExceptions) - { - string text = e.Exception.ToString(); - s_unobservedExceptions[text] = s_unobservedExceptions.GetValueOrDefault(text) + 1; - } - }; - - private static readonly FieldInfo s_ClientWebSocket_innerWebSocketField = - typeof(ClientWebSocket).GetField("_innerWebSocket", BindingFlags.NonPublic | BindingFlags.Instance) - ?? throw new Exception("Could not find ClientWebSocket._innerWebSocket field"); - private static readonly PropertyInfo s_WebSocketHandle_WebSocketProperty = - typeof(ClientWebSocket).Assembly.GetType("System.Net.WebSockets.WebSocketHandle", throwOnError: true)! - .GetProperty("WebSocket", BindingFlags.Instance | BindingFlags.Public) - ?? throw new Exception("Could not find WebSocketHandle.WebSocket property"); - - private static WebSocket GetUnderlyingWebSocket(ClientWebSocket clientWebSocket) - { - object? innerWebSocket = s_ClientWebSocket_innerWebSocketField.GetValue(clientWebSocket); - if (innerWebSocket == null) - { - throw new Exception("ClientWebSocket._innerWebSocket is null"); - } - - return (WebSocket)s_WebSocketHandle_WebSocketProperty.GetValue(innerWebSocket); - } - - public TracingTestCollection() - { - Console.WriteLine(Environment.NewLine + "===== Running TracingTestCollection =====" + Environment.NewLine); - - TaskScheduler.UnobservedTaskException += s_eventHandler; - - _listener = new TestEventListener(Trace, enableActivityId: true, "System.Net.Http", "Private.InternalDiagnostics.System.Net.Http", "Private.InternalDiagnostics.System.Net.WebSockets"); - } - - public void Dispose() - { - Console.WriteLine(Environment.NewLine + "===== Disposing TracingTestCollection =====" + Environment.NewLine); - _listener.Dispose(); - - TaskScheduler.UnobservedTaskException -= s_eventHandler; - Console.WriteLine($"Unobserved exceptions of {s_unobservedExceptions.Count} different types: {Environment.NewLine}{string.Join(Environment.NewLine + new string('=', 120) + Environment.NewLine, s_unobservedExceptions.Select(pair => $"Count {pair.Value}: {pair.Key}"))}"); - } - - public void Trace(string message) => Trace((FormattableString)$"{message}"); - - public void Trace(FormattableString message) - { - var str = $"{DateTime.UtcNow:HH:mm:ss.fff} {_tracePrefix} | {message}"; - lock (Console.Out) - { - Console.WriteLine(str); - } - } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index d34c076bedd1e..08306c0804ee4 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -41,8 +41,7 @@ public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) { clientOptions.KeepAliveInterval = TimeSpan.FromMilliseconds(100); clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); - }, - DebugLog = Trace + } }; return LoopbackWebSocketServer.RunAsync( diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs index 8c4dbe0604c74..b841eead6ea24 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -15,7 +15,6 @@ public class Http2LoopbackStream : Stream private readonly int _streamId; private bool _readEnded; private ReadOnlyMemory _leftoverReadData; - private readonly Action? _debugLog; private bool _sendResetOnDispose; public override bool CanRead => true; @@ -25,11 +24,10 @@ public class Http2LoopbackStream : Stream public Http2LoopbackConnection Connection => _connection; public int StreamId => _streamId; - public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, bool sendResetOnDispose = true, Action? debugLog = null) + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, bool sendResetOnDispose = true) { _connection = connection; _streamId = streamId; - _debugLog = debugLog; _sendResetOnDispose = sendResetOnDispose; } @@ -71,7 +69,6 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override async ValueTask DisposeAsync() { - _debugLog?.Invoke($"Http2LoopbackStream.DisposeAsync() for stream {_streamId}; readEnded={_readEnded}; sendResetOnDispose={_sendResetOnDispose}"); try { await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs index f2ccb6e95a82f..ec53020184802 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -73,11 +73,9 @@ await server.AcceptConnectionAsync(async connection => await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); - options.DebugLog?.Invoke("loopbackServerFunc completed; disposing the connection"); - await http2Connection.ShutdownIgnoringErrorsAsync(http2StreamId).ConfigureAwait(false); }, - new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl, DebugLog = options.DebugLog }); + new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); } else { @@ -98,7 +96,6 @@ private static async Task RunServerAsync( if (options.DisposeServerWebSocket) { - options.DebugLog?.Invoke("Disposing server websocket"); serverWebSocket.Dispose(); } } @@ -149,7 +146,6 @@ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker public bool DisposeHttpInvoker { get; set; } public bool ManualServerHandshakeResponse { get; set; } public Action? ConfigureClientOptions { get; set; } - public Action? DebugLog { get; set; } } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs index f3c27e048429e..af98d76580cf2 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs @@ -26,13 +26,13 @@ private record struct DataFrame(ValueWebSocketReceiveResult Metadata, Memory? debugLog = null) + public ReadAheadWebSocket(WebSocket innerWebSocket) { _innerWebSocket = innerWebSocket; - _ = ProcessIncomingFrames(debugLog); + _ = ProcessIncomingFrames(); } - private async Task ProcessIncomingFrames(Action? debugLog) + private async Task ProcessIncomingFrames() { var buffer = new byte[ReadAheadBufferSize]; while (true) @@ -56,7 +56,6 @@ private async Task ProcessIncomingFrames(Action? debugLog) catch (Exception e) { _incomingFrames.Writer.Complete(e); - debugLog?.Invoke($"Exception during {nameof(ProcessIncomingFrames)}: {e}"); break; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs index e1eb0f3c63520..06e62d4a17e48 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -83,7 +83,7 @@ public static async Task ProcessHttp2RequestAsync(Http2Loo await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); } - data.WebSocketStream = new Http2LoopbackStream(connection, streamId, sendResetOnDispose: false, server.DebugLog); + data.WebSocketStream = new Http2LoopbackStream(connection, streamId, sendResetOnDispose: false); return data; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index 4cdd537cc658d..ad1d18b403dd3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; @@ -11,53 +12,6 @@ namespace System.Net.WebSockets { internal sealed partial class ManagedWebSocket : WebSocket { - // "Observe" either a ValueTask result, or any exception, ignoring it - // to prevent the unobserved exception event from being raised. - public void Observe(ValueTask t) - { - if (t.IsCompletedSuccessfully) - { - t.GetAwaiter().GetResult(); - } - else - { - ObserveException(t.AsTask()); - } - } - - // "Observe" either a Task result, or any exception, ignoring it - // to prevent the unobserved exception event from being raised. - public void Observe(Task t) - { - if (t.IsCompletedSuccessfully) - { - t.GetAwaiter().GetResult(); - } - else - { - ObserveException(t); - } - } - - private void ObserveException(Task task) - { - task.ContinueWith( - LogFaulted, - this, - CancellationToken.None, - TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Default); - - static void LogFaulted(Task task, object? thisObj) - { - Debug.Assert(task.IsFaulted); - - Exception? innerException = task.Exception!.InnerException; // accessing exception anyway, to observe it regardless of whether the tracing is enabled - - if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, innerException ?? task.Exception!); - } - } - private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null; private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1; private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1; @@ -105,12 +59,11 @@ private void KeepAlivePingHeartBeat() if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + bool shouldSendPing = false; + long pingPayload = -1; + try { - bool timedOut = false; - bool sendPing = false; - long pingPayload = -1; - lock (StateUpdateLock) { if (_keepAlivePingState.Exception is not null) @@ -121,39 +74,34 @@ private void KeepAlivePingHeartBeat() long now = Environment.TickCount64; - if (_keepAlivePingState.AwaitingPong) + if (_keepAlivePingState.PingSent) { - Debug.Assert(_keepAlivePingState.WillTimeoutTimestamp != Timeout.Infinite); - - if (now > _keepAlivePingState.WillTimeoutTimestamp) + if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp) { - timedOut = true; - pingPayload = _keepAlivePingState.PingPayload; + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {_keepAlivePingState.PingPayload}"); + } + + Exception exc = ExceptionDispatchInfo.SetCurrentStackTrace( + new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout)); + + _keepAlivePingState.OnKeepAliveFaultedCore(exc); // we are holding the lock + return; } } else { - if (now > _keepAlivePingState.NextPingTimestamp) + if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp) { - sendPing = true; - pingPayload = ++_keepAlivePingState.PingPayload; - - _keepAlivePingState.AwaitingPong = true; - _keepAlivePingState.WillTimeoutTimestamp = now + _keepAlivePingState.TimeoutMs; + _keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock + shouldSendPing = true; + pingPayload = _keepAlivePingState.PingPayload; } } } - if (timedOut) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {pingPayload}"); - } - - throw new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout); - } - else if (sendPing) + if (shouldSendPing) { Observe( SendPingAsync(pingPayload)); @@ -163,23 +111,7 @@ private void KeepAlivePingHeartBeat() { if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, e); - bool shouldAbort = false; - lock (StateUpdateLock) - { - if (!_disposed) - { - // We only save the exception in the keep-alive state if we will actually trigger the abort/disposal - // The exception needs to be assigned before _disposed is set to true - _keepAlivePingState.Exception = e; - shouldAbort = true; - } - } - - if (shouldAbort) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Exception saved in _keepAlivePingState, aborting..."); - Abort(); - } + _keepAlivePingState.OnKeepAliveFaulted(e); } } @@ -191,10 +123,7 @@ private async ValueTask SendPingAsync(long pingPayload) BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload); try { - await TrySendKeepAliveFrameAsync( - MessageOpcode.Ping, - pingPayloadBuffer.AsMemory(0, sizeof(long))) - .ConfigureAwait(false); + await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false); if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload); } @@ -204,129 +133,173 @@ await TrySendKeepAliveFrameAsync( } } - private void OnDataReceived(int bytesRead) + // "Observe" either a ValueTask result, or any exception, ignoring it + // to prevent the unobserved exception event from being raised. + private void Observe(ValueTask t) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={bytesRead}"); - - if (_keepAlivePingState != null && bytesRead > 0) + if (t.IsCompletedSuccessfully) { - lock (StateUpdateLock) - { - _keepAlivePingState.OnDataReceived(); - } + t.GetAwaiter().GetResult(); + } + else + { + Observe(t.AsTask()); } } - private void ThrowIfDisposedOrKeepAliveFaulted() - => ThrowIfInvalidStateOrKeepAliveFaulted(validStates: null); - - private void ThrowIfInvalidStateOrKeepAliveFaulted(WebSocketState[]? validStates) + // "Observe" any exception, ignoring it to prevent the unobserved task + // exception event from being raised. + private void Observe(Task t) { - Debug.Assert(_keepAlivePingState is not null); - - // Exception order: WebSocketException -> OperationCanceledException -> ObjectDisposedException - // - // If keepAlive exception present: - // 1. WebSocketException(InvalidState), keepAlive exception as inner -- if invalid state - // 2. OperationCanceledException, keepAlive exception as inner - // - // If keepAlive exception not present: - // 1. WebSocketException(InvalidState) -- if invalid state - // 2. ObjectDisposedException - - bool disposed; - WebSocketState state; - Exception? keepAliveException; - lock (StateUpdateLock) + if (t.IsCompleted) { - disposed = _disposed; - state = _state; - keepAliveException = _keepAlivePingState.Exception; + if (t.IsFaulted) + { + LogFaulted(t, this); + } } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_disposed={disposed}, _state={state}, _keepAlivePingState.Exception={keepAliveException?.Message}"); - - string? invalidStateMessage = validStates is not null ? WebSocketValidate.GetInvalidStateMessage(state, validStates) : null; - if (invalidStateMessage is not null) + else { - // Surface keepAliveException as inner exception, if present - throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, keepAliveException); + t.ContinueWith( + LogFaulted, + this, + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); } - // If keepAliveException is not null, it triggered the abort which also disposed the websocket - // We only save the exception if it actually triggered the abort - if (keepAliveException is not null) + static void LogFaulted(Task task, object? thisObj) { - throw new OperationCanceledException(nameof(WebSocketState.Aborted), keepAliveException); - } + Debug.Assert(task.IsFaulted); + + // accessing exception to observe it regardless of whether the tracing is enabled + Exception e = task.Exception!.InnerException!; - // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. - ObjectDisposedException.ThrowIf(disposed, this); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e); + } } private sealed class KeepAlivePingState { internal const int PingPayloadSize = sizeof(long); - internal const int MinIntervalMs = 1; - - internal readonly int DelayMs; - internal readonly int TimeoutMs; - internal readonly int HeartBeatIntervalMs; + private const int MinIntervalMs = 1; - internal long NextPingTimestamp; - internal long WillTimeoutTimestamp; + private readonly ManagedWebSocket _parent; + private object StateUpdateLock => _parent.StateUpdateLock; - internal bool AwaitingPong; - internal long PingPayload; - internal Exception? Exception; + internal int DelayMs { get; } + internal int TimeoutMs { get; } + internal int HeartBeatIntervalMs => Math.Max(Math.Min(DelayMs, TimeoutMs) / 4, MinIntervalMs); - internal object Debug_WebSocket_StateUpdateLock = null!; // for Debug.Asserts + internal long PingPayload { get; private set; } + internal bool PingSent { get; private set; } + internal long PingTimeoutTimestamp { get; private set; } + internal long NextPingRequestTimestamp { get; private set; } + internal Exception? Exception { get; private set; } - public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) + public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout, ManagedWebSocket parent) { DelayMs = TimeSpanToMs(keepAliveInterval); TimeoutMs = TimeSpanToMs(keepAliveTimeout); - NextPingTimestamp = Environment.TickCount64 + DelayMs; - WillTimeoutTimestamp = Timeout.Infinite; + NextPingRequestTimestamp = Environment.TickCount64 + DelayMs; + PingTimeoutTimestamp = Timeout.Infinite; + _parent = parent; - HeartBeatIntervalMs = Math.Max( - Math.Min(DelayMs, TimeoutMs) / 4, - MinIntervalMs); - - static int TimeSpanToMs(TimeSpan value) => - (int)Math.Clamp((long)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); + static int TimeSpanToMs(TimeSpan value) => (int)Math.Clamp((long)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); } internal void OnDataReceived() { - Debug.Assert(Monitor.IsEntered(Debug_WebSocket_StateUpdateLock)); - - NextPingTimestamp = Environment.TickCount64 + DelayMs; + lock (StateUpdateLock) + { + NextPingRequestTimestamp = Environment.TickCount64 + DelayMs; + } } internal void OnPongResponseReceived(long pongPayload) { - Debug.Assert(Monitor.IsEntered(Debug_WebSocket_StateUpdateLock)); + lock (StateUpdateLock) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"pongPayload={pongPayload}"); - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"pongPayload={pongPayload}"); + if (!PingSent) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Not waiting for Pong. Skipping."); + return; + } - if (!AwaitingPong) + if (pongPayload == PingPayload) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayload); + + PingTimeoutTimestamp = long.MaxValue; + PingSent = false; + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Expected payload {PingPayload}. Skipping."); + } + } + } + + internal void OnNextPingRequestCore() + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock)); + + PingSent = true; + PingTimeoutTimestamp = Environment.TickCount64 + TimeoutMs; + ++PingPayload; + } + + internal void OnKeepAliveFaulted(Exception exc) + { + lock (StateUpdateLock) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Not waiting for Pong. Skipping."); - return; + OnKeepAliveFaultedCore(exc); } + } - if (pongPayload == PingPayload) + internal void OnKeepAliveFaultedCore(Exception exc) + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock)); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + + if (_parent._disposed) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayload); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket already disposed, skipping..."); + return; + } - WillTimeoutTimestamp = Timeout.Infinite; - AwaitingPong = false; + if (_parent.State is WebSocketState.Closed) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already closed, skipping..."); + // We've transferred into the Closed state, but didn't dispose yet + // This can happen in e.g. HandleReceivedCloseAsync where we first change the state + // but then still do some operations with the stream. + // No need to do anything as we've already completed the Closing Handshake + return; } - else + + if (_parent.State is WebSocketState.Aborted) { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Expected payload {PingPayload}. Skipping."); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already aborted, skipping..."); + // Something else already aborted the websocket, but didn't dispose it (yet?)? + // This han happen either + // (1) in the Abort() method, e.g. on cancellation, if we interjected between the state + // change and the Dispose() call; or + // (2) in the catch block of ReceiveAsyncPrivate (which doesn't do the dispose after??). + // This most possibly happens if we've hit a premature EOF from the server. + // Websocket is not usable in the Aborted state anyway, so let's free the resources while we're at it? + _parent.Dispose(); + return; } + + // we were the ones who triggered the abort, let's save the exception + Exception = exc; + + _parent.OnAbortedCore(); + _parent.DisposeCore(); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index cbdc09ae7dbfe..8a26a4c29e2eb 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -8,6 +8,7 @@ using System.Net.WebSockets.Compression; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; @@ -179,11 +180,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim long heartBeatIntervalMs = (long)keepAliveInterval.TotalMilliseconds; if (keepAliveTimeout > TimeSpan.Zero) { - _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout); -#if DEBUG - _keepAlivePingState.Debug_WebSocket_StateUpdateLock = StateUpdateLock; -#endif - + _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout, this); heartBeatIntervalMs = _keepAlivePingState.HeartBeatIntervalMs; if (NetEventSource.Log.IsEnabled()) @@ -468,16 +465,23 @@ private void OnAborted() lock (StateUpdateLock) { - WebSocketState state = _state; - if (state != WebSocketState.Closed && state != WebSocketState.Aborted) - { - _state = state != WebSocketState.None && state != WebSocketState.Connecting ? - WebSocketState.Aborted : - WebSocketState.Closed; - } + OnAbortedCore(); + } + } - if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); + private void OnAbortedCore() + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held"); + + WebSocketState state = _state; + if (state is not WebSocketState.Closed and not WebSocketState.Aborted) + { + _state = state is not WebSocketState.None and not WebSocketState.Connecting ? + WebSocketState.Aborted : + WebSocketState.Closed; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } /// Sends a websocket frame to the network. @@ -948,11 +952,14 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo int numBytesRead = await _stream.ReadAtLeastAsync( readBuffer, bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numBytesRead}"); + if (numBytesRead < bytesToRead) { ThrowEOFUnexpected(); } - OnDataReceived(numBytesRead); + _keepAlivePingState?.OnDataReceived(); totalBytesReceived += numBytesRead; } @@ -1019,14 +1026,16 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo if (_state == WebSocketState.Aborted) { + Exception inner = exc; if (_keepAlivePingState?.Exception is not null) { - // it should have already been wrapped in an OperationCanceledException and thrown above, - // but just in case it wasn't due to some race, let's surface both exceptions - throw new OperationCanceledException(nameof(WebSocketState.Aborted), new AggregateException(exc, _keepAlivePingState.Exception)); + // exception was most likely caused by us aborting the connection due to + // keep-alive timeout; but let's surface both just in case + inner = ExceptionDispatchInfo.SetCurrentStackTrace( + new AggregateException(_keepAlivePingState.Exception, exc)); } - throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); + throw new OperationCanceledException(nameof(WebSocketState.Aborted), inner); } OnAborted(); @@ -1600,11 +1609,13 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc _receiveBuffer.Slice(_receiveBufferCount), bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false); _receiveBufferCount += numRead; + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numRead}"); + if (numRead < bytesToRead) { ThrowEOFUnexpected(); } - OnDataReceived(numRead); + _keepAlivePingState?.OnDataReceived(); } } } @@ -1726,26 +1737,28 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa cancellationToken); } - private void ThrowIfDisposed() - { - if (_keepAlivePingState is not null) - { - ThrowIfDisposedOrKeepAliveFaulted(); - return; - } + private void ThrowIfDisposed() => ThrowIfInvalidState(); - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); - } - - private void ThrowIfInvalidState(WebSocketState[] validStates) + private void ThrowIfInvalidState(WebSocketState[]? validStates = null) { + bool disposed = _disposed; + WebSocketState state = _state; + Exception? keepAliveException = null; + if (_keepAlivePingState is not null) { - ThrowIfInvalidStateOrKeepAliveFaulted(validStates); - return; + // we need to take a lock to maintain consistency + lock (StateUpdateLock) + { + disposed = _disposed; + state = _state; + keepAliveException = _keepAlivePingState.Exception; + } } - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, validStates); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}"); + + WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates); } // From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75 From 83d88e41f2e81252f97da8241d54df1bcd9f8901 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 8 Aug 2024 15:11:38 +0100 Subject: [PATCH 11/13] TO REVERT: run new tests in innerloop --- .../tests/KeepAliveTest.Loopback.cs | 2 +- .../System.Net.WebSockets.Client/tests/KeepAliveTest.cs | 2 +- .../System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index 08306c0804ee4..5157992ee8796 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -19,7 +19,7 @@ public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; - [OuterLoop("Uses Task.Delay")] + //[OuterLoop("Uses Task.Delay")] [Theory] [MemberData(nameof(UseSsl_MemberData))] public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs index 5ff9c51e56a6a..98022d8ef859e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs @@ -37,7 +37,7 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() } [ConditionalTheory(nameof(WebSocketsSupported))] - [OuterLoop("Uses Task.Delay")] + //[OuterLoop("Uses Task.Delay")] [InlineData(1, 0)] // unsolicited pong [InlineData(1, 2)] // ping/pong public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs index 11dd28117ac5d..486af3fd197ef 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs @@ -237,7 +237,7 @@ static void ApplyMask(Span buffer, Span mask) } } - [OuterLoop("Uses Task.Delay")] + //[OuterLoop("Uses Task.Delay")] [Theory] [InlineData(true)] [InlineData(false)] From 5698eacb77d6e859b5d7c700b1f789b12295b1bf Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 8 Aug 2024 17:02:35 +0100 Subject: [PATCH 12/13] Revert "TO REVERT: run new tests in innerloop" This reverts commit 83d88e41f2e81252f97da8241d54df1bcd9f8901. --- .../tests/KeepAliveTest.Loopback.cs | 2 +- .../System.Net.WebSockets.Client/tests/KeepAliveTest.cs | 2 +- .../System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs index 5157992ee8796..08306c0804ee4 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -19,7 +19,7 @@ public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; - //[OuterLoop("Uses Task.Delay")] + [OuterLoop("Uses Task.Delay")] [Theory] [MemberData(nameof(UseSsl_MemberData))] public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs index 98022d8ef859e..5ff9c51e56a6a 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs @@ -37,7 +37,7 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() } [ConditionalTheory(nameof(WebSocketsSupported))] - //[OuterLoop("Uses Task.Delay")] + [OuterLoop("Uses Task.Delay")] [InlineData(1, 0)] // unsolicited pong [InlineData(1, 2)] // ping/pong public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs index 486af3fd197ef..11dd28117ac5d 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs @@ -237,7 +237,7 @@ static void ApplyMask(Span buffer, Span mask) } } - //[OuterLoop("Uses Task.Delay")] + [OuterLoop("Uses Task.Delay")] [Theory] [InlineData(true)] [InlineData(false)] From 9ee42215a3ee84168f11c11d414f2a59667825c2 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Thu, 8 Aug 2024 17:32:52 +0100 Subject: [PATCH 13/13] add tripleslash docs --- .../src/System/Net/WebSockets/ClientWebSocketOptions.cs | 6 ++++++ .../src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index 2b28a561d8fdb..dc7155e11cf58 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -172,6 +172,12 @@ public void AddSubProtocol(string subProtocol) subprotocols.Add(subProtocol); } + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise, + /// unsolicited PONG messages are used as a keep-alive heartbeat. + /// The default is (typically 30 seconds). + /// [UnsupportedOSPlatform("browser")] public TimeSpan KeepAliveInterval { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs index ad1d18b403dd3..c9ff393cb7180 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -285,7 +285,7 @@ internal void OnKeepAliveFaultedCore(Exception exc) { if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already aborted, skipping..."); // Something else already aborted the websocket, but didn't dispose it (yet?)? - // This han happen either + // This can happen either // (1) in the Abort() method, e.g. on cancellation, if we interjected between the state // change and the Dispose() call; or // (2) in the catch block of ReceiveAsyncPrivate (which doesn't do the dispose after??).