Skip to content

Commit

Permalink
(#109) XmppTcpConnection: cancellable TCP client connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ForNeVeR committed Sep 21, 2021
1 parent 3cb2ddb commit 3c17dc5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
59 changes: 59 additions & 0 deletions SharpXMPP.Shared/Compat/TcpClientEx.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace SharpXMPP.Compat
{
internal static class TcpClientEx
{
#if !NET5_0_OR_GREATER
private static async Task AbandonOnCancel(this Task task, CancellationToken cancellationToken)
{
// See https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/ for details.
var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(() => tcs.TrySetResult(true)))
{
if (task != await Task.WhenAny(task, tcs.Task))
{
throw new OperationCanceledException(cancellationToken);
}
}
}
#endif

// Unfortunately, only .NET 5+ supports TcpClient connection cancellation. We'll do the best effort here,
// though.
//
// TcpClient uses Socket::BeginConnect under the covers for all the connection methods on .NET Framework, which
// is documented to be cancelled on Close(). So, ideally, if the caller eventually disposes the client, then all
// the resources will be freed upon its destruction. Which means we are free to just abandon the task in
// question.

public static Task ConnectWithCancellationAsync(
this TcpClient tcpClient,
IPAddress address,
int port,
CancellationToken cancellationToken)
{
#if NET5_0_OR_GREATER
return tcpClient.ConnectAsync(address, port, cancellationToken).AsTask();
#else
return tcpClient.ConnectAsync(address, port).AbandonOnCancel(cancellationToken);
#endif
}

public static Task ConnectWithCancellationAsync(
this TcpClient tcpClient,
IPAddress[] addresses,
int port,
CancellationToken cancellationToken)
{
#if NET5_0_OR_GREATER
return tcpClient.ConnectAsync(addresses, port, cancellationToken).AsTask();
#else
return tcpClient.ConnectAsync(addresses, port).AbandonOnCancel(cancellationToken);
#endif
}
}
}
17 changes: 10 additions & 7 deletions SharpXMPP.Shared/Resolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using SharpXMPP.Compat;

namespace SharpXMPP
{
Expand All @@ -18,12 +20,14 @@ public struct SRVRecord

public static class Resolver
{
public async static Task<List<SRVRecord>> ResolveXMPPClient(string domain)
public static Task<List<SRVRecord>> ResolveXMPPClient(string domain) =>
ResolveXMPPClient(domain, default);

public static async Task<List<SRVRecord>> ResolveXMPPClient(string domain, CancellationToken cancellationToken)
{
var result = new List<SRVRecord>();
var client = new TcpClient();
await client.ConnectAsync(IPAddress.Parse("1.1.1.1"), 53);
var stream = client.GetStream();
using var client = new TcpClient();
await client.ConnectWithCancellationAsync(IPAddress.Parse("1.1.1.1"), 53, cancellationToken);
using var stream = client.GetStream();
var message = EncodeQuery(domain);
var lengthPrefix = IPAddress.HostToNetworkOrder((short)message.Length);
var lengthPrefixBytes = BitConverter.GetBytes(lengthPrefix);
Expand All @@ -34,8 +38,7 @@ public async static Task<List<SRVRecord>> ResolveXMPPClient(string domain)
stream.Read(responseLengthBytes, 0, 2);
var responseMessage = new byte[IPAddress.NetworkToHostOrder(BitConverter.ToInt16(responseLengthBytes, 0))];
stream.Read(responseMessage, 0, responseMessage.Length);
result = Decode(responseMessage);
return result;
return Decode(responseMessage);
}

private static byte[] EncodeQuery(string domain)
Expand Down
29 changes: 22 additions & 7 deletions SharpXMPP.Shared/XmppTcpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading.Tasks;
using System.Xml;
using System.Xml.Linq;
using SharpXMPP.Compat;
using SharpXMPP.Errors;
using SharpXMPP.XMPP;
using SharpXMPP.XMPP.Bind;
Expand Down Expand Up @@ -143,8 +144,8 @@ public void SessionLoop()

public override async Task ConnectAsync(CancellationToken token)
{
List<IPAddress> HostAddresses = await ResolveHostAddresses();
await ConnectOverTcp(HostAddresses);
List<IPAddress> HostAddresses = await ResolveHostAddresses(token);
await ConnectOverTcp(HostAddresses, token);

RestartXmlStreams();

Expand Down Expand Up @@ -201,11 +202,23 @@ public Task SessionLoopAsync(CancellationToken token)

private Features GetServerFeatures() => Stanza.Parse<Features>(NextElement());

private async Task ConnectOverTcp(List<IPAddress> HostAddresses)
private async Task ConnectOverTcp(List<IPAddress> HostAddresses, CancellationToken cancellationToken)
{
((IDisposable)_client)?.Dispose();

_client = new TcpClient();
await _client.ConnectAsync(HostAddresses.ToArray(), TcpPort); // TODO: check ports
ConnectionStream = _client.GetStream();
try
{
// TODO: check ports
await _client.ConnectWithCancellationAsync(HostAddresses.ToArray(), TcpPort, cancellationToken);
ConnectionStream = _client.GetStream();
}
catch
{
((IDisposable)_client).Dispose();
_client = null;
throw;
}
}

private Task StartAuthentication(Features features)
Expand Down Expand Up @@ -244,21 +257,23 @@ private Task StartAuthentication(Features features)
}


private async Task<List<IPAddress>> ResolveHostAddresses()
private async Task<List<IPAddress>> ResolveHostAddresses(CancellationToken cancellationToken)
{
List<IPAddress> HostAddresses = new List<IPAddress>();

var srvs = await Resolver.ResolveXMPPClient(Jid.Domain);
var srvs = await Resolver.ResolveXMPPClient(Jid.Domain, cancellationToken);
if (srvs.Any())
{
foreach (var srv in srvs)
{
cancellationToken.ThrowIfCancellationRequested();
var addresses = await Dns.GetHostAddressesAsync(srv.Host);
HostAddresses.AddRange(addresses);
}
}
else
{
cancellationToken.ThrowIfCancellationRequested();
HostAddresses.AddRange(await Dns.GetHostAddressesAsync(Jid.Domain));
}

Expand Down

0 comments on commit 3c17dc5

Please sign in to comment.