Skip to content

Commit

Permalink
(#109) XmppTcpConnection: better connection cancellation, better term…
Browse files Browse the repository at this point in the history
…ination
  • Loading branch information
ForNeVeR committed Sep 23, 2021
1 parent 2d05994 commit a20980e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 28 deletions.
24 changes: 24 additions & 0 deletions SharpXMPP.Shared/Compat/SslStreamEx.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Net.Security;
using System.Threading;
using System.Threading.Tasks;

namespace SharpXMPP.Compat
{
public static class SslStreamEx
{
public static Task AuthenticateAsClientWithCancellationAsync(
this SslStream sslStream,
string targetHost,
CancellationToken cancellationToken)
{
#if NET5_0_OR_GREATER
return sslStream.AuthenticateAsClientAsync(
new SslClientAuthenticationOptions { TargetHost = targetHost },
cancellationToken);
#else
// No cancellation on older runtimes :(
return sslStream.AuthenticateAsClientAsync(targetHost);
#endif
}
}
}
85 changes: 57 additions & 28 deletions SharpXMPP.Shared/XmppTcpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace SharpXMPP
{
public class XmppTcpConnection : XmppConnection
{

private object _terminationLock = new object();
private TcpClient _client;

protected virtual int TcpPort
Expand Down Expand Up @@ -90,6 +90,24 @@ public override void Send(XElement data)
Writer.Flush();
}

private void TerminateTcpConnection()
{
// There are two callers for this method: connection timeout and external dispose. This lock is placed in
// case of a race condition between the two.
lock (_terminationLock)
{
// NOTE: Client is explicitly Disposable on older runtimes, so cast is required.
((IDisposable)_client)?.Dispose();
_client = null;
Writer?.Dispose();
Writer = null;
Reader?.Dispose();
Reader = null;
ConnectionStream?.Dispose();
ConnectionStream = null;
}
}

// In what context this method should be used?
public void Close()
{
Expand All @@ -101,15 +119,7 @@ public override void Dispose()
if (!_disposed)
{
_disposed = true;
// NOTE: used this statement because faced issue with compilation under net451
((IDisposable)_client)?.Dispose();
_client = null;
Writer?.Dispose();
Writer = null;
Reader?.Dispose();
Reader = null;
ConnectionStream?.Dispose();
ConnectionStream = null;
TerminateTcpConnection();
}
base.Dispose();
}
Expand Down Expand Up @@ -150,13 +160,13 @@ public override async Task ConnectAsync(CancellationToken token)
RestartXmlStreams();

Features features = GetServerFeatures();
var tlsSupported = await InitTlsIfSupported(features);
var tlsSupported = await InitTlsIfSupported(features, token);
if (tlsSupported)
{
features = GetServerFeatures();
}

await StartAuthentication(features);
await StartAuthentication(features, token);
}

public Task SessionLoopAsync() => SessionLoopAsync(CancellationToken.None);
Expand Down Expand Up @@ -204,7 +214,7 @@ public Task SessionLoopAsync(CancellationToken token)

private async Task ConnectOverTcp(List<IPAddress> HostAddresses, CancellationToken cancellationToken)
{
((IDisposable)_client)?.Dispose();
TerminateTcpConnection();

_client = new TcpClient();
try
Expand All @@ -221,38 +231,55 @@ private async Task ConnectOverTcp(List<IPAddress> HostAddresses, CancellationTok
}
}

private Task StartAuthentication(Features features)
private Task StartAuthentication(Features features, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
Task.Run(() =>

void RunCatching(Action act)
{
try
{
act();
}
catch (Exception ex)
{
tcs.SetException(ex);
}
}

Task.Run(() => RunCatching(() =>
{
var authenticator = SASLHandler.Create(features.SaslMechanisms, Jid, Password);
if (authenticator == null)
{
OnConnectionFailed(new ConnFailedArgs { Message = "supported sasl mechanism not available" });
tcs.SetResult(false);
tcs.TrySetResult(false);
return;
}
authenticator.Authenticated += sender =>
authenticator.Authenticated += _ => RunCatching(() =>
{
RestartXmlStreams();
var session = new SessionHandler();
session.SessionStarted += connection =>
session.SessionStarted += connection => RunCatching(() =>
{
OnSignedIn(new SignedInArgs { Jid = connection.Jid });
tcs.SetResult(true);
};
tcs.TrySetResult(true);
});
// TODO make async
// Locks stream with SessionLoop
session.Start(this);
};
authenticator.AuthenticationFailed += sender =>
});
authenticator.AuthenticationFailed += _ => RunCatching(() =>
{
OnConnectionFailed(new ConnFailedArgs { Message = "Authentication failed" });
tcs.SetResult(true);
};
authenticator.Start(this);
});
tcs.TrySetResult(true);
});

using (cancellationToken.Register(TerminateTcpConnection))
authenticator.Start(this);

tcs.TrySetResult(false);
}), cancellationToken);
return tcs.Task;
}

Expand Down Expand Up @@ -280,7 +307,7 @@ private async Task<List<IPAddress>> ResolveHostAddresses(CancellationToken cance
return HostAddresses;
}

private async Task<bool> InitTlsIfSupported(Features features)
private async Task<bool> InitTlsIfSupported(Features features, CancellationToken cancellationToken)
{
if (!features.Tls)
{
Expand All @@ -295,7 +322,9 @@ private async Task<bool> InitTlsIfSupported(Features features)
}

ConnectionStream = new SslStream(ConnectionStream, true);
await ((SslStream)ConnectionStream).AuthenticateAsClientAsync(Jid.Domain);
await ((SslStream)ConnectionStream).AuthenticateAsClientWithCancellationAsync(
Jid.Domain,
cancellationToken);
RestartXmlStreams();
return true;
}
Expand Down

0 comments on commit a20980e

Please sign in to comment.