Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket Feedback Follow-up #107662

Merged
merged 9 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,6 @@ internal static partial class WebSocketValidate
private static readonly SearchValues<char> s_validSubprotocolChars =
SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~");

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
=> ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []);

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null)
{
if (validStates is not null && Array.IndexOf(validStates, currentState) == -1)
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates));

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(currentState == WebSocketState.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

internal static void ValidateSubprotocol(string subProtocol)
{
ArgumentException.ThrowIfNullOrWhiteSpace(subProtocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<Compile Include="System\Net\WebSockets\WebSocketMessageFlags.cs" />
<Compile Include="System\Net\WebSockets\WebSocketReceiveResult.cs" />
<Compile Include="System\Net\WebSockets\WebSocketState.cs" />
<Compile Include="System\Net\WebSockets\WebSocketStateHelper.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketDefaults.cs"
Link="Common\System\Net\WebSockets\WebSocketDefaults.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketValidate.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,28 @@ public Task EnterAsync(CancellationToken cancellationToken)
// If cancellation was requested, bail immediately.
// If the mutex is not currently held nor contended, enter immediately.
// Otherwise, fall back to a more expensive likely-asynchronous wait.
return
cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
Interlocked.Decrement(ref _gate) >= 0 ? Task.CompletedTask :
Contended(cancellationToken);

if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}

int gate = Interlocked.Decrement(ref _gate);
if (gate >= 0)
{
return Task.CompletedTask;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Waiting to enter, queue length {-gate}");
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved

return Contended(cancellationToken);

// Everything that follows is the equivalent of:
// return _sem.WaitAsync(cancellationToken);
// if _sem were to be constructed as `new SemaphoreSlim(0)`.

Task Contended(CancellationToken cancellationToken)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

var w = new Waiter(this);

// We need to register for cancellation before storing the waiter into the list.
Expand Down Expand Up @@ -178,18 +187,18 @@ static void OnCancellation(object? state, CancellationToken cancellationToken)
/// <remarks>The caller must logically own the mutex. This is not validated.</remarks>
public void Exit()
{
if (Interlocked.Increment(ref _gate) < 1)
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
int gate = Interlocked.Increment(ref _gate);
if (gate < 1)
{
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Unblocking next waiter on exit, remaining queue length {-_gate}", nameof(Exit));
Contended();
}

void Contended()
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

Waiter? w;
lock (SyncObj)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.ExceptionServices;
Expand All @@ -13,8 +12,8 @@ namespace System.Net.WebSockets
internal sealed partial class ManagedWebSocket : WebSocket
{
private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null;
private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1;
private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1;
private static bool IsValidSendState(WebSocketState state) => WebSocketStateHelper.HasFlag(s_validSendStates, state);
private static bool IsValidReceiveState(WebSocketState state) => WebSocketStateHelper.HasFlag(s_validReceiveStates, state);

private void HeartBeat()
{
Expand All @@ -36,9 +35,9 @@ private void UnsolicitedPongHeartBeat()
TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
}

private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte>? payload = null)
private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte> payload = default)
{
Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping);
Debug.Assert((opcode is MessageOpcode.Pong) || (!IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping));

if (!IsValidSendState(_state))
{
Expand All @@ -48,9 +47,7 @@ private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemor
return ValueTask.CompletedTask;
}

payload ??= ReadOnlyMemory<byte>.Empty;

return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None);
return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload, CancellationToken.None);
}

private void KeepAlivePingHeartBeat()
Expand All @@ -76,7 +73,7 @@ private void KeepAlivePingHeartBeat()

