From f9e0d38bc864e66b806d9a2b83ef5e5f95eef480 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Mon, 18 Nov 2024 14:07:35 +0100 Subject: [PATCH 1/6] Implement server keep alive. --- README.md | 6 +- src/Tmds.Ssh/SocketSshConnection.cs | 68 ++++++++++++- src/Tmds.Ssh/SshClientSettings.Defaults.cs | 2 + src/Tmds.Ssh/SshClientSettings.SshConfig.cs | 4 +- src/Tmds.Ssh/SshClientSettings.cs | 19 ++++ src/Tmds.Ssh/SshConfig.cs | 12 ++- src/Tmds.Ssh/SshConfigOption.cs | 4 +- src/Tmds.Ssh/SshConnection.cs | 2 + src/Tmds.Ssh/SshConnectionClosedException.cs | 1 + src/Tmds.Ssh/SshSequencePoolExtensions.cs | 12 +++ src/Tmds.Ssh/SshSession.cs | 102 ++++++++++++++++++- 11 files changed, 225 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 712ea139..d38bbc7c 100644 --- a/README.md +++ b/README.md @@ -231,6 +231,8 @@ class SshClientSettings bool AutoReconnect { get; set; } = false; bool TcpKeepAlive { get; set; } = true; + TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero; + public int KeepAliveCountMax = 3; List GlobalKnownHostsFilePaths { get; set; } = DefaultGlobalKnownHostsFilePaths; List UserKnownHostsFilePaths { get; set; } = DefaultUserKnownHostsFilePaths; @@ -281,7 +283,9 @@ public enum SshConfigOption KexAlgorithms, MACs, PubkeyAcceptedAlgorithms, - TCPKeepAlive + TCPKeepAlive, + ServerAliveCountMax, + ServerAliveInterval } struct SshConfigOptionValue { diff --git a/src/Tmds.Ssh/SocketSshConnection.cs b/src/Tmds.Ssh/SocketSshConnection.cs index 05c73c01..c94949aa 100644 --- a/src/Tmds.Ssh/SocketSshConnection.cs +++ b/src/Tmds.Ssh/SocketSshConnection.cs @@ -2,6 +2,7 @@ // See file LICENSE for full license details. using System.Buffers; +using System.Diagnostics; using System.Net.Sockets; using System.Text; using Microsoft.Extensions.Logging; @@ -10,7 +11,6 @@ namespace Tmds.Ssh; sealed class SocketSshConnection : SshConnection { - private static ReadOnlySpan NewLine => new byte[] { (byte)'\r', (byte)'\n' }; private static readonly UTF8Encoding s_utf8Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); @@ -22,6 +22,61 @@ sealed class SocketSshConnection : SshConnection private IPacketEncryptor _encryptor; private uint _sendSequenceNumber; private uint _receiveSequenceNumber; + private int _keepAlivePeriod; + private Action? _keepAliveCallback; + private Timer? _keepAliveTimer; + private int _lastReceivedTime; + + public override void EnableKeepAlive(int period, Action callback) + { + if (period > 0) + { + if (_keepAliveTimer is not null) + { + throw new InvalidOperationException(); + } + + _keepAlivePeriod = period; + _keepAliveCallback = callback; + _lastReceivedTime = GetTime(); + _keepAliveTimer = new Timer(o => ((SocketSshConnection)o!).OnKeepAliveTimerCallback(), this, _keepAlivePeriod, _keepAlivePeriod); + } + } + + private static int GetTime() + => Environment.TickCount; + + private static int GetElapsed(int previous) + => Math.Max(GetTime() - previous, 0); + + private void OnKeepAliveTimerCallback() + { + Debug.Assert(_keepAliveTimer is not null); + Debug.Assert(_keepAliveCallback is not null); + + int elapsedTime = GetElapsed(_lastReceivedTime); + lock (_keepAliveTimer) + { + // Synchronize with dispose. + if (_keepAlivePeriod < 0) + { + return; + } + + if (elapsedTime < _keepAlivePeriod) + { + // Wait for the period to expire. + _keepAliveTimer.Change(_keepAlivePeriod - elapsedTime, _keepAlivePeriod); + return; + } + else + { + _keepAliveTimer.Change(_keepAlivePeriod, _keepAlivePeriod); + } + } + + _keepAliveCallback(); + } public SocketSshConnection(ILogger logger, SequencePool sequencePool, Socket socket) : base(sequencePool) @@ -102,6 +157,8 @@ public async override ValueTask ReceivePacketAsync(CancellationToken ct, { if (_decryptor.TryDecrypt(_receiveBuffer, _receiveSequenceNumber, maxLength, out Packet packet)) { + _lastReceivedTime = GetTime(); + _receiveSequenceNumber++; using Packet p = packet; @@ -163,6 +220,15 @@ public override void SetEncryptorDecryptor(IPacketEncryptor packetEncoder, IPack public override void Dispose() { + if (_keepAliveTimer is not null) + { + lock (_keepAliveTimer) + { + _keepAlivePeriod = -1; + _keepAliveTimer.Dispose(); + } + } + _keepAliveTimer?.Dispose(); _receiveBuffer.Dispose(); _sendBuffer.Dispose(); _encryptor.Dispose(); diff --git a/src/Tmds.Ssh/SshClientSettings.Defaults.cs b/src/Tmds.Ssh/SshClientSettings.Defaults.cs index a561bdab..4fc58992 100644 --- a/src/Tmds.Ssh/SshClientSettings.Defaults.cs +++ b/src/Tmds.Ssh/SshClientSettings.Defaults.cs @@ -35,6 +35,8 @@ partial class SshClientSettings private static bool DefaultTcpKeepAlive => true; + private static int DefaultKeepAliveCountMax => 3; + // Algorithms are in **order of preference**. private readonly static List EmptyList = []; internal readonly static List SupportedKeyExchangeAlgorithms = [ AlgorithmNames.EcdhSha2Nistp256, AlgorithmNames.EcdhSha2Nistp384, AlgorithmNames.EcdhSha2Nistp521 ]; diff --git a/src/Tmds.Ssh/SshClientSettings.SshConfig.cs b/src/Tmds.Ssh/SshClientSettings.SshConfig.cs index cffa3be3..c6120dab 100644 --- a/src/Tmds.Ssh/SshClientSettings.SshConfig.cs +++ b/src/Tmds.Ssh/SshClientSettings.SshConfig.cs @@ -45,7 +45,9 @@ internal static async ValueTask LoadFromConfigAsync(string? u MinimumRSAKeySize = sshConfig.RequiredRSASize ?? DefaultMinimumRSAKeySize, Credentials = DetermineCredentials(sshConfig), HashKnownHosts = sshConfig.HashKnownHosts ?? DefaultHashKnownHosts, - TcpKeepAlive = sshConfig.TcpKeepAlive ?? DefaultTcpKeepAlive + TcpKeepAlive = sshConfig.TcpKeepAlive ?? DefaultTcpKeepAlive, + KeepAliveCountMax = sshConfig.ServerAliveCountMax ?? DefaultKeepAliveCountMax, + KeepAliveInterval = sshConfig.ServerAliveInterval > 0 ? TimeSpan.FromSeconds(sshConfig.ServerAliveInterval.Value) : TimeSpan.Zero, }; if (sshConfig.UserKnownHostsFiles is not null) { diff --git a/src/Tmds.Ssh/SshClientSettings.cs b/src/Tmds.Ssh/SshClientSettings.cs index 2f5a3eed..942b9307 100644 --- a/src/Tmds.Ssh/SshClientSettings.cs +++ b/src/Tmds.Ssh/SshClientSettings.cs @@ -11,6 +11,8 @@ public sealed partial class SshClientSettings private string _userName = ""; private List? _credentials; private TimeSpan _connectTimeout = DefaultConnectTimeout; + private TimeSpan _keepAliveInterval = TimeSpan.Zero; + private int _keepAliveCountMax = 3; private List? _userKnownHostsFilePaths; private List? _globalKnownHostsFilePaths; private Dictionary? _environmentVariables; @@ -107,6 +109,23 @@ public TimeSpan ConnectTimeout } } + public int KeepAliveCountMax + { + get => _keepAliveCountMax; + set + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(value, 1); + _keepAliveCountMax = value; + } + } + + // Zero or < 0 is disabled. + public TimeSpan KeepAliveInterval + { + get => _keepAliveInterval; + set => _keepAliveInterval = value; + } + internal void Validate() { if (_credentials is not null) diff --git a/src/Tmds.Ssh/SshConfig.cs b/src/Tmds.Ssh/SshConfig.cs index 1cddb99c..4ad4f138 100644 --- a/src/Tmds.Ssh/SshConfig.cs +++ b/src/Tmds.Ssh/SshConfig.cs @@ -60,6 +60,8 @@ public struct AlgorithmList public bool? HashKnownHosts { get; set; } public List? SendEnv { get; set; } public bool? TcpKeepAlive { get; set; } + public int? ServerAliveCountMax { get; set; } + public int? ServerAliveInterval { get; set; } internal static ValueTask DetermineConfigForHost(string? userName, string host, int? port, IReadOnlyDictionary? options, IReadOnlyList configFiles, CancellationToken cancellationToken) { @@ -447,6 +449,14 @@ private static void HandleMatchedKeyword(SshConfig config, ReadOnlySpan ke config.TcpKeepAlive ??= ParseYesNoKeywordValue(keyword, ref remainder); break; + case "serveralivecountmax": + config.ServerAliveCountMax ??= NextTokenAsInt(keyword, ref remainder); + break; + + case "serveraliveinterval": + config.ServerAliveInterval ??= NextTokenAsInt(keyword, ref remainder); + break; + /* The following options are unsupported, we have some basic handling that checks the option value indicates the feature is disabled */ case "permitlocalcommand": @@ -539,8 +549,6 @@ we have some basic handling that checks the option value indicates the feature i // case "ipqos": // case "streamlocalbindmask": // case "streamlocalbindunlink": - // case "serveralivecountmax": - // case "serveraliveinterval": // case "setenv": // case "tag": // case "proxycommand": diff --git a/src/Tmds.Ssh/SshConfigOption.cs b/src/Tmds.Ssh/SshConfigOption.cs index 7729a641..cb51a5f6 100644 --- a/src/Tmds.Ssh/SshConfigOption.cs +++ b/src/Tmds.Ssh/SshConfigOption.cs @@ -32,5 +32,7 @@ public enum SshConfigOption KexAlgorithms, MACs, PubkeyAcceptedAlgorithms, - TCPKeepAlive + TCPKeepAlive, + ServerAliveCountMax, + ServerAliveInterval } \ No newline at end of file diff --git a/src/Tmds.Ssh/SshConnection.cs b/src/Tmds.Ssh/SshConnection.cs index 7d1e4d09..5c7bfebc 100644 --- a/src/Tmds.Ssh/SshConnection.cs +++ b/src/Tmds.Ssh/SshConnection.cs @@ -15,6 +15,8 @@ protected SshConnection(SequencePool sequencePool) public SequencePool SequencePool { get; } + public abstract void EnableKeepAlive(int period, Action callback); + public abstract ValueTask ReceiveLineAsync(int maxLength, CancellationToken ct); public abstract ValueTask WriteLineAsync(string line, CancellationToken ct); diff --git a/src/Tmds.Ssh/SshConnectionClosedException.cs b/src/Tmds.Ssh/SshConnectionClosedException.cs index 4cb3fa15..4d1e6b37 100644 --- a/src/Tmds.Ssh/SshConnectionClosedException.cs +++ b/src/Tmds.Ssh/SshConnectionClosedException.cs @@ -6,6 +6,7 @@ namespace Tmds.Ssh; public class SshConnectionClosedException : SshConnectionException { internal const string ConnectionClosedByPeer = "Connection closed by peer."; + internal const string ConnectionClosedByKeepAliveTimeout = "Connection closed due to keep alive timeout."; internal const string ConnectionClosedByAbort = "Connection closed due to an unexpected error."; internal const string ConnectionClosedByDispose = "Connection closed by dispose."; diff --git a/src/Tmds.Ssh/SshSequencePoolExtensions.cs b/src/Tmds.Ssh/SshSequencePoolExtensions.cs index 17a9af4d..1bb8a3c2 100644 --- a/src/Tmds.Ssh/SshSequencePoolExtensions.cs +++ b/src/Tmds.Ssh/SshSequencePoolExtensions.cs @@ -212,4 +212,16 @@ uint32 bytes to add writer.WriteUInt32(bytesToAdd); return packet.Move(); } + + public static Packet CreateKeepAliveMessage(this SequencePool sequencePool) + { + using var packet = sequencePool.RentPacket(); + var writer = packet.GetWriter(); + writer.WriteMessageId(MessageId.SSH_MSG_GLOBAL_REQUEST); + // The request name can be any unknown name (to trigger an SSH_MSG_REQUEST_FAILURE response). + // We use the same name as the OpenSSH client. + writer.WriteString("keepalive@openssh.com"); + writer.WriteBoolean(true); // want reply + return packet.Move(); + } } diff --git a/src/Tmds.Ssh/SshSession.cs b/src/Tmds.Ssh/SshSession.cs index 21b6b53b..b8201923 100644 --- a/src/Tmds.Ssh/SshSession.cs +++ b/src/Tmds.Ssh/SshSession.cs @@ -11,7 +11,9 @@ namespace Tmds.Ssh; sealed partial class SshSession { - private static readonly Exception ClosedByPeer = new Exception(); // Sentinel _abortReason + // Sentinal _abortReason. Note: these get logged. + private static readonly Exception ClosedByPeer = new SshConnectionClosedException(SshConnectionClosedException.ConnectionClosedByPeer); + private static readonly Exception ClosedByKeepAliveTimeout = new SshConnectionClosedException(SshConnectionClosedException.ConnectionClosedByKeepAliveTimeout); private static readonly ObjectDisposedException DisposedException = SshClient.NewObjectDisposedException(); private readonly SshClient _client; @@ -29,7 +31,15 @@ sealed partial class SshSession private SemaphoreSlim? _keyReExchangeSemaphore; private const int BitsPerAllocatedItem = sizeof(int) * 8; private readonly List _allocatedChannels = new List(); + private readonly Queue?> _pendingGlobalRequestReplies = new(); private readonly SshLoggers _loggers; + private int _keepAliveMax; + private int _keepAliveCount; + + internal struct GlobalRequestReply + { + public required MessageId Id { get; init; } + } private ILogger Logger => _loggers.SshClientLogger; @@ -301,7 +311,37 @@ private SshChannel CreateChannel(Type channelType, Action? onAbort = } } + private Task SendGlobalRequestAsync(Packet packet, bool wantReply, bool ignoreReply = false) + { + if (!wantReply) + { + TrySendPacket(packet); + return Task.FromResult(default(GlobalRequestReply)); + } + + TaskCompletionSource? tcs = ignoreReply ? null : new(TaskCreationOptions.RunContinuationsAsynchronously); + lock (_pendingGlobalRequestReplies) + { + _pendingGlobalRequestReplies.Enqueue(tcs); + bool packetQueued = TrySendPacketCore(packet); + if (tcs is null) + { + return Task.FromResult(default(GlobalRequestReply)); + } + if (!packetQueued) + { + // We don't remove this because a queue is a FIFO. + tcs.SetException(CreateCloseException()); + } + } + + return tcs.Task; + } + internal void TrySendPacket(Packet packet) + => TrySendPacketCore(packet); + + internal bool TrySendPacketCore(Packet packet) { Channel? sendQueue = _sendQueue; @@ -314,7 +354,10 @@ internal void TrySendPacket(Packet packet) if (!sendQueue!.Writer.TryWrite(packet)) { packet.Dispose(); + return false; } + + return true; } private async Task SendLoopAsync(SshConnection connection) @@ -363,6 +406,17 @@ private async Task SendLoopAsync(SshConnection connection) packet.Dispose(); } } + + // Do this after we've completed the sendQueue writer + // so SendGlobalRequestAsync can act upon not being able to send when enqueueing. + lock (_pendingGlobalRequestReplies) + { + while (_pendingGlobalRequestReplies.TryDequeue(out TaskCompletionSource? tcs)) + { + // Use Try as we're competing with the read loop and SendGlobalRequestAsync (for failed sends). + tcs?.TrySetException(CreateCloseException()); + } + } } } @@ -372,6 +426,8 @@ private async Task ReceiveLoopAsync(SshConnection connection, SshConnectionInfo try { + EnableKeepAlive(connection, _settings.KeepAliveCountMax, _settings.KeepAliveInterval); + CancellationToken abortToken = _abortCts.Token; while (true) { @@ -423,6 +479,30 @@ private async Task ReceiveLoopAsync(SshConnection connection, SshConnectionInfo } } + private void EnableKeepAlive(SshConnection connection, int serverAliveCountMax, TimeSpan serverAliveInterval) + { + int period = (int)serverAliveInterval.TotalMilliseconds; + if (serverAliveCountMax > 0 && period > 0) + { + _keepAliveMax = serverAliveCountMax; + connection.EnableKeepAlive(period, OnKeepAlive); + } + } + + private void OnKeepAlive() + { + if (_keepAliveCount++ > _keepAliveMax) + { + Abort(ClosedByKeepAliveTimeout); + } + else + { + Task sendTask = SendGlobalRequestAsync(_sequencePool.CreateKeepAliveMessage(), wantReply: true, ignoreReply: true); + Debug.Assert(sendTask.IsCompletedSuccessfully); + sendTask.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing).GetAwaiter().GetResult(); + } + } + internal void HandleNonKexPacket(MessageId msgId, Packet _p) { using Packet packet = _p; // Ensure dispose @@ -465,6 +545,10 @@ internal void HandleNonKexPacket(MessageId msgId, Packet _p) case MessageId.SSH_MSG_DISCONNECT: HandleDisconnectMessage(packet); break; + case MessageId.SSH_MSG_REQUEST_SUCCESS: + case MessageId.SSH_MSG_REQUEST_FAILURE: + HandleGlobalRequestReply(packet); + break; default: ThrowHelper.ThrowProtocolUnexpectedMessageId(msgId); break; @@ -478,6 +562,18 @@ static uint GetChannelNumber(ReadOnlyPacket packet) } } + private void HandleGlobalRequestReply(ReadOnlyPacket packet) + { + _keepAliveCount = 0; + + TaskCompletionSource? tcs; + lock (_pendingGlobalRequestReplies) + { + _pendingGlobalRequestReplies.TryDequeue(out tcs); + } + tcs?.SetResult(new GlobalRequestReply { Id = packet.MessageId!.Value }); + } + private void HandleDisconnectMessage(ReadOnlyPacket packet) { /* @@ -576,6 +672,10 @@ internal Exception CreateCloseException() { return new SshConnectionClosedException(SshConnectionClosedException.ConnectionClosedByPeer); } + else if (_abortReason == ClosedByKeepAliveTimeout) + { + return new SshConnectionClosedException(SshConnectionClosedException.ConnectionClosedByKeepAliveTimeout); + } else if (_abortReason == DisposedException) { return new SshConnectionClosedException(SshConnectionClosedException.ConnectionClosedByDispose, _abortReason); From b4d2378b27e5a9ffc147e4e1001839953da3c26c Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Tue, 19 Nov 2024 05:48:24 +0100 Subject: [PATCH 2/6] Add test. --- test/Tmds.Ssh.Tests/KeepAliveTests.cs | 114 ++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 test/Tmds.Ssh.Tests/KeepAliveTests.cs diff --git a/test/Tmds.Ssh.Tests/KeepAliveTests.cs b/test/Tmds.Ssh.Tests/KeepAliveTests.cs new file mode 100644 index 00000000..3c9717fa --- /dev/null +++ b/test/Tmds.Ssh.Tests/KeepAliveTests.cs @@ -0,0 +1,114 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using Xunit; + +namespace Tmds.Ssh.Tests; + +[Collection(nameof(SshServerCollection))] +public class KeepAliveTests +{ + private readonly SshServer _sshServer; + + public KeepAliveTests(SshServer sshServer) + { + _sshServer = sshServer; + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task KeepAliveTimeOutClosesConnection(bool enableKeepAlive) + { + TimeSpan keepAliveInterval = TimeSpan.FromMilliseconds(500); + int keepAliveCountMax = 3; + + // Establish a proxied connection. + TcpListener proxyServer = new TcpListener(IPAddress.Loopback, 0); + proxyServer.Start(); + IPEndPoint localEndPoint = (IPEndPoint)proxyServer.LocalEndpoint; + string destination = $"{_sshServer.TestUser}@{localEndPoint.Address}:{localEndPoint.Port}"; + var settings = new SshClientSettings(destination) + { + HostAuthentication = delegate { return ValueTask.FromResult(true); }, + Credentials = [ new PasswordCredential(_sshServer.TestUserPassword) ], + KeepAliveCountMax = keepAliveCountMax, + KeepAliveInterval = enableKeepAlive ? keepAliveInterval : TimeSpan.Zero + }; + using var client = new SshClient(settings); + Task connectTask = client.ConnectAsync(); + using ProxyConnection proxyConnection = new ProxyConnection(await proxyServer.AcceptSocketAsync()); + proxyConnection.ProxyTo(_sshServer.ServerHost, _sshServer.ServerPort); + await connectTask; + + // Keep the TCP connection but stop relaying data. + proxyConnection.StopProxying(); + + // Start a command. + long startTime = Stopwatch.GetTimestamp(); + Task executeHello = client.ExecuteAsync("echo 'hello world'"); + + Task timeoutTask = Task.Delay(keepAliveInterval * (keepAliveCountMax + 2) + TimeSpan.FromSeconds(2)); + Task completedTask = await Task.WhenAny(executeHello, timeoutTask); + if (enableKeepAlive) + { + TimeSpan elapsedTime = Stopwatch.GetElapsedTime(startTime); + Assert.Equal(executeHello, completedTask); + Assert.True(elapsedTime > keepAliveInterval * keepAliveCountMax); + await Assert.ThrowsAsync(() => completedTask); + } + else + { + Assert.Equal(timeoutTask, timeoutTask); + } + } + + private class ProxyConnection : IDisposable + { + private readonly Socket _socket; + private readonly CancellationTokenSource _cts = new(); + private Socket? _proxySocket; + + public ProxyConnection(Socket socket) + => _socket = socket; + + public void StopProxying() + { + _cts.Cancel(); + } + + public void ProxyTo(string host, int port) + { + _proxySocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _proxySocket.Connect(host, port); + _ = RelayDataAsync(_socket, _proxySocket, _cts.Token); + _ = RelayDataAsync(_proxySocket, _socket, _cts.Token); + } + + private static async Task RelayDataAsync(Socket source, Socket destination, CancellationToken cancellationToken) + { + try + { + var buffer = new byte[4096]; + while (!cancellationToken.IsCancellationRequested) + { + int bytesRead = await source.ReceiveAsync(buffer, SocketFlags.None, cancellationToken); + if (bytesRead == 0) + { + break; + } + + await destination.SendAsync(new ArraySegment(buffer, 0, bytesRead), SocketFlags.None, cancellationToken); + } + } + catch (OperationCanceledException) + { } + } + + public void Dispose() + { + _socket.Dispose(); + _proxySocket?.Dispose(); + } + } +} From cffaa159541e78668877ccd5e3d2959eb117b65a Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Tue, 19 Nov 2024 10:03:09 +0100 Subject: [PATCH 3/6] Improvements. --- src/Tmds.Ssh/SocketSshConnection.cs | 4 +++- test/Tmds.Ssh.Tests/KeepAliveTests.cs | 10 ++++++++-- test/Tmds.Ssh.Tests/SshConfigTests.cs | 6 ++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/Tmds.Ssh/SocketSshConnection.cs b/src/Tmds.Ssh/SocketSshConnection.cs index c94949aa..03451c0a 100644 --- a/src/Tmds.Ssh/SocketSshConnection.cs +++ b/src/Tmds.Ssh/SocketSshConnection.cs @@ -39,7 +39,9 @@ public override void EnableKeepAlive(int period, Action callback) _keepAlivePeriod = period; _keepAliveCallback = callback; _lastReceivedTime = GetTime(); - _keepAliveTimer = new Timer(o => ((SocketSshConnection)o!).OnKeepAliveTimerCallback(), this, _keepAlivePeriod, _keepAlivePeriod); + _keepAliveTimer = new Timer(o => ((SocketSshConnection)o!).OnKeepAliveTimerCallback(), this, -1, -1); + // Start timer AFTER assigning the variable. + _keepAliveTimer.Change(_keepAlivePeriod, _keepAlivePeriod); } } diff --git a/test/Tmds.Ssh.Tests/KeepAliveTests.cs b/test/Tmds.Ssh.Tests/KeepAliveTests.cs index 3c9717fa..4bab1c42 100644 --- a/test/Tmds.Ssh.Tests/KeepAliveTests.cs +++ b/test/Tmds.Ssh.Tests/KeepAliveTests.cs @@ -24,7 +24,7 @@ public async Task KeepAliveTimeOutClosesConnection(bool enableKeepAlive) int keepAliveCountMax = 3; // Establish a proxied connection. - TcpListener proxyServer = new TcpListener(IPAddress.Loopback, 0); + using TcpListener proxyServer = new TcpListener(IPAddress.Loopback, 0); proxyServer.Start(); IPEndPoint localEndPoint = (IPEndPoint)proxyServer.LocalEndpoint; string destination = $"{_sshServer.TestUser}@{localEndPoint.Address}:{localEndPoint.Port}"; @@ -48,7 +48,9 @@ public async Task KeepAliveTimeOutClosesConnection(bool enableKeepAlive) long startTime = Stopwatch.GetTimestamp(); Task executeHello = client.ExecuteAsync("echo 'hello world'"); - Task timeoutTask = Task.Delay(keepAliveInterval * (keepAliveCountMax + 2) + TimeSpan.FromSeconds(2)); + // Task that times out after the keep alive. + Task timeoutTask = Task.Delay(keepAliveInterval * (keepAliveCountMax + 1) + TimeSpan.FromSeconds(1)); + Task completedTask = await Task.WhenAny(executeHello, timeoutTask); if (enableKeepAlive) { @@ -60,6 +62,10 @@ public async Task KeepAliveTimeOutClosesConnection(bool enableKeepAlive) else { Assert.Equal(timeoutTask, timeoutTask); + + Assert.False(executeHello.IsCompleted); + client.Dispose(); + await Assert.ThrowsAsync(() => executeHello); } } diff --git a/test/Tmds.Ssh.Tests/SshConfigTests.cs b/test/Tmds.Ssh.Tests/SshConfigTests.cs index 0ca7a14a..df22930d 100644 --- a/test/Tmds.Ssh.Tests/SshConfigTests.cs +++ b/test/Tmds.Ssh.Tests/SshConfigTests.cs @@ -30,6 +30,8 @@ Compression yes Macs hmac-sha2-256-etm@openssh.com,hmac-sha2-512 PubKeyAcceptedAlgorithms rsa-sha2-256,ecdsa-sha2-nistp256 CanonicalizeHostName no + ServerAliveCountMax 7 + ServerAliveInterval 20 # !!! update SupportedSettingsAlternateConfig when adding values here !!! """; @@ -57,6 +59,8 @@ KexAlgorithms ecdh-sha2-nistp384 Macs hmac-sha2-512 PubKeyAcceptedAlgorithms ecdsa-sha2-nistp256 CanonicalizeHostName yes + ServerAliveCountMax 8 + ServerAliveInterval 30 """; [Fact] @@ -103,6 +107,8 @@ private void VerifyConfig(SshConfig config) Assert.Equal(true, config.GssApiAuthentication); Assert.Equal(true, config.GssApiDelegateCredentials); Assert.Equal("serverid", config.GssApiServerIdentity); + Assert.Equal(7, config.ServerAliveCountMax); + Assert.Equal(20, config.ServerAliveInterval); } [Fact] From 577d79a76d384cb0833d74bfac1b0ccba82c5414 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Tue, 19 Nov 2024 10:10:26 +0100 Subject: [PATCH 4/6] Tweaks. --- src/Tmds.Ssh/SocketSshConnection.cs | 2 +- src/Tmds.Ssh/SshClientSettings.cs | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/Tmds.Ssh/SocketSshConnection.cs b/src/Tmds.Ssh/SocketSshConnection.cs index 03451c0a..0e6bf37a 100644 --- a/src/Tmds.Ssh/SocketSshConnection.cs +++ b/src/Tmds.Ssh/SocketSshConnection.cs @@ -40,7 +40,7 @@ public override void EnableKeepAlive(int period, Action callback) _keepAliveCallback = callback; _lastReceivedTime = GetTime(); _keepAliveTimer = new Timer(o => ((SocketSshConnection)o!).OnKeepAliveTimerCallback(), this, -1, -1); - // Start timer AFTER assigning the variable. + // Start timer after assigning the variable to ensure it is set when the callback is invoked. _keepAliveTimer.Change(_keepAlivePeriod, _keepAlivePeriod); } } diff --git a/src/Tmds.Ssh/SshClientSettings.cs b/src/Tmds.Ssh/SshClientSettings.cs index 942b9307..aedbb9e8 100644 --- a/src/Tmds.Ssh/SshClientSettings.cs +++ b/src/Tmds.Ssh/SshClientSettings.cs @@ -114,16 +114,19 @@ public int KeepAliveCountMax get => _keepAliveCountMax; set { - ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(value, 1); + ArgumentOutOfRangeException.ThrowIfLessThan(value, 0); _keepAliveCountMax = value; } } - // Zero or < 0 is disabled. public TimeSpan KeepAliveInterval { get => _keepAliveInterval; - set => _keepAliveInterval = value; + set + { + ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero); + _keepAliveInterval = value; + } } internal void Validate() From c66f8bb2867bfa2fc15677596ab20d668650b0a3 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Tue, 19 Nov 2024 10:12:40 +0100 Subject: [PATCH 5/6] Update ClientSettingsTests. --- test/Tmds.Ssh.Tests/SshClientSettingsTests.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Tmds.Ssh.Tests/SshClientSettingsTests.cs b/test/Tmds.Ssh.Tests/SshClientSettingsTests.cs index c59307c4..5d250da8 100644 --- a/test/Tmds.Ssh.Tests/SshClientSettingsTests.cs +++ b/test/Tmds.Ssh.Tests/SshClientSettingsTests.cs @@ -32,6 +32,8 @@ public void Defaults() Assert.Equal(new[] { new Name("none") }, settings.CompressionAlgorithmsServerToClient); Assert.Equal(Array.Empty(), settings.LanguagesClientToServer); Assert.Equal(Array.Empty(), settings.LanguagesServerToClient); + Assert.Equal(3, settings.KeepAliveCountMax); + Assert.Equal(TimeSpan.Zero, settings.KeepAliveInterval); } [Theory] From bcf8cf932b65c80d0c833e1059418e423504d161 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Tue, 19 Nov 2024 10:36:33 +0100 Subject: [PATCH 6/6] README tweaks. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d38bbc7c..1ac6debc 100644 --- a/README.md +++ b/README.md @@ -232,7 +232,7 @@ class SshClientSettings bool TcpKeepAlive { get; set; } = true; TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero; - public int KeepAliveCountMax = 3; + int KeepAliveCountMax = 3; List GlobalKnownHostsFilePaths { get; set; } = DefaultGlobalKnownHostsFilePaths; List UserKnownHostsFilePaths { get; set; } = DefaultUserKnownHostsFilePaths; @@ -260,7 +260,7 @@ class SshConfigSettings HostAuthentication? HostAuthentication { get; set; } // Called for Unknown when StrictHostKeyChecking is 'ask' (default) } -public enum SshConfigOption +enum SshConfigOption { Hostname, User,