From 3c17dc52fa393799d30eb41abdf03bd41c81aeb0 Mon Sep 17 00:00:00 2001 From: Friedrich von Never Date: Wed, 22 Sep 2021 00:53:15 +0700 Subject: [PATCH] (#109) XmppTcpConnection: cancellable TCP client connections --- SharpXMPP.Shared/Compat/TcpClientEx.cs | 59 ++++++++++++++++++++++++++ SharpXMPP.Shared/Resolver.cs | 17 +++++--- SharpXMPP.Shared/XmppTcpConnection.cs | 29 ++++++++++--- 3 files changed, 91 insertions(+), 14 deletions(-) create mode 100644 SharpXMPP.Shared/Compat/TcpClientEx.cs diff --git a/SharpXMPP.Shared/Compat/TcpClientEx.cs b/SharpXMPP.Shared/Compat/TcpClientEx.cs new file mode 100644 index 0000000..94b1889 --- /dev/null +++ b/SharpXMPP.Shared/Compat/TcpClientEx.cs @@ -0,0 +1,59 @@ +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace SharpXMPP.Compat +{ + internal static class TcpClientEx + { +#if !NET5_0_OR_GREATER + private static async Task AbandonOnCancel(this Task task, CancellationToken cancellationToken) + { + // See https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/ for details. + var tcs = new TaskCompletionSource(); + using (cancellationToken.Register(() => tcs.TrySetResult(true))) + { + if (task != await Task.WhenAny(task, tcs.Task)) + { + throw new OperationCanceledException(cancellationToken); + } + } + } +#endif + + // Unfortunately, only .NET 5+ supports TcpClient connection cancellation. We'll do the best effort here, + // though. + // + // TcpClient uses Socket::BeginConnect under the covers for all the connection methods on .NET Framework, which + // is documented to be cancelled on Close(). So, ideally, if the caller eventually disposes the client, then all + // the resources will be freed upon its destruction. Which means we are free to just abandon the task in + // question. + + public static Task ConnectWithCancellationAsync( + this TcpClient tcpClient, + IPAddress address, + int port, + CancellationToken cancellationToken) + { +#if NET5_0_OR_GREATER + return tcpClient.ConnectAsync(address, port, cancellationToken).AsTask(); +#else + return tcpClient.ConnectAsync(address, port).AbandonOnCancel(cancellationToken); +#endif + } + + public static Task ConnectWithCancellationAsync( + this TcpClient tcpClient, + IPAddress[] addresses, + int port, + CancellationToken cancellationToken) + { +#if NET5_0_OR_GREATER + return tcpClient.ConnectAsync(addresses, port, cancellationToken).AsTask(); +#else + return tcpClient.ConnectAsync(addresses, port).AbandonOnCancel(cancellationToken); +#endif + } + } +} diff --git a/SharpXMPP.Shared/Resolver.cs b/SharpXMPP.Shared/Resolver.cs index f193039..2be8b0f 100644 --- a/SharpXMPP.Shared/Resolver.cs +++ b/SharpXMPP.Shared/Resolver.cs @@ -4,7 +4,9 @@ using System.Net; using System.Net.Sockets; using System.Text; +using System.Threading; using System.Threading.Tasks; +using SharpXMPP.Compat; namespace SharpXMPP { @@ -18,12 +20,14 @@ public struct SRVRecord public static class Resolver { - public async static Task> ResolveXMPPClient(string domain) + public static Task> ResolveXMPPClient(string domain) => + ResolveXMPPClient(domain, default); + + public static async Task> ResolveXMPPClient(string domain, CancellationToken cancellationToken) { - var result = new List(); - var client = new TcpClient(); - await client.ConnectAsync(IPAddress.Parse("1.1.1.1"), 53); - var stream = client.GetStream(); + using var client = new TcpClient(); + await client.ConnectWithCancellationAsync(IPAddress.Parse("1.1.1.1"), 53, cancellationToken); + using var stream = client.GetStream(); var message = EncodeQuery(domain); var lengthPrefix = IPAddress.HostToNetworkOrder((short)message.Length); var lengthPrefixBytes = BitConverter.GetBytes(lengthPrefix); @@ -34,8 +38,7 @@ public async static Task> ResolveXMPPClient(string domain) stream.Read(responseLengthBytes, 0, 2); var responseMessage = new byte[IPAddress.NetworkToHostOrder(BitConverter.ToInt16(responseLengthBytes, 0))]; stream.Read(responseMessage, 0, responseMessage.Length); - result = Decode(responseMessage); - return result; + return Decode(responseMessage); } private static byte[] EncodeQuery(string domain) diff --git a/SharpXMPP.Shared/XmppTcpConnection.cs b/SharpXMPP.Shared/XmppTcpConnection.cs index 8ebc5ee..2721f97 100644 --- a/SharpXMPP.Shared/XmppTcpConnection.cs +++ b/SharpXMPP.Shared/XmppTcpConnection.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using System.Xml; using System.Xml.Linq; +using SharpXMPP.Compat; using SharpXMPP.Errors; using SharpXMPP.XMPP; using SharpXMPP.XMPP.Bind; @@ -143,8 +144,8 @@ public void SessionLoop() public override async Task ConnectAsync(CancellationToken token) { - List HostAddresses = await ResolveHostAddresses(); - await ConnectOverTcp(HostAddresses); + List HostAddresses = await ResolveHostAddresses(token); + await ConnectOverTcp(HostAddresses, token); RestartXmlStreams(); @@ -201,11 +202,23 @@ public Task SessionLoopAsync(CancellationToken token) private Features GetServerFeatures() => Stanza.Parse(NextElement()); - private async Task ConnectOverTcp(List HostAddresses) + private async Task ConnectOverTcp(List HostAddresses, CancellationToken cancellationToken) { + ((IDisposable)_client)?.Dispose(); + _client = new TcpClient(); - await _client.ConnectAsync(HostAddresses.ToArray(), TcpPort); // TODO: check ports - ConnectionStream = _client.GetStream(); + try + { + // TODO: check ports + await _client.ConnectWithCancellationAsync(HostAddresses.ToArray(), TcpPort, cancellationToken); + ConnectionStream = _client.GetStream(); + } + catch + { + ((IDisposable)_client).Dispose(); + _client = null; + throw; + } } private Task StartAuthentication(Features features) @@ -244,21 +257,23 @@ private Task StartAuthentication(Features features) } - private async Task> ResolveHostAddresses() + private async Task> ResolveHostAddresses(CancellationToken cancellationToken) { List HostAddresses = new List(); - var srvs = await Resolver.ResolveXMPPClient(Jid.Domain); + var srvs = await Resolver.ResolveXMPPClient(Jid.Domain, cancellationToken); if (srvs.Any()) { foreach (var srv in srvs) { + cancellationToken.ThrowIfCancellationRequested(); var addresses = await Dns.GetHostAddressesAsync(srv.Host); HostAddresses.AddRange(addresses); } } else { + cancellationToken.ThrowIfCancellationRequested(); HostAddresses.AddRange(await Dns.GetHostAddressesAsync(Jid.Domain)); }