Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket Keep-Alive Ping and Timeout (minimal) implementation #105841

Merged
merged 14 commits into from
Aug 9, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading;

namespace System.Net.WebSockets
{
/// <summary>
/// Central repository for default values used in WebSocket settings. Not all settings are relevant
/// to or configurable by all WebSocket implementations.
/// </summary>
internal static partial class WebSocketDefaults
{
public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.Zero;
public static readonly TimeSpan DefaultClientKeepAliveInterval = TimeSpan.FromSeconds(30);

public static readonly TimeSpan DefaultKeepAliveTimeout = Timeout.InfiniteTimeSpan;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ internal ClientWebSocketOptions() { }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveInterval { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.IWebProxy? Proxy { get { throw null; } set { } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<Compile Include="System\Net\WebSockets\ClientWebSocketOptions.cs" />
<Compile Include="System\Net\WebSockets\WebSocketHandle.Managed.cs" />
<Compile Include="$(CommonPath)System\Net\HttpKnownHeaderNames.cs" Link="Common\System\Net\HttpKnownHeaderNames.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketDefaults.cs" Link="Common\System\Net\WebSockets\WebSocketDefaults.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'browser'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ public TimeSpan KeepAliveInterval
set => throw new PlatformNotSupportedException();
}

[UnsupportedOSPlatform("browser")]
public TimeSpan KeepAliveTimeout
{
get => throw new PlatformNotSupportedException();
set => throw new PlatformNotSupportedException();
}

[UnsupportedOSPlatform("browser")]
public WebSocketDeflateOptions? DangerousDeflateOptions
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace System.Net.WebSockets
public sealed class ClientWebSocketOptions
{
private bool _isReadOnly; // After ConnectAsync is called the options cannot be modified.
private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval;
private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultClientKeepAliveInterval;
private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout;
private bool _useDefaultCredentials;
private ICredentials? _credentials;
private IWebProxy? _proxy;
Expand Down Expand Up @@ -188,6 +189,23 @@ public TimeSpan KeepAliveInterval
}
}

[UnsupportedOSPlatform("browser")]
public TimeSpan KeepAliveTimeout
rzikm marked this conversation as resolved.
Show resolved Hide resolved
{
get => _keepAliveTimeout;
set
{
ThrowIfReadOnly();
if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero)
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
{
throw new ArgumentOutOfRangeException(nameof(value), value,
SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall,
Timeout.InfiniteTimeSpan.ToString()));
}
_keepAliveTimeout = value;
}
}

/// <summary>
/// Gets or sets the options for the per-message-deflate extension.
/// When present, the options are sent to the server during the handshake phase. If the server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio
IsServer = false,
SubProtocol = subprotocol,
KeepAliveInterval = options.KeepAliveInterval,
KeepAliveTimeout = options.KeepAliveTimeout,
DangerousDeflateOptions = negotiatedDeflateOptions
});
_negotiatedDeflateOptions = negotiatedDeflateOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ public static void KeepAliveInterval_Roundtrips()
AssertExtensions.Throws<ArgumentOutOfRangeException>("value", () => cws.Options.KeepAliveInterval = TimeSpan.MinValue);
}