if (_keepAlivePingState.PingSent)
{
if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp)
if (now > _keepAlivePingState.PingTimeoutTimestamp)
{
if (NetEventSource.Log.IsEnabled())
{
Expand All @@ -92,7 +89,7 @@ private void KeepAlivePingHeartBeat()
}
else
{
if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp)
if (now > _keepAlivePingState.NextPingRequestTimestamp)
{
_keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock
shouldSendPing = true;
Expand All @@ -119,18 +116,12 @@ private async ValueTask SendPingAsync(long pingPayload)
{
Debug.Assert(_keepAlivePingState != null);

byte[] pingPayloadBuffer = ArrayPool<byte>.Shared.Rent(sizeof(long));
byte[] pingPayloadBuffer = new byte[sizeof(long)];
BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload);
try
{
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}
finally
{
ArrayPool<byte>.Shared.Return(pingPayloadBuffer);
}
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}

// "Observe" either a ValueTask result, or any exception, ignoring it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ internal sealed partial class ManagedWebSocket : WebSocket
private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

/// <summary>Valid states to be in when calling SendAsync.</summary>
private static readonly WebSocketState[] s_validSendStates = { WebSocketState.Open, WebSocketState.CloseReceived };
private static readonly int s_validSendStates = WebSocketStateHelper.ToFlags(WebSocketState.Open, WebSocketState.CloseReceived);
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>Valid states to be in when calling ReceiveAsync.</summary>
private static readonly WebSocketState[] s_validReceiveStates = { WebSocketState.Open, WebSocketState.CloseSent };
private static readonly int s_validReceiveStates = WebSocketStateHelper.ToFlags(WebSocketState.Open, WebSocketState.CloseSent);
/// <summary>Valid states to be in when calling CloseOutputAsync.</summary>
private static readonly WebSocketState[] s_validCloseOutputStates = { WebSocketState.Open, WebSocketState.CloseReceived };
private static readonly int s_validCloseOutputStates = WebSocketStateHelper.ToFlags(WebSocketState.Open, WebSocketState.CloseReceived);
/// <summary>Valid states to be in when calling CloseAsync.</summary>
private static readonly WebSocketState[] s_validCloseStates = { WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent };
private static readonly int s_validCloseStates = WebSocketStateHelper.ToFlags(WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent);

/// <summary>The maximum size in bytes of a message frame header that includes mask bytes.</summary>
internal const int MaxMessageHeaderLength = 14;
Expand Down Expand Up @@ -797,11 +797,9 @@ private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> paylo

if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateStarted(this, payloadBuffer.Length);

CancellationTokenRegistration registration = default;
CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
try
{
registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);

await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex);

Expand Down Expand Up @@ -1737,9 +1735,9 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa
cancellationToken);
}

private void ThrowIfDisposed() => ThrowIfInvalidState();
private void ThrowIfDisposed() => ThrowIfInvalidState(validStates: WebSocketStateHelper.All);

private void ThrowIfInvalidState(WebSocketState[]? validStates = null)
private void ThrowIfInvalidState(int validStates)
{
bool disposed = _disposed;
WebSocketState state = _state;
Expand All @@ -1758,7 +1756,7 @@ private void ThrowIfInvalidState(WebSocketState[]? validStates = null)

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}");

WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
WebSocketStateHelper.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
}

// From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ internal sealed partial class NetEventSource

private const int MutexEnterId = SendStopId + 1;
private const int MutexExitId = MutexEnterId + 1;
private const int MutexContendedId = MutexExitId + 1;

//
// Keep-Alive
Expand Down Expand Up @@ -185,10 +184,6 @@ private void MutexEnter(string objName, string memberName) =>
private void MutexExit(string objName, string memberName) =>
WriteEvent(MutexExitId, objName, memberName);

[Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
private void MutexContended(string objName, string memberName, int queueLength) =>
WriteEvent(MutexContendedId, objName, memberName, queueLength);

[NonEvent]
public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null)
{
Expand All @@ -203,13 +198,6 @@ public static void MutexExited(object? obj, [CallerMemberName] string? memberNam
Log.MutexExit(IdOf(obj), memberName ?? MissingMember);
}

[NonEvent]
public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null)
{
Debug.Assert(Log.IsEnabled());
Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue);
}

//
// WriteEvent overloads
//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;

namespace System.Net.WebSockets
{
internal static class WebSocketStateHelper
{
internal const int All = (1 << ((int)WebSocketState.Aborted + 1)) - 1;

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, int validStates)
{
if (!HasFlag(validStates, currentState))
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, string.Join(", ", FromFlags(validStates)));

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(currentState == WebSocketState.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

internal static bool HasFlag(int states, WebSocketState value) => (states & 1 << (int)value) != 0;

internal static int ToFlags(params WebSocketState[] values)
{
int states = 0;
foreach (WebSocketState value in values)
{
states |= 1 << (int)value;
}
return states;
}

private static IEnumerable<WebSocketState> FromFlags(int states)
{
foreach (WebSocketState value in Enum.GetValues<WebSocketState>())
{
if (HasFlag(states, value))
{
yield return value;
}
}
}
}
}