diff --git a/src/DtronixMessageQueue.Tests/Mq/MqClientTests.cs b/src/DtronixMessageQueue.Tests/Mq/MqClientTests.cs index b75304d..1cb668f 100644 --- a/src/DtronixMessageQueue.Tests/Mq/MqClientTests.cs +++ b/src/DtronixMessageQueue.Tests/Mq/MqClientTests.cs @@ -292,7 +292,7 @@ public void Client_times_out_while_connecting_for_too_long() { } }; - StartAndWait(false, 1000, false); + StartAndWait(false, 10000, false); if (TestStatus.IsSet == false) { throw new Exception("Socket did not timeout."); diff --git a/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs b/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs index cfc5111..d6e81cc 100644 --- a/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs +++ b/src/DtronixMessageQueue.Tests/Rpc/RpcClientTests.cs @@ -266,9 +266,9 @@ public void Client_notified_of_authentication_success() { } [Fact] - public void Client_times_out_on_auth_failure() { + public void Client_times_out_on_long_auth() { Server.Config.RequireAuthentication = true; - Server.Config.ConnectionTimeout = 100; + Client.Config.ConnectionTimeout = 100; Client.Closed += (sender, e) => { if (e.CloseReason != SocketCloseReason.TimeOut) { @@ -277,11 +277,12 @@ public void Client_times_out_on_auth_failure() { TestStatus.Set(); }; - Client.Authenticate += (sender, e) => { - Thread.Sleep(200); + Server.Authenticate += (sender, e) => + { + Thread.Sleep(500); }; - StartAndWait(); + StartAndWait(true, 5000, true); } [Fact] diff --git a/src/DtronixMessageQueue.Tests/Rpc/RpcTestsBase.cs b/src/DtronixMessageQueue.Tests/Rpc/RpcTestsBase.cs index 825a48a..e97a308 100644 --- a/src/DtronixMessageQueue.Tests/Rpc/RpcTestsBase.cs +++ b/src/DtronixMessageQueue.Tests/Rpc/RpcTestsBase.cs @@ -45,8 +45,8 @@ public static int FreeTcpPort() { return port; } - public void StartAndWait(bool timeout_error = true, int timeout_length = -1) { - if (Server.IsRunning == false) { + public void StartAndWait(bool timeout_error = true, int timeout_length = -1, bool start_server = true) { + if (Server.IsRunning == false && start_server) { Server.Start(); } if (Client.IsRunning == false) { diff --git a/src/DtronixMessageQueue/MqClient.cs b/src/DtronixMessageQueue/MqClient.cs index cedfa05..f187285 100644 --- a/src/DtronixMessageQueue/MqClient.cs +++ b/src/DtronixMessageQueue/MqClient.cs @@ -104,6 +104,10 @@ public void Send(MqMessage message) { } public void Close() { + if (Session == null) + { + return; + } Session.IncomingMessage -= OnIncomingMessage; Session.Close(SocketCloseReason.ClientClosing); Session.Dispose(); diff --git a/src/DtronixMessageQueue/Rpc/RpcSession.cs b/src/DtronixMessageQueue/Rpc/RpcSession.cs index 5be6ba8..d775bff 100644 --- a/src/DtronixMessageQueue/Rpc/RpcSession.cs +++ b/src/DtronixMessageQueue/Rpc/RpcSession.cs @@ -64,6 +64,10 @@ public abstract class RpcSession : MqSession public bool Authenticated { get; private set; } + private Task auth_timeout; + + private CancellationTokenSource auth_timeout_cancel = new CancellationTokenSource(); + protected RpcSession() { MessageHandlers = new Dictionary>(); } @@ -152,9 +156,27 @@ protected override void ProcessCommand(MqFrame frame) { var auth_message = serializer.MessageWriter.ToMessage(true); auth_message[0].FrameType = MqFrameType.Command; - // RpcCommand:byte; RpcCommandType:byte; AuthData:byte[]; - Send(auth_message); - } else { + auth_timeout = new Task(async () => + { + try + { + await Task.Delay(Config.ConnectionTimeout, auth_timeout_cancel.Token); + } + catch + { + return; + } + + if(!auth_timeout_cancel.IsCancellationRequested) + Close(SocketCloseReason.TimeOut); + }); + + // RpcCommand:byte; RpcCommandType:byte; AuthData:byte[]; + Send(auth_message); + + auth_timeout.Start(); + + } else { // If no authentication is required, set this client to authenticated. Authenticated = true; @@ -216,6 +238,9 @@ protected override void ProcessCommand(MqFrame frame) { } else if (rpc_command_type == RpcCommandType.AuthenticationResult) { // RpcCommand:byte; RpcCommandType:byte; AuthResult:bool; + // Cancel the timeout request. + auth_timeout_cancel.Cancel(); + // Ensure that this command is running on the client. if (BaseSocket.Mode != SocketMode.Client) { Close(SocketCloseReason.ProtocolError); diff --git a/src/DtronixMessageQueue/Socket/SocketClient.cs b/src/DtronixMessageQueue/Socket/SocketClient.cs index 2e8d584..09f052f 100644 --- a/src/DtronixMessageQueue/Socket/SocketClient.cs +++ b/src/DtronixMessageQueue/Socket/SocketClient.cs @@ -2,8 +2,10 @@ using System.Net; using System.Net.Sockets; using System.Threading; +using System.Threading.Tasks; -namespace DtronixMessageQueue.Socket { +namespace DtronixMessageQueue.Socket +{ /// /// Base functionality for all client connections to a remote server. @@ -12,7 +14,8 @@ namespace DtronixMessageQueue.Socket { /// Configuration for this connection. public class SocketClient : SessionHandler where TSession : SocketSession, new() - where TConfig : SocketConfig { + where TConfig : SocketConfig + { /// /// True if the client is connected to a server. @@ -29,57 +32,83 @@ public class SocketClient : SessionHandler /// Creates a socket client with the specified configurations. /// /// Configurations to use. - public SocketClient(TConfig config) : base(config, SocketMode.Client) { + public SocketClient(TConfig config) : base(config, SocketMode.Client) + { } /// /// Connects to the configured endpoint. /// - public void Connect() { + public void Connect() + { Connect(new IPEndPoint(IPAddress.Parse(Config.Ip), Config.Port)); } - + /// + /// Task which will run when a connection times out. + /// + private Task connection_timeout_task; + + /// + /// Cancellation token to cancel the timeout event for connections. + /// + private CancellationTokenSource connection_timeout_cancellation; /// /// Connects to the specified endpoint. /// /// Endpoint to connect to. - public void Connect(IPEndPoint end_point) { - if (MainSocket != null && Session?.CurrentState != SocketSession.State.Closed) { + public void Connect(IPEndPoint end_point) + { + if (MainSocket != null && Session?.CurrentState != SocketSession.State.Closed) + { throw new InvalidOperationException("Client is in the process of connecting."); } - MainSocket = new System.Net.Sockets.Socket(end_point.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { + MainSocket = new System.Net.Sockets.Socket(end_point.AddressFamily, SocketType.Stream, ProtocolType.Tcp) + { NoDelay = true }; + // Set to true if the client connection either timed out or was canceled. bool timed_out = false; + connection_timeout_cancellation = new CancellationTokenSource(); - Timer connection_timer = null; + connection_timeout_task = new Task(async () => + { + try + { + await Task.Delay(Config.ConnectionTimeout, connection_timeout_cancellation.Token); + } + catch + { + return; + } - connection_timer = new Timer(o => { timed_out = true; - MainSocket.Close(); - connection_timer.Change(Timeout.Infinite, Timeout.Infinite); - connection_timer.Dispose(); - OnClose(null, SocketCloseReason.TimeOut); + MainSocket.Close(); }); - var event_arg = new SocketAsyncEventArgs { + + + + + var event_arg = new SocketAsyncEventArgs + { RemoteEndPoint = end_point }; event_arg.Completed += (sender, args) => { - if (timed_out) { + if (timed_out) + { return; } - if (args.LastOperation == SocketAsyncOperation.Connect) { + if (args.LastOperation == SocketAsyncOperation.Connect) + { // Stop the timeout timer. - connection_timer.Change(Timeout.Infinite, Timeout.Infinite); - connection_timer.Dispose(); + connection_timeout_cancellation.Cancel(); Session = CreateSession(MainSocket); Session.Connected += (sndr, e) => OnConnect(Session); @@ -90,21 +119,24 @@ public void Connect(IPEndPoint end_point) { } }; - + MainSocket.ConnectAsync(event_arg); - connection_timer.Change(Config.ConnectionTimeout, Timeout.Infinite); + connection_timeout_task.Start(); + } - protected override void OnClose(TSession session, SocketCloseReason reason) { + protected override void OnClose(TSession session, SocketCloseReason reason) + { MainSocket.Close(); TSession sess_out; // If the session is null, the connection timed out while trying to connect. - if (session != null) { + if (session != null) + { ConnectedSessions.TryRemove(Session.Id, out sess_out); }