Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the per connection overhead in SocketConnection #31308

Merged
merged 3 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,49 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;

namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
{
internal class SocketAwaitableEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion
// A slimmed down version of https://github.com/dotnet/runtime/blob/82ca681cbac89d813a3ce397e0c665e6c051ed67/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs#L798 that
// 1. Doesn't support any custom scheduling other than the PipeScheduler (no sync context, no task scheduler)
// 2. Doesn't do ValueTask validation using the token
// 3. Doesn't support usage outside of async/await (doesn't try to capture and restore the execution context)
// 4. Doesn't use cancellation tokens
internal class SocketAwaitableEventArgs : SocketAsyncEventArgs, IValueTaskSource<int>
{
private static readonly Action _callbackCompleted = () => { };
private static readonly Action<object?> _continuationCompleted = _ => { };

private readonly PipeScheduler _ioScheduler;

private Action? _callback;
private Action<object?>? _continuation;

public SocketAwaitableEventArgs(PipeScheduler ioScheduler)
: base(unsafeSuppressExecutionContextFlow: true)
{
_ioScheduler = ioScheduler;
}

public SocketAwaitableEventArgs GetAwaiter() => this;
public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted);

public int GetResult()
protected override void OnCompleted(SocketAsyncEventArgs _)
{
Debug.Assert(ReferenceEquals(_callback, _callbackCompleted));
var c = _continuation;

if (c != null || (c = Interlocked.CompareExchange(ref _continuation, _continuationCompleted, null)) != null)
{
var continuationState = UserToken;
UserToken = null;
_continuation = _continuationCompleted; // in case someone's polling IsCompleted

_callback = null;
_ioScheduler.Schedule(c, continuationState);
}
}

public int GetResult(short token)
{
_continuation = null;

if (SocketError != SocketError.Success)
{
Expand All @@ -43,36 +55,30 @@ public int GetResult()

static void ThrowSocketException(SocketError e)
{
throw new SocketException((int)e);
throw CreateException(e);
}
}

public void OnCompleted(Action continuation)
protected static SocketException CreateException(SocketError e)
{
if (ReferenceEquals(_callback, _callbackCompleted) ||
ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted))
{
Task.Run(continuation);
}
return new SocketException((int)e);
}

public void UnsafeOnCompleted(Action continuation)
public ValueTaskSourceStatus GetStatus(short token)
{
OnCompleted(continuation);
return !ReferenceEquals(_continuation, _continuationCompleted) ? ValueTaskSourceStatus.Pending :
SocketError == SocketError.Success ? ValueTaskSourceStatus.Succeeded :
ValueTaskSourceStatus.Faulted;
}

public void Complete()
public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
OnCompleted(this);
}

protected override void OnCompleted(SocketAsyncEventArgs _)
{
var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted);

if (continuation != null)
UserToken = state;
var prevContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
if (ReferenceEquals(prevContinuation, _continuationCompleted))
{
_ioScheduler.Schedule(state => ((Action)state!)(), continuation);
UserToken = null;
ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ internal sealed class SocketConnection : TransportConnection
private readonly object _shutdownLock = new object();
private volatile bool _socketDisposed;
private volatile Exception? _shutdownReason;
private Task? _processingTask;
private Task? _sendingTask;
private Task? _receivingTask;
private readonly TaskCompletionSource _waitForConnectionClosedTcs = new TaskCompletionSource();
private bool _connectionClosed;
private readonly bool _waitForData;
Expand Down Expand Up @@ -78,28 +79,16 @@ internal SocketConnection(Socket socket,
public override MemoryPool<byte> MemoryPool { get; }

public void Start()
{
_processingTask = StartAsync();
}

private async Task StartAsync()
{
try
{
// Spawn send and receive logic
var receiveTask = DoReceive();
var sendTask = DoSend();

// Now wait for both to complete
await receiveTask;
await sendTask;

_receiver.Dispose();
_sender?.Dispose();
_receivingTask = DoReceive();
_sendingTask = DoSend();
}
catch (Exception ex)
{
_trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}.");
_trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}.");
}
}

Expand All @@ -118,9 +107,28 @@ public override async ValueTask DisposeAsync()
_originalTransport.Input.Complete();
_originalTransport.Output.Complete();