[ConditionalFact(nameof(WebSocketsSupported))]
[SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")]
public static void KeepAliveTimeout_Roundtrips()
{
var cws = new ClientWebSocket();
Assert.True(cws.Options.KeepAliveTimeout == Timeout.InfiniteTimeSpan);

cws.Options.KeepAliveTimeout = TimeSpan.Zero;
Assert.Equal(TimeSpan.Zero, cws.Options.KeepAliveTimeout);

cws.Options.KeepAliveTimeout = TimeSpan.MaxValue;
Assert.Equal(TimeSpan.MaxValue, cws.Options.KeepAliveTimeout);

cws.Options.KeepAliveTimeout = Timeout.InfiniteTimeSpan;
Assert.Equal(Timeout.InfiniteTimeSpan, cws.Options.KeepAliveTimeout);

AssertExtensions.Throws<ArgumentOutOfRangeException>("value", () => cws.Options.KeepAliveTimeout = TimeSpan.MinValue);
}

[ConditionalFact(nameof(WebSocketsSupported))]
[SkipOnPlatform(TestPlatforms.Browser, "Certificates not supported on browser")]
public void RemoteCertificateValidationCallback_Roundtrips()
Expand Down
11 changes: 8 additions & 3 deletions src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,11 @@ await LoopbackServer.CreateClientAndServerAsync(async uri =>
try
{
using (var cws = new ClientWebSocket())
using (var cts = new CancellationTokenSource(TimeOutMilliseconds))
using (var testTimeoutCts = new CancellationTokenSource(TimeOutMilliseconds))
{
await ConnectAsync(cws, uri, cts.Token);
await ConnectAsync(cws, uri, testTimeoutCts.Token);

Task receiveTask = cws.ReceiveAsync(new byte[1], CancellationToken.None);
Task receiveTask = cws.ReceiveAsync(new byte[1], testTimeoutCts.Token);

var cancelCloseCts = new CancellationTokenSource();
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
Expand All @@ -509,7 +509,12 @@ await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
await t;
});

Assert.True(cancelCloseCts.Token.IsCancellationRequested);
Assert.False(testTimeoutCts.Token.IsCancellationRequested);

await Assert.ThrowsAnyAsync<OperationCanceledException>(() => receiveTask);

Assert.False(testTimeoutCts.Token.IsCancellationRequested);
}
}
finally
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Net.Test.Common;
using System.Threading;
using System.Threading.Tasks;

using Xunit;
using Xunit.Abstractions;

namespace System.Net.WebSockets.Client.Tests
{
[SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")]
public abstract class KeepAliveTest_Loopback : ClientWebSocketTestBase
{
public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { }

protected virtual Version HttpVersion => Net.HttpVersion.Version11;

public static readonly object[][] UseSsl_MemberData = PlatformDetection.SupportsAlpn
? new[] { new object[] { false }, new object[] { true } }
: new[] { new object[] { false } };

[Theory]
[MemberData(nameof(UseSsl_MemberData))]
public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl)
{
var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
var serverMsg = new byte[] { 42 };
var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var longDelayByServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
TimeSpan LongDelay = TimeSpan.FromSeconds(10);

var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);

var options = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker())
{
DisposeServerWebSocket = true,
DisposeClientWebSocket = true,
ConfigureClientOptions = clientOptions =>
{
clientOptions.KeepAliveInterval = TimeSpan.FromSeconds(100);
clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1);
},
};

return LoopbackWebSocketServer.RunAsync(
async (clientWebSocket, token) =>
{
await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token);

// We need to always have a read task active to keep processing pongs
var outstandingReadTask = clientWebSocket.ReceiveAsync(Array.Empty<byte>(), token);

await longDelayByServerTcs.Task.WaitAsync(token);

var result = await outstandingReadTask;
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
Assert.False(result.EndOfMessage);
Assert.Equal(0, result.Count); // we issued a zero byte read, just to wait for data to become available

Assert.Equal(WebSocketState.Open, clientWebSocket.State);

await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token);

Assert.Equal(WebSocketState.Open, clientWebSocket.State);

await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token);

Assert.Equal(WebSocketState.Closed, clientWebSocket.State);
},
async (serverWebSocket, token) =>
{
await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token);

Assert.Equal(WebSocketState.Open, serverWebSocket.State);

await Task.Delay(LongDelay);

Assert.Equal(WebSocketState.Open, serverWebSocket.State);

// recreate already-completed TCS for another round
clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

longDelayByServerTcs.SetResult();

await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token);

var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty<byte>(), token);
Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType);
Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State);

await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token);
Assert.Equal(WebSocketState.Closed, serverWebSocket.State);
},
options,
timeoutCts.Token);
}

