Skip to content

Commit

Permalink
add Task-based DisconnectAsync and refactor APM methods on top of it (#…
Browse files Browse the repository at this point in the history
…51213)

* add Task-based DisconnectAsync and refactor APM methods on top of it

* fix BeginDisconnect to throw synchronously and add relevant tests

* remove #region stuff in Socket.cs and add link to github issue

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
  • Loading branch information
geoffkizer and Geoffrey Kizer authored Apr 18, 2021
1 parent d2daf0b commit 8093c52
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 267 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ public void Connect(string host, int port) { }
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public void Disconnect(bool reuseSocket) { }
public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
<Compile Include="System\Net\Sockets\UdpReceiveResult.cs" />
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.cs" />
<Compile Include="System\Net\Sockets\UnixDomainSocketEndPoint.cs" />
<!-- Logging -->
<Compile Include="$(CommonPath)System\Net\Logging\NetEventSource.Common.cs"
Expand Down Expand Up @@ -187,7 +186,6 @@
<ItemGroup Condition="'$(TargetsUnix)' == 'true'">
<Compile Include="System\Net\Sockets\AcceptOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\BaseOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\DisconnectOverlappedAsyncResult.Unix.cs" />
<Compile Include="System\Net\Sockets\SafeSocketHandle.Unix.cs" />
<Compile Include="System\Net\Sockets\Socket.Unix.cs" />
<Compile Include="System\Net\Sockets\SocketAsyncContext.Unix.cs" />
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat
return ConnectAsync(ep, cancellationToken);
}

/// <summary>
/// Disconnects a connected socket from the remote host.
/// </summary>
/// <param name="reuseSocket">Indicates whether the socket should be available for reuse after disconnect.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes when the socket is disconnected.</returns>
public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

saea.DisconnectReuseSocket = reuseSocket;
saea.WrapExceptionsForNetworkStream = false;

return saea.DisconnectAsync(this, cancellationToken);
}

/// <summary>
/// Receives data from a connected socket.
/// </summary>
Expand Down Expand Up @@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket)
ValueTask.FromException(CreateException(error));
}

public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.DisconnectAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask(this, _token);
}

SocketError error = SocketError;

Release();

return error == SocketError.Success ?
ValueTask.CompletedTask :
ValueTask.FromException(CreateException(error));
}

/// <summary>Gets the status of the operation.</summary>
public ValueTaskSourceStatus GetStatus(short token)
{
Expand Down
119 changes: 36 additions & 83 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ private sealed class CacheSet
private int _closeTimeout = Socket.DefaultCloseTimeout;
private int _disposed; // 0 == false, anything else == true

#region Constructors
public Socket(SocketType socketType, ProtocolType protocolType)
: this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType)
{
Expand Down Expand Up @@ -242,9 +241,10 @@ private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) =>
handle is null ? throw new ArgumentNullException(nameof(handle)) :
handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) :
handle;
#endregion

#region Properties
//
// Properties
//

// The CLR allows configuration of these properties, separately from whether the OS supports IPv4/6. We
// do not provide these config options, so SupportsIPvX === OSSupportsIPvX.
Expand Down Expand Up @@ -761,9 +761,10 @@ internal bool CanTryAddressFamily(AddressFamily family)
{
return (family == _addressFamily) || (family == AddressFamily.InterNetwork && IsDualMode);
}
#endregion

#region Public Methods
//
// Public Methods
//

// Associates a socket with an end point.
public void Bind(EndPoint localEP)
Expand Down Expand Up @@ -2116,43 +2117,14 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);

public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state)
public void EndConnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

// Start context-flowing op. No need to lock - we don't use the context till the callback.
DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback);
asyncResult.StartPostingAsyncOp(false);

// Post the disconnect.
DoBeginDisconnect(reuseSocket, asyncResult);

// Finish flowing (or call the callback), and return.
asyncResult.FinishPostingAsyncOp();
return asyncResult;
TaskToApm.End(asyncResult);
}

private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError errorCode = SocketError.Success;

errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult);

if (errorCode == SocketError.Success)
{
SetToDisconnected();
_remoteEndPoint = null;
_localEndPoint = null;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}");

