Skip to content

Commit

Permalink
Do not throw ObjectDisposeException from Socket.EndXyz methods (#73335)
Browse files Browse the repository at this point in the history
Fix #61411 and harmonize argument exception handling.
  • Loading branch information
antonfirsov authored Aug 5, 2022
1 parent f999d0a commit 5d4526d
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 154 deletions.
55 changes: 9 additions & 46 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2248,11 +2248,7 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);

public void EndConnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}
public void EndConnect(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);

public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
TaskToApm.Begin(DisconnectAsync(reuseSocket).AsTask(), callback, state);
Expand All @@ -2278,12 +2274,7 @@ public void Disconnect(bool reuseSocket)
_localEndPoint = null;
}

public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}

public void EndDisconnect(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);

public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
Expand Down Expand Up @@ -2331,12 +2322,7 @@ public IAsyncResult BeginSend(IList<ArraySegment<byte>> buffers, SocketFlags soc
return TaskToApm.Begin(t, callback, state);
}

public int EndSend(IAsyncResult asyncResult)
{
ThrowIfDisposed();

return TaskToApm.End<int>(asyncResult);
}
public int EndSend(IAsyncResult asyncResult) => TaskToApm.End<int>(asyncResult);

public int EndSend(IAsyncResult asyncResult, out SocketError errorCode) =>
EndSendReceive(asyncResult, out errorCode);
Expand All @@ -2360,13 +2346,7 @@ public IAsyncResult BeginSendFile(string? fileName, byte[]? preBuffer, byte[]? p
return TaskToApm.Begin(SendFileAsync(fileName, preBuffer, postBuffer, flags).AsTask(), callback, state);
}

public void EndSendFile(IAsyncResult asyncResult)
{
ThrowIfDisposed();
ArgumentNullException.ThrowIfNull(asyncResult);

TaskToApm.End(asyncResult);
}
public void EndSendFile(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);

public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback? callback, object? state)
{
Expand All @@ -2378,11 +2358,7 @@ public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags
return TaskToApm.Begin(t, callback, state);
}

public int EndSendTo(IAsyncResult asyncResult)
{
ThrowIfDisposed();
return TaskToApm.End<int>(asyncResult);
}
public int EndSendTo(IAsyncResult asyncResult) => TaskToApm.End<int>(asyncResult);