private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg,
TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken)
{
var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken);

var recvBuf = new byte[remoteMsg.Length * 2];
var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false);

Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType);
Assert.Equal(remoteMsg.Length, recvResult.Count);
Assert.True(recvResult.EndOfMessage);
Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]);

localAckTcs.SetResult();

await sendTask.ConfigureAwait(false);
await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false);
}
}

// --- HTTP/1.1 WebSocket loopback tests ---

public class KeepAliveTest_Invoker_Loopback : KeepAliveTest_Loopback
{
public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { }
protected override bool UseCustomInvoker => true;
}

public class KeepAliveTest_HttpClient_Loopback : KeepAliveTest_Loopback
{
public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { }
protected override bool UseHttpClient => true;
}

public class KeepAliveTest_SharedHandler_Loopback : KeepAliveTest_Loopback
{
public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { }
}

// --- HTTP/2 WebSocket loopback tests ---

public class KeepAliveTest_Invoker_Http2 : KeepAliveTest_Invoker_Loopback
{
public KeepAliveTest_Invoker_Http2(ITestOutputHelper output) : base(output) { }
protected override Version HttpVersion => Net.HttpVersion.Version20;
}

public class KeepAliveTest_HttpClient_Http2 : KeepAliveTest_HttpClient_Loopback
{
public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { }
protected override Version HttpVersion => Net.HttpVersion.Version20;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using Xunit;
using Xunit.Abstractions;

using static System.Net.Test.Common.Configuration.WebSockets;

namespace System.Net.WebSockets.Client.Tests
{
[SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")]
Expand All @@ -20,7 +22,7 @@ public KeepAliveTest(ITestOutputHelper output) : base(output) { }
[OuterLoop] // involves long delay
public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds()
{
using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(System.Net.Test.Common.Configuration.WebSockets.RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1)))
using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1)))
{
await cws.SendAsync(new ArraySegment<byte>(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None);

Expand All @@ -33,5 +35,35 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds()
await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None);
}
}

[ConditionalTheory(nameof(WebSocketsSupported))]
[OuterLoop] // involves long delay
[InlineData(1, 0)] // unsolicited pong
[InlineData(1, 2)] // ping/pong
public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec)
{
using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(
RemoteEchoServer,
TimeOutMilliseconds,
_output,
options =>
{
options.KeepAliveInterval = TimeSpan.FromSeconds(keepAliveIntervalSec);
options.KeepAliveTimeout = TimeSpan.FromSeconds(keepAliveTimeoutSec);
}))
{
byte[] receiveBuffer = new byte[1];
var receiveTask = cws.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None); // this will wait until we trigger the echo server by sending a message

await Task.Delay(TimeSpan.FromSeconds(10));

await cws.SendAsync(new ArraySegment<byte>(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None);

Assert.Equal(1, (await receiveTask).Count);
Assert.Equal(42, receiveBuffer[0]);

await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ public static async Task<ClientWebSocket> GetConnectedClientAsync(Uri uri, Optio
clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; };
}

options.ConfigureClientOptions?.Invoke(clientWebSocket.Options);

await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false);

return clientWebSocket;
Expand All @@ -143,6 +145,7 @@ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker
public bool DisposeClientWebSocket { get; set; }
public bool DisposeHttpInvoker { get; set; }
public bool ManualServerHandshakeResponse { get; set; }
public Action<ClientWebSocketOptions>? ConfigureClientOptions { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
<Compile Include="ConnectTest.Http2.cs" />
<Compile Include="ConnectTest.cs" />
<Compile Include="KeepAliveTest.cs" />
<Compile Include="KeepAliveTest.Loopback.cs" />
<Compile Include="LoopbackHelper.cs" />
<Compile Include="LoopbackServer\Http2LoopbackStream.cs" />
<Compile Include="LoopbackServer\LoopbackWebSocketServer.cs" />
Expand Down
Loading
Loading