diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs
index 58b1942cc..076cb901a 100644
--- a/src/Renci.SshNet/SshCommand.cs
+++ b/src/Renci.SshNet/SshCommand.cs
@@ -26,11 +26,13 @@ public class SshCommand : IDisposable
private CommandAsyncResult _asyncResult;
private AsyncCallback _callback;
private EventWaitHandle _sessionErrorOccuredWaitHandle;
+ private EventWaitHandle _commandCancelledWaitHandle;
private Exception _exception;
private StringBuilder _result;
private StringBuilder _error;
private bool _hasError;
private bool _isDisposed;
+ private bool _isCancelled;
private ChannelInputStream _inputStream;
private TimeSpan _commandTimeout;
@@ -84,7 +86,7 @@ public TimeSpan CommandTimeout
///
/// The stream that can be used to transfer data to the command's input stream.
///
- #pragma warning disable CA1859 // Use concrete types when possible for improved performance
+#pragma warning disable CA1859 // Use concrete types when possible for improved performance
public Stream CreateInputStream()
#pragma warning restore CA1859 // Use concrete types when possible for improved performance
{
@@ -186,7 +188,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding)
_encoding = encoding;
CommandTimeout = Timeout.InfiniteTimeSpan;
_sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false);
-
+ _commandCancelledWaitHandle = new AutoResetEvent(initialState: false);
_session.Disconnected += Session_Disconnected;
_session.ErrorOccured += Session_ErrorOccured;
}
@@ -249,11 +251,11 @@ public IAsyncResult BeginExecute(AsyncCallback callback, object state)
// Create new AsyncResult object
_asyncResult = new CommandAsyncResult
- {
- AsyncWaitHandle = new ManualResetEvent(initialState: false),
- IsCompleted = false,
- AsyncState = state,
- };
+ {
+ AsyncWaitHandle = new ManualResetEvent(initialState: false),
+ IsCompleted = false,
+ AsyncState = state,
+ };
if (_channel is not null)
{
@@ -349,20 +351,25 @@ public string EndExecute(IAsyncResult asyncResult)
commandAsyncResult.EndCalled = true;
- return Result;
+ if (!_isCancelled)
+ {
+ return Result;
+ }
+
+ SetAsyncComplete();
+ throw new OperationCanceledException();
}
}
///
/// Cancels command execution in asynchronous scenarios.
///
- public void CancelAsync()
+ /// if true send SIGKILL instead of SIGTERM.
+ public void CancelAsync(bool forceKill = false)
{
- if (_channel is not null && _channel.IsOpen && _asyncResult is not null)
- {
- // TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ?
- _channel.Dispose();
- }
+ var signal = forceKill ? "KILL" : "TERM";
+ _ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en");
+ _ = _commandCancelledWaitHandle?.Set();
}
///
@@ -430,14 +437,14 @@ private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
_ = _sessionErrorOccuredWaitHandle.Set();
}
- private void Channel_Closed(object sender, ChannelEventArgs e)
+ private void SetAsyncComplete()
{
OutputStream?.Flush();
ExtendedOutputStream?.Flush();
_asyncResult.IsCompleted = true;
- if (_callback is not null)
+ if (_callback is not null && !_isCancelled)
{
// Execute callback on different thread
ThreadAbstraction.ExecuteThread(() => _callback(_asyncResult));
@@ -446,6 +453,11 @@ private void Channel_Closed(object sender, ChannelEventArgs e)
_ = ((EventWaitHandle) _asyncResult.AsyncWaitHandle).Set();
}
+ private void Channel_Closed(object sender, ChannelEventArgs e)
+ {
+ SetAsyncComplete();
+ }
+
private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e)
{
if (e.Info is ExitStatusRequestInfo exitStatusInfo)
@@ -506,7 +518,8 @@ private void WaitOnHandle(WaitHandle waitHandle)
var waitHandles = new[]
{
_sessionErrorOccuredWaitHandle,
- waitHandle
+ waitHandle,
+ _commandCancelledWaitHandle
};
var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout);
@@ -518,6 +531,9 @@ private void WaitOnHandle(WaitHandle waitHandle)
case 1:
// Specified waithandle was signaled
break;
+ case 2:
+ _isCancelled = true;
+ break;
case WaitHandle.WaitTimeout:
throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText));
default:
@@ -620,6 +636,9 @@ protected virtual void Dispose(bool disposing)
_sessionErrorOccuredWaitHandle = null;
}
+ _commandCancelledWaitHandle?.Dispose();
+ _commandCancelledWaitHandle = null;
+
_isDisposed = true;
}
}
diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs
index aefe1d6d0..4273339f4 100644
--- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs
+++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs
@@ -51,6 +51,49 @@ public void Test_Execute_SingleCommand()
}
}
+ [TestMethod]
+ [Timeout(5000)]
+ public void Test_CancelAsync_Unfinished_Command()
+ {
+ using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
+ #region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal
+ client.Connect();
+ var testValue = Guid.NewGuid().ToString();
+ var command = $"sleep 15s; echo {testValue}";
+ using var cmd = client.CreateCommand(command);
+ var asyncResult = cmd.BeginExecute();
+ cmd.CancelAsync();
+ Assert.ThrowsException(() => cmd.EndExecute(asyncResult));
+ Assert.IsTrue(asyncResult.IsCompleted);
+ client.Disconnect();
+ Assert.AreEqual(string.Empty, cmd.Result.Trim());
+ #endregion
+ }
+
+ [TestMethod]
+ public async Task Test_CancelAsync_Finished_Command()
+ {
+ using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
+ #region Example SshCommand CancelAsync Finished Command
+ client.Connect();
+ var testValue = Guid.NewGuid().ToString();
+ var command = $"echo {testValue}";
+ using var cmd = client.CreateCommand(command);
+ var asyncResult = cmd.BeginExecute();
+ while (!asyncResult.IsCompleted)
+ {
+ await Task.Delay(200);
+ }
+
+ cmd.CancelAsync();
+ cmd.EndExecute(asyncResult);
+ client.Disconnect();
+
+ Assert.IsTrue(asyncResult.IsCompleted);
+ Assert.AreEqual(testValue, cmd.Result.Trim());
+ #endregion
+ }
+
[TestMethod]
public void Test_Execute_OutputStream()
{
@@ -222,7 +265,7 @@ public void Test_Execute_Command_ExitStatus()
client.Connect();
var cmd = client.RunCommand("exit 128");
-
+
Console.WriteLine(cmd.ExitStatus);
client.Disconnect();
@@ -443,7 +486,7 @@ public void Test_Execute_Invalid_Command()
}
[TestMethod]
-
+
public void Test_MultipleThread_100_MultipleConnections()
{
try