public IAsyncResult BeginReceive(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
Expand Down Expand Up @@ -2428,21 +2404,16 @@ public IAsyncResult BeginReceive(IList<ArraySegment<byte>> buffers, SocketFlags
return TaskToApm.Begin(t, callback, state);
}

public int EndReceive(IAsyncResult asyncResult)
{
ThrowIfDisposed();
return TaskToApm.End<int>(asyncResult);
}
public int EndReceive(IAsyncResult asyncResult) => TaskToApm.End<int>(asyncResult);

public int EndReceive(IAsyncResult asyncResult, out SocketError errorCode) =>
EndSendReceive(asyncResult, out errorCode);

private int EndSendReceive(IAsyncResult asyncResult, out SocketError errorCode)
private static int EndSendReceive(IAsyncResult asyncResult, out SocketError errorCode)
{
ThrowIfDisposed();

if (TaskToApm.GetTask(asyncResult) is not Task<int> ti)
{
ArgumentNullException.ThrowIfNull(asyncResult);
throw new ArgumentException(null, nameof(asyncResult));
}

Expand Down Expand Up @@ -2485,7 +2456,6 @@ public IAsyncResult BeginReceiveMessageFrom(byte[] buffer, int offset, int size,

public int EndReceiveMessageFrom(IAsyncResult asyncResult, ref SocketFlags socketFlags, ref EndPoint endPoint, out IPPacketInformation ipPacketInformation)
{
ThrowIfDisposed();
ArgumentNullException.ThrowIfNull(endPoint);
if (!CanTryAddressFamily(endPoint.AddressFamily))
{
Expand Down Expand Up @@ -2522,8 +2492,6 @@ public IAsyncResult BeginReceiveFrom(byte[] buffer, int offset, int size, Socket

public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
{
ThrowIfDisposed();

ArgumentNullException.ThrowIfNull(endPoint);
if (!CanTryAddressFamily(endPoint.AddressFamily))
{
Expand All @@ -2541,11 +2509,7 @@ public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
public IAsyncResult BeginAccept(AsyncCallback? callback, object? state) =>
TaskToApm.Begin(AcceptAsync(), callback, state);

public Socket EndAccept(IAsyncResult asyncResult)
{
ThrowIfDisposed();
return TaskToApm.End<Socket>(asyncResult);
}
public Socket EndAccept(IAsyncResult asyncResult) => TaskToApm.End<Socket>(asyncResult);

// This method provides support for legacy BeginAccept methods that take a "receiveSize" argument and
// allow data to be received as part of the accept operation.
Expand Down Expand Up @@ -2599,7 +2563,6 @@ public Socket EndAccept(out byte[] buffer, IAsyncResult asyncResult)

public Socket EndAccept(out byte[] buffer, out int bytesTransferred, IAsyncResult asyncResult)
{
ThrowIfDisposed();
Socket s;
(s, buffer, bytesTransferred) = TaskToApm.End<(Socket, byte[], int)>(asyncResult);
return s;
Expand Down
57 changes: 46 additions & 11 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Threading;
using System.Threading.Tasks;
using Microsoft.DotNet.RemoteExecutor;
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;
Expand Down Expand Up @@ -333,7 +334,7 @@ await RetryHelper.ExecuteAsync(async () =>
await disposeTask;
SocketError? localSocketError = null;
bool disposedException = false;
try
{
await acceptTask;
Expand All @@ -342,17 +343,8 @@ await RetryHelper.ExecuteAsync(async () =>
{
localSocketError = se.SocketErrorCode;
}
catch (ObjectDisposedException)
{
disposedException = true;
}
if (UsesApm)
{
Assert.Null(localSocketError);
Assert.True(disposedException);
}
else if (UsesSync)
if (UsesSync)
{
Assert.Equal(SocketError.Interrupted, localSocketError);
}
Expand Down Expand Up @@ -401,6 +393,49 @@ public AcceptSyncForceNonBlocking(ITestOutputHelper output) : base(output) {}
public sealed class AcceptApm : Accept<SocketHelperApm>
{
public AcceptApm(ITestOutputHelper output) : base(output) {}

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void AbortedByDispose_LeaksNoUnobservedExceptions()
{
RemoteExecutor.Invoke(static async () =>
{
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
socket.BindToAnonymousPort(IPAddress.Loopback);
socket.Listen(10);
bool unobservedThrown = false;
TaskScheduler.UnobservedTaskException += (_, __) => unobservedThrown = true;
await Task.Run(() =>
{
socket.BeginAccept(asyncResult =>
{
try
{
socket.EndAccept(asyncResult);
}
catch
{
}
}, socket);
});
// Give some time for the Accept operation to start
await Task.Delay(30);
// Close the socket aborting Accept
socket.Dispose();
// Wait for the internal AcceptAsync Task to complete with the exception.
await Task.Delay(30);
// Ensure that the internal TaskExceptionHolder is finalized and the exception published to UnobservedTaskException.
GC.Collect();
GC.WaitForPendingFinalizers();
Assert.False(unobservedThrown);
}).Dispose();
}
}

public sealed class AcceptTask : Accept<SocketHelperTask>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,7 @@ await RetryHelper.ExecuteAsync(async () =>
disposedException = true;
}
if (UsesApm)
{
Assert.Null(localSocketError);
Assert.True(disposedException);
}
else if (UsesSync)
if (UsesSync)
{
Assert.True(disposedException || localSocketError == SocketError.NotSocket, $"{disposedException} {localSocketError}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,16 @@ public void BeginConnect_Host_Throws_ObjectDisposed()
}

[Fact]
public void EndConnect_Throws_ObjectDisposed()
public void EndConnect_Throws_SocketException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndConnect(null));
using Socket notListening = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
notListening.BindToAnonymousPort(IPAddress.Loopback);

Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var iar = socket.BeginConnect(notListening.LocalEndPoint, null, null);
socket.Dispose();

Assert.Throws<SocketException>(() => socket.EndConnect(iar));
}

[Fact]
Expand All @@ -656,25 +663,12 @@ public void BeginSend_Buffers_SocketError_Throws_ObjectDisposedException()
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().BeginSend(s_buffers, SocketFlags.None, TheAsyncCallback, null));
}

[Fact]
public void EndSend_Throws_ObjectDisposedException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndSend(null));
}

[Fact]
public void BeginSendTo_Throws_ObjectDisposedException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().BeginSendTo(s_buffer, 0, s_buffer.Length, SocketFlags.None, new IPEndPoint(IPAddress.Loopback, 1), TheAsyncCallback, null));
}

[Fact]
public void EndSendTo_Throws_ObjectDisposedException()
{
// Behavior difference: EndSendTo_Throws_ObjectDisposed
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndSendTo(null));
}

[Fact]
public void BeginReceive_Buffer_Throws_ObjectDisposedException()
{
Expand All @@ -700,9 +694,17 @@ public void BeginReceive_Buffers_SocketError_Throws_ObjectDisposedException()
}

[Fact]
public void EndReceive_Throws_ObjectDisposedException()
public void EndReceive_Throws_SocketException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndReceive(null));
(Socket a, Socket b) = SocketTestExtensions.CreateConnectedSocketPair();

using (b)
{
var iar = a.BeginReceive(new byte[1], 0, 1, SocketFlags.None, null, null);
a.Dispose();

Assert.Throws<SocketException>(() => a.EndReceive(iar));
}
}

[Fact]
Expand All @@ -713,10 +715,16 @@ public void BeginReceiveFrom_Throws_ObjectDisposedException()
}

[Fact]
public void EndReceiveFrom_Throws_ObjectDisposedException()
public void EndReceiveFrom_Throws_SocketException()
{
EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1);
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndReceiveFrom(null, ref remote));
(Socket sender, Socket receiver, EndPoint senderEp) = CreateUdpSocketPair();

