diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs index 58b1942cc..076cb901a 100644 --- a/src/Renci.SshNet/SshCommand.cs +++ b/src/Renci.SshNet/SshCommand.cs @@ -26,11 +26,13 @@ public class SshCommand : IDisposable private CommandAsyncResult _asyncResult; private AsyncCallback _callback; private EventWaitHandle _sessionErrorOccuredWaitHandle; + private EventWaitHandle _commandCancelledWaitHandle; private Exception _exception; private StringBuilder _result; private StringBuilder _error; private bool _hasError; private bool _isDisposed; + private bool _isCancelled; private ChannelInputStream _inputStream; private TimeSpan _commandTimeout; @@ -84,7 +86,7 @@ public TimeSpan CommandTimeout /// /// The stream that can be used to transfer data to the command's input stream. /// - #pragma warning disable CA1859 // Use concrete types when possible for improved performance +#pragma warning disable CA1859 // Use concrete types when possible for improved performance public Stream CreateInputStream() #pragma warning restore CA1859 // Use concrete types when possible for improved performance { @@ -186,7 +188,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding) _encoding = encoding; CommandTimeout = Timeout.InfiniteTimeSpan; _sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false); - + _commandCancelledWaitHandle = new AutoResetEvent(initialState: false); _session.Disconnected += Session_Disconnected; _session.ErrorOccured += Session_ErrorOccured; } @@ -249,11 +251,11 @@ public IAsyncResult BeginExecute(AsyncCallback callback, object state) // Create new AsyncResult object _asyncResult = new CommandAsyncResult - { - AsyncWaitHandle = new ManualResetEvent(initialState: false), - IsCompleted = false, - AsyncState = state, - }; + { + AsyncWaitHandle = new ManualResetEvent(initialState: false), + IsCompleted = false, + AsyncState = state, + }; if (_channel is not null) { @@ -349,20 +351,25 @@ public string EndExecute(IAsyncResult asyncResult) commandAsyncResult.EndCalled = true; - return Result; + if (!_isCancelled) + { + return Result; + } + + SetAsyncComplete(); + throw new OperationCanceledException(); } } /// /// Cancels command execution in asynchronous scenarios. /// - public void CancelAsync() + /// if true send SIGKILL instead of SIGTERM. + public void CancelAsync(bool forceKill = false) { - if (_channel is not null && _channel.IsOpen && _asyncResult is not null) - { - // TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ? - _channel.Dispose(); - } + var signal = forceKill ? "KILL" : "TERM"; + _ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en"); + _ = _commandCancelledWaitHandle?.Set(); } /// @@ -430,14 +437,14 @@ private void Session_ErrorOccured(object sender, ExceptionEventArgs e) _ = _sessionErrorOccuredWaitHandle.Set(); } - private void Channel_Closed(object sender, ChannelEventArgs e) + private void SetAsyncComplete() { OutputStream?.Flush(); ExtendedOutputStream?.Flush(); _asyncResult.IsCompleted = true; - if (_callback is not null) + if (_callback is not null && !_isCancelled) { // Execute callback on different thread ThreadAbstraction.ExecuteThread(() => _callback(_asyncResult)); @@ -446,6 +453,11 @@ private void Channel_Closed(object sender, ChannelEventArgs e) _ = ((EventWaitHandle) _asyncResult.AsyncWaitHandle).Set(); } + private void Channel_Closed(object sender, ChannelEventArgs e) + { + SetAsyncComplete(); + } + private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e) { if (e.Info is ExitStatusRequestInfo exitStatusInfo) @@ -506,7 +518,8 @@ private void WaitOnHandle(WaitHandle waitHandle) var waitHandles = new[] { _sessionErrorOccuredWaitHandle, - waitHandle + waitHandle, + _commandCancelledWaitHandle }; var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout); @@ -518,6 +531,9 @@ private void WaitOnHandle(WaitHandle waitHandle) case 1: // Specified waithandle was signaled break; + case 2: + _isCancelled = true; + break; case WaitHandle.WaitTimeout: throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText)); default: @@ -620,6 +636,9 @@ protected virtual void Dispose(bool disposing) _sessionErrorOccuredWaitHandle = null; } + _commandCancelledWaitHandle?.Dispose(); + _commandCancelledWaitHandle = null; + _isDisposed = true; } } diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs index aefe1d6d0..4273339f4 100644 --- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs +++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs @@ -51,6 +51,49 @@ public void Test_Execute_SingleCommand() } } + [TestMethod] + [Timeout(5000)] + public void Test_CancelAsync_Unfinished_Command() + { + using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password); + #region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal + client.Connect(); + var testValue = Guid.NewGuid().ToString(); + var command = $"sleep 15s; echo {testValue}"; + using var cmd = client.CreateCommand(command); + var asyncResult = cmd.BeginExecute(); + cmd.CancelAsync(); + Assert.ThrowsException(() => cmd.EndExecute(asyncResult)); + Assert.IsTrue(asyncResult.IsCompleted); + client.Disconnect(); + Assert.AreEqual(string.Empty, cmd.Result.Trim()); + #endregion + } + + [TestMethod] + public async Task Test_CancelAsync_Finished_Command() + { + using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password); + #region Example SshCommand CancelAsync Finished Command + client.Connect(); + var testValue = Guid.NewGuid().ToString(); + var command = $"echo {testValue}"; + using var cmd = client.CreateCommand(command); + var asyncResult = cmd.BeginExecute(); + while (!asyncResult.IsCompleted) + { + await Task.Delay(200); + } + + cmd.CancelAsync(); + cmd.EndExecute(asyncResult); + client.Disconnect(); + + Assert.IsTrue(asyncResult.IsCompleted); + Assert.AreEqual(testValue, cmd.Result.Trim()); + #endregion + } + [TestMethod] public void Test_Execute_OutputStream() { @@ -222,7 +265,7 @@ public void Test_Execute_Command_ExitStatus() client.Connect(); var cmd = client.RunCommand("exit 128"); - + Console.WriteLine(cmd.ExitStatus); client.Disconnect(); @@ -443,7 +486,7 @@ public void Test_Execute_Invalid_Command() } [TestMethod] - + public void Test_MultipleThread_100_MultipleConnections() { try