Skip to content

Commit

Permalink
make Socket useable after cancellation (#99181)
Browse files Browse the repository at this point in the history
  • Loading branch information
wfurt authored May 10, 2024
1 parent 95a0265 commit 776ff7e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3754,7 +3754,7 @@ internal void UpdateStatusAfterSocketError(SocketError errorCode, bool disconnec

if (disconnectOnFailure && _isConnected && (_handle.IsInvalid || (errorCode != SocketError.WouldBlock &&
errorCode != SocketError.IOPending && errorCode != SocketError.NoBufferSpaceAvailable &&
errorCode != SocketError.TimedOut)))
errorCode != SocketError.TimedOut && errorCode != SocketError.OperationAborted)))
{
// The socket is no longer a valid socket.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "Invalidating socket.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ private static async Task RunWithConnectedNetworkStreamsAsync(Func<NetworkStream
await Task.WhenAll(remoteTask, clientConnectTask);

using (TcpClient remote = remoteTask.Result)
using (NetworkStream serverStream = new NetworkStream(remote.Client, serverAccess, ownsSocket:true))
using (NetworkStream serverStream = new NetworkStream(remote.Client, serverAccess, ownsSocket: true))
using (NetworkStream clientStream = new NetworkStream(client.Client, clientAccess, ownsSocket: true))
{
await func(serverStream, clientStream);
Expand All @@ -560,6 +560,77 @@ private static async Task RunWithConnectedNetworkStreamsAsync(Func<NetworkStream
}
}

[Fact]
public async Task NetworkStream_ReadTimeout_RemainUseable()
{
using StreamPair streams = await CreateConnectedStreamsAsync();
NetworkStream readable = (NetworkStream)streams.Stream1;

Assert.True(readable.Socket.Connected);
readable.Socket.ReceiveTimeout = TestSettings.FailingTestTimeout;
var buffer = new byte[100];
int readBytes;
try
{
readBytes = readable.Read(buffer);
}
catch (IOException ex) when (ex.InnerException is SocketException && ((SocketException)ex.InnerException).SocketErrorCode == SocketError.TimedOut)
{
}
Assert.True(readable.Socket.Connected);

try
{
readBytes = readable.Read(buffer);
}
catch (IOException ex) when (ex.InnerException is SocketException && ((SocketException)ex.InnerException).SocketErrorCode == SocketError.TimedOut)
{
}
Assert.True(readable.Socket.Connected);

streams.Stream2.Write(new byte[] { 65 });
readBytes = readable.Read(buffer);
Assert.Equal(1, readBytes);
Assert.True(readable.Socket.Connected);
}


[Fact]
public async Task NetworkStream_ReadAsyncTimeout_RemainUseable()
{
using StreamPair streams = await CreateConnectedStreamsAsync();
NetworkStream readable = (NetworkStream)streams.Stream1;

Assert.True(readable.Socket.Connected);

CancellationTokenSource cts = new CancellationTokenSource(TestSettings.FailingTestTimeout);
var buffer = new byte[100];
int readBytes;
try
{
readBytes = await readable.ReadAsync(buffer, cts.Token);
}
catch (OperationCanceledException)
{
}
Assert.True(readable.Socket.Connected);

try
{
cts = new CancellationTokenSource(TestSettings.FailingTestTimeout);
readBytes = await readable.ReadAsync(buffer, cts.Token);
}
catch (OperationCanceledException)
{
}
Assert.True(readable.Socket.Connected);

await streams.Stream2.WriteAsync(new byte[] { 65 });
readBytes = await readable.ReadAsync(buffer);
Assert.Equal(1, readBytes);
Assert.True(readable.Socket.Connected);
}

private sealed class DerivedNetworkStream : NetworkStream
{
public DerivedNetworkStream(Socket socket) : base(socket) { }
Expand Down

0 comments on commit 776ff7e

Please sign in to comment.