using (sender)
{
var iar = receiver.BeginReceiveFrom(new byte[1], 0, 1, SocketFlags.None, ref senderEp, null, null);
receiver.Dispose();
Assert.Throws<SocketException>(() => receiver.EndReceiveFrom(iar, ref senderEp));
}
}

[Fact]
Expand All @@ -726,25 +734,49 @@ public void BeginReceiveMessageFrom_Throws_ObjectDisposedException()
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().BeginReceiveMessageFrom(s_buffer, 0, s_buffer.Length, SocketFlags.None, ref remote, TheAsyncCallback, null));
}

private static (Socket a, Socket b, IPEndPoint aEndPoint) CreateUdpSocketPair()
{
Socket a = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
int port = a.BindToAnonymousPort(IPAddress.Loopback);
IPEndPoint aEndPoint = new IPEndPoint(IPAddress.Loopback, port);
Socket b = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
b.BindToAnonymousPort(IPAddress.Loopback);
return (a, b, aEndPoint);
}

[Fact]
public void EndReceiveMessageFrom_Throws_ObjectDisposedException()
public void EndReceiveMessageFrom_Throws_SocketException()
{
SocketFlags flags = SocketFlags.None;
EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1);
IPPacketInformation packetInfo;
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndReceiveMessageFrom(null, ref flags, ref remote, out packetInfo));
(Socket sender, Socket receiver, EndPoint senderEp) = CreateUdpSocketPair();

using (sender)
{
var iar = receiver.BeginReceiveMessageFrom(new byte[1], 0, 1, SocketFlags.None, ref senderEp, null, null);
receiver.Dispose();

SocketFlags flags = SocketFlags.None;
EndPoint remote = new IPEndPoint(IPAddress.Loopback, 1);
IPPacketInformation packetInfo;
Assert.Throws<SocketException>(() => receiver.EndReceiveMessageFrom(iar, ref flags, ref remote, out packetInfo));
}
}

[Fact]
public void BeginAccept_Throws_ObjectDisposed()
public void BeginAccept_Throws_ObjectDisposedException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().BeginAccept(TheAsyncCallback, null));
}

[Fact]
public void EndAccept_Throws_ObjectDisposed()
public void EndAccept_Throws_SocketException()
{
Assert.Throws<ObjectDisposedException>(() => GetDisposedSocket().EndAccept(null));
Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
listener.BindToAnonymousPort(IPAddress.Loopback);
listener.Listen();
var iar = listener.BeginAccept(null, null);
listener.Dispose();

Assert.Throws<SocketException>(() => listener.EndAccept(iar));
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,10 @@ async Task RunTestAsync()
if (closeOrDispose) socket.Close();
else socket.Dispose();

if (DisposeDuringOperationResultsInDisposedException)
{
await Assert.ThrowsAsync<ObjectDisposedException>(() => receiveTask)
SocketException ex = await Assert.ThrowsAsync<SocketException>(() => receiveTask)
.WaitAsync(CancellationTestTimeout);
}
else
{
SocketException ex = await Assert.ThrowsAsync<SocketException>(() => receiveTask)
.WaitAsync(CancellationTestTimeout);
SocketError expectedError = UsesSync ? SocketError.Interrupted : SocketError.OperationAborted;
Assert.Equal(expectedError, ex.SocketErrorCode);
}
SocketError expectedError = UsesSync ? SocketError.Interrupted : SocketError.OperationAborted;
Assert.Equal(expectedError, ex.SocketErrorCode);
}
}

Expand Down
Loading

0 comments on commit 5d4526d

Please sign in to comment.