Skip to content

Commit

Permalink
Surface exceptions thrown by LoopbackSocksServer (#56396)
Browse files Browse the repository at this point in the history
  • Loading branch information
MihaZupan authored Jul 27, 2021
1 parent 59e7258 commit d61aeca
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net.Sockets;
using System.Net.Test.Common;
using System.Runtime.ExceptionServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.Http.Functional.Tests
{
/// <summary>
/// Provides a test-only SOCKS4/5 proxy.
/// </summary>
internal class LoopbackSocksServer : IDisposable
internal class LoopbackSocksServer : IAsyncDisposable
{
private readonly Socket _listener;
private readonly ManualResetEvent _serverStopped;
private bool _disposed;

private int _connections;
public int Connections => _connections;
private readonly List<Task> _connectionTasks = new();
private readonly TaskCompletionSource _serverStopped = new(TaskCreationOptions.RunContinuationsAsynchronously);

public int Port { get; }

private string? _username, _password;

private LoopbackSocksServer(string? username = null, string? password = null)
public LoopbackSocksServer(string? username = null, string? password = null)
{
if (password != null && username == null)
{
Expand All @@ -40,74 +35,37 @@ private LoopbackSocksServer(string? username = null, string? password = null)

_listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
_listener.Listen(int.MaxValue);
_listener.Listen();

var ep = (IPEndPoint)_listener.LocalEndPoint;
Port = ep.Port;

_serverStopped = new ManualResetEvent(false);
}

private void Start()
{
Task.Run(async () =>
{
var activeTasks = new ConcurrentDictionary<Task, int>();
try
while (true)
{
while (true)
try
{
Socket s = await _listener.AcceptAsync().ConfigureAwait(false);
var connectionTask = Task.Run(async () =>
_connectionTasks.Add(Task.Run(async () =>
{
try
{
await ProcessConnection(s).ConfigureAwait(false);
}
catch (Exception ex)
using (var ns = new NetworkStream(s, ownsSocket: true))
{
EventSourceTestLogging.Log.TestAncillaryError(ex);
await ProcessRequest(s, ns).ConfigureAwait(false);
}
});
activeTasks.TryAdd(connectionTask, 0);
_ = connectionTask.ContinueWith(t => activeTasks.TryRemove(connectionTask, out _), TaskContinuationOptions.ExecuteSynchronously);
}));
}
catch
{
break;
}
}
catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted)
{
// caused during Dispose() to cancel the loop. ignore.
}
catch (Exception ex)
{
EventSourceTestLogging.Log.TestAncillaryError(ex);
}
try
{
await Task.WhenAll(activeTasks.Keys).ConfigureAwait(false);
}
catch (Exception ex)
{
EventSourceTestLogging.Log.TestAncillaryError(ex);
}
_serverStopped.Set();
_serverStopped.SetResult();
});
}

private async Task ProcessConnection(Socket s)
{
Interlocked.Increment(ref _connections);

using (var ns = new NetworkStream(s, ownsSocket: true))
{
await ProcessRequest(s, ns).ConfigureAwait(false);
}
}

private async Task ProcessRequest(Socket clientSocket, NetworkStream ns)
{
int version = await ns.ReadByteAsync().ConfigureAwait(false);
Expand Down Expand Up @@ -344,21 +302,27 @@ private async ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
}
}

public static LoopbackSocksServer Create(string? username = null, string? password = null)
public async ValueTask DisposeAsync()
{
var server = new LoopbackSocksServer(username, password);
server.Start();
_listener.Dispose();
await _serverStopped.Task;

return server;
}
List<Exception> exceptions = new();
foreach (Task task in _connectionTasks)
{
try
{
await task;
}
catch (Exception ex)
{
exceptions.Add(ex);
}
}

public void Dispose()
{
if (!_disposed)
if (exceptions.Count > 0)
{
_listener.Dispose();
_serverStopped.WaitOne();
_disposed = true;
throw new AggregateException(exceptions);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Test.Common;
using System.Threading.Tasks;
Expand Down Expand Up @@ -38,7 +37,7 @@ public async Task TestLoopbackAsync(string scheme, bool useSsl, bool useAuth, st
await LoopbackServerFactory.CreateClientAndServerAsync(
async uri =>
{
using LoopbackSocksServer proxy = useAuth ? LoopbackSocksServer.Create("DOTNET", "424242") : LoopbackSocksServer.Create();
await using var proxy = useAuth ? new LoopbackSocksServer("DOTNET", "424242") : new LoopbackSocksServer();
using HttpClientHandler handler = CreateHttpClientHandler();
using HttpClient client = CreateHttpClient(handler);
Expand Down Expand Up @@ -93,7 +92,7 @@ public static IEnumerable<object[]> TestExceptionalAsync_MemberData()
[MemberData(nameof(TestExceptionalAsync_MemberData))]
public async Task TestExceptionalAsync(string scheme, string host, bool useAuth, ICredentials? credentials, string exceptionMessage)
{
using LoopbackSocksServer proxy = useAuth ? LoopbackSocksServer.Create("DOTNET", "424242") : LoopbackSocksServer.Create();
var proxy = useAuth ? new LoopbackSocksServer("DOTNET", "424242") : new LoopbackSocksServer();
using HttpClientHandler handler = CreateHttpClientHandler();
using HttpClient client = CreateHttpClient(handler);

Expand All @@ -109,6 +108,12 @@ public async Task TestExceptionalAsync(string scheme, string host, bool useAuth,
var innerException = ex.InnerException;
Assert.Equal(exceptionMessage, innerException.Message);
Assert.Equal("SocksException", innerException.GetType().Name);

try
{
await proxy.DisposeAsync();
}
catch { }
}
}

Expand Down

0 comments on commit d61aeca

Please sign in to comment.