if (_processingTask != null)
try
{
// Now wait for both to complete
if (_receivingTask != null)
{
await _receivingTask;
}

if (_sendingTask != null)
{
await _sendingTask;
}

}
catch (Exception ex)
{
_trace.LogError(0, ex, $"Unexpected exception in {nameof(SocketConnection)}.{nameof(Start)}.");
}
finally
{
await _processingTask;
_receiver.Dispose();
_sender?.Dispose();
}

_connectionClosedTokenSource.Dispose();
Expand All @@ -132,7 +140,50 @@ private async Task DoReceive()

try
{
await ProcessReceives();
while (true)
{
if (_waitForData)
{
// Wait for data before allocating a buffer.
await _receiver.WaitForDataAsync(_socket);
}

// Ensure we have some reasonable amount of buffer space
var buffer = Input.GetMemory(MinAllocBufferSize);

var bytesReceived = await _receiver.ReceiveAsync(_socket, buffer);

if (bytesReceived == 0)
{
// FIN
_trace.ConnectionReadFin(ConnectionId);
break;
}

Input.Advance(bytesReceived);

var flushTask = Input.FlushAsync();

var paused = !flushTask.IsCompleted;

if (paused)
{
_trace.ConnectionPause(ConnectionId);
}

var result = await flushTask;

if (paused)
{
_trace.ConnectionResume(ConnectionId);
}

if (result.IsCompleted || result.IsCanceled)
{
// Pipe consumer is shut down, do we stop writing
break;
}
}
}
catch (SocketException ex) when (IsConnectionResetError(ex.SocketErrorCode))
{
Expand Down Expand Up @@ -176,64 +227,40 @@ private async Task DoReceive()
}
}

private async Task ProcessReceives()
{
// Resolve `input` PipeWriter via the IDuplexPipe interface prior to loop start for performance.
var input = Input;
while (true)
{
if (_waitForData)
{
// Wait for data before allocating a buffer.
await _receiver.WaitForDataAsync(_socket);
}

// Ensure we have some reasonable amount of buffer space
var buffer = input.GetMemory(MinAllocBufferSize);

var bytesReceived = await _receiver.ReceiveAsync(_socket, buffer);

if (bytesReceived == 0)
{
// FIN
_trace.ConnectionReadFin(ConnectionId);
break;
}

input.Advance(bytesReceived);

var flushTask = input.FlushAsync();

var paused = !flushTask.IsCompleted;

if (paused)
{
_trace.ConnectionPause(ConnectionId);
}

var result = await flushTask;

if (paused)
{
_trace.ConnectionResume(ConnectionId);
}

if (result.IsCompleted || result.IsCanceled)
{
// Pipe consumer is shut down, do we stop writing
break;
}
}
}

private async Task DoSend()
{
Exception? shutdownReason = null;
Exception? unexpectedError = null;

try
{
await ProcessSends();
while (true)
{
var result = await Output.ReadAsync();

if (result.IsCanceled)
{
break;
}
var buffer = result.Buffer;

if (!buffer.IsEmpty)
{
_sender = _socketSenderPool.Rent();
await _sender.SendAsync(_socket, buffer);
// We don't return to the pool if there was an exception, and
// we keep the _sender assigned so that we can dispose it in StartAsync.
_socketSenderPool.Return(_sender);
_sender = null;
}

Output.AdvanceTo(buffer.End);

if (result.IsCompleted)
{
break;
}
}
}
catch (SocketException ex) when (IsConnectionResetError(ex.SocketErrorCode))
{
Expand Down Expand Up @@ -265,42 +292,6 @@ private async Task DoSend()
}
}

private async Task ProcessSends()
{
// Resolve `output` PipeReader via the IDuplexPipe interface prior to loop start for performance.
var output = Output;
while (true)
{
var result = await output.ReadAsync();

if (result.IsCanceled)
{
break;
}

var buffer = result.Buffer;

var end = buffer.End;
var isCompleted = result.IsCompleted;
if (!buffer.IsEmpty)
{
_sender = _socketSenderPool.Rent();
await _sender.SendAsync(_socket, buffer);
// We don't return to the pool if there was an exception, and
// we keep the _sender assigned so that we can dispose it in StartAsync.
_socketSenderPool.Return(_sender);
_sender = null;
}

output.AdvanceTo(end);

if (isCompleted)
{
break;
}
}
}

private void FireConnectionClosed()
{
// Guard against scheduling this multiple times
Expand Down
Loading