// If the call failed, update our status and throw
if (!CheckErrorAndUpdateStatus(errorCode))
{
throw new SocketException((int)errorCode);
}
}
public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
TaskToApmBeginWithSyncExceptions(DisconnectAsync(reuseSocket).AsTask(), callback, state);

public void Disconnect(bool reuseSocket)
{
Expand All @@ -2175,47 +2147,12 @@ public void Disconnect(bool reuseSocket)
_localEndPoint = null;
}

public void EndConnect(IAsyncResult asyncResult)
public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}

public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();

if (asyncResult == null)
{
throw new ArgumentNullException(nameof(asyncResult));
}

//get async result and check for errors
LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult;
if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
{
throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
}
if (castedAsyncResult.EndCalled)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect)));
}

//wait for completion if it hasn't occurred
castedAsyncResult.InternalWaitForCompletion();
castedAsyncResult.EndCalled = true;

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this);

//
// if the asynchronous native call failed asynchronously
// we'll throw a SocketException
//
if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
{
UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode);
}
}

public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
Expand Down Expand Up @@ -2668,7 +2605,10 @@ public void Shutdown(SocketShutdown how)
InternalSetBlocking(_willBlockInternal);
}

#region Async methods
//
// Async methods
//

public bool AcceptAsync(SocketAsyncEventArgs e)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -2889,7 +2829,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e)
e.CancelConnectAsync();
}

public bool DisconnectAsync(SocketAsyncEventArgs e)
public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default);

private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
// Throw if socket disposed
ThrowIfDisposed();
Expand All @@ -2904,7 +2846,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
SocketError socketError = SocketError.Success;
try
{
socketError = e.DoOperationDisconnect(this, _handle);
socketError = e.DoOperationDisconnect(this, _handle, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -3155,10 +3097,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT

return socketError == SocketError.IOPending;
}
#endregion
#endregion

#region Internal and private properties
//
// Internal and private properties
//

private CacheSet Caches
{
Expand All @@ -3174,9 +3116,10 @@ private CacheSet Caches
}

internal bool Disposed => _disposed != 0;
#endregion

#region Internal and private methods
//
// Internal and private methods
//

internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6)
{
Expand Down Expand Up @@ -3889,6 +3832,16 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t)
};
}

#endregion
// Helper to maintain existing behavior of Socket APM methods to throw synchronously from Begin*.
private static IAsyncResult TaskToApmBeginWithSyncExceptions(Task task, AsyncCallback? callback, object? state)
{
if (task.IsFaulted)
{
task.GetAwaiter().GetResult();
Debug.Fail("Task faulted but GetResult did not throw???");
}

return TaskToApm.Begin(task, callback, state);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h
return socketError;
}

internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
FinishOperationSync(socketError, 0, SocketFlags.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,11 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
}
}

internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
// Note: CancellationToken is ignored for now.
// See https://github.com/dotnet/runtime/issues/51452

NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
Expand Down Expand Up @@ -1188,6 +1191,7 @@ private unsafe SocketError FinishOperationConnect()
private void CompleteCore()
{
_strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially

if (_singleBufferHandleState != SingleBufferHandleState.None)
{
// If the state isn't None, then either it's Set, in which case there's state to cleanup,
Expand All @@ -1213,6 +1217,8 @@ void CompleteCoreSpin()
sw.SpinOnce();
}

Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

// Remove any cancellation registration. First dispose the registration
// to ensure that cancellation will either never fine or will have completed
// firing before we continue. Only then can we safely null out the overlapped.
Expand All @@ -1223,6 +1229,8 @@ void CompleteCoreSpin()
}

// Release any GC handles.
Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);

if (_singleBufferHandleState == SingleBufferHandleState.Set)
{
_singleBufferHandleState = SingleBufferHandleState.None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa
return socketError;
}

internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
{
SocketError socketError = Disconnect(socket, handle, reuseSocket);
asyncResult.PostCompletion(socketError);
return socketError;
}

internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
handle.SetToDisconnected();
Expand Down
Loading

0 comments on commit 8093c52

Please sign in to comment.