Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CancelAsync Cause Deadlock #1345

Merged
merged 9 commits into from
Mar 24, 2024
53 changes: 36 additions & 17 deletions src/Renci.SshNet/SshCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ public class SshCommand : IDisposable
private CommandAsyncResult _asyncResult;
private AsyncCallback _callback;
private EventWaitHandle _sessionErrorOccuredWaitHandle;
private EventWaitHandle _commandCancelledWaitHandle;
private Exception _exception;
private StringBuilder _result;
private StringBuilder _error;
private bool _hasError;
private bool _isDisposed;
private bool _isCancelled;
private ChannelInputStream _inputStream;
private TimeSpan _commandTimeout;

Expand Down Expand Up @@ -84,7 +86,7 @@ public TimeSpan CommandTimeout
/// <returns>
/// The stream that can be used to transfer data to the command's input stream.
/// </returns>
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
public Stream CreateInputStream()
#pragma warning restore CA1859 // Use concrete types when possible for improved performance
{
Expand Down Expand Up @@ -186,7 +188,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding)
_encoding = encoding;
CommandTimeout = Timeout.InfiniteTimeSpan;
_sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false);

_commandCancelledWaitHandle = new AutoResetEvent(initialState: false);
_session.Disconnected += Session_Disconnected;
_session.ErrorOccured += Session_ErrorOccured;
}
Expand Down Expand Up @@ -249,11 +251,11 @@ public IAsyncResult BeginExecute(AsyncCallback callback, object state)

// Create new AsyncResult object
_asyncResult = new CommandAsyncResult
{
AsyncWaitHandle = new ManualResetEvent(initialState: false),
IsCompleted = false,
AsyncState = state,
};
{
AsyncWaitHandle = new ManualResetEvent(initialState: false),
IsCompleted = false,
AsyncState = state,
};

if (_channel is not null)
{
Expand Down Expand Up @@ -349,20 +351,25 @@ public string EndExecute(IAsyncResult asyncResult)

commandAsyncResult.EndCalled = true;

return Result;
if (!_isCancelled)
{
return Result;
}

SetAsyncComplete();
throw new OperationCanceledException();
}
}

/// <summary>
/// Cancels command execution in asynchronous scenarios.
/// </summary>
public void CancelAsync()
/// <param name="forceKill">if true send SIGKILL instead of SIGTERM.</param>
public void CancelAsync(bool forceKill = false)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we drop the Async suffix here (and obsolete CancelAsync) as it is not an async method. @WojciechNagorski?

{
if (_channel is not null && _channel.IsOpen && _asyncResult is not null)
{
// TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ?
_channel.Dispose();
}
var signal = forceKill ? "KILL" : "TERM";
_ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en");
_ = _commandCancelledWaitHandle?.Set();
}

/// <summary>
Expand Down Expand Up @@ -430,14 +437,14 @@ private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
_ = _sessionErrorOccuredWaitHandle.Set();
}

private void Channel_Closed(object sender, ChannelEventArgs e)
private void SetAsyncComplete()
{
OutputStream?.Flush();
ExtendedOutputStream?.Flush();

_asyncResult.IsCompleted = true;

if (_callback is not null)
if (_callback is not null && !_isCancelled)
{
// Execute callback on different thread
ThreadAbstraction.ExecuteThread(() => _callback(_asyncResult));
Expand All @@ -446,6 +453,11 @@ private void Channel_Closed(object sender, ChannelEventArgs e)
_ = ((EventWaitHandle) _asyncResult.AsyncWaitHandle).Set();
}

private void Channel_Closed(object sender, ChannelEventArgs e)
{
SetAsyncComplete();
}

private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e)
{
if (e.Info is ExitStatusRequestInfo exitStatusInfo)
Expand Down Expand Up @@ -506,7 +518,8 @@ private void WaitOnHandle(WaitHandle waitHandle)
var waitHandles = new[]
{
_sessionErrorOccuredWaitHandle,
waitHandle
waitHandle,
_commandCancelledWaitHandle
};

var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout);
Expand All @@ -518,6 +531,9 @@ private void WaitOnHandle(WaitHandle waitHandle)
case 1:
// Specified waithandle was signaled
break;
case 2:
_isCancelled = true;
break;
case WaitHandle.WaitTimeout:
throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText));
default:
Expand Down Expand Up @@ -620,6 +636,9 @@ protected virtual void Dispose(bool disposing)
_sessionErrorOccuredWaitHandle = null;
}

_commandCancelledWaitHandle?.Dispose();
_commandCancelledWaitHandle = null;

_isDisposed = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@ public void Test_Execute_SingleCommand()
}
}

[TestMethod]
[Timeout(5000)]
public void Test_CancelAsync_Unfinished_Command()
{
using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
#region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal
client.Connect();
var testValue = Guid.NewGuid().ToString();
var command = $"sleep 15s; echo {testValue}";
using var cmd = client.CreateCommand(command);
var asyncResult = cmd.BeginExecute();
cmd.CancelAsync();
Assert.ThrowsException<OperationCanceledException>(() => cmd.EndExecute(asyncResult));
Assert.IsTrue(asyncResult.IsCompleted);
client.Disconnect();
Assert.AreEqual(string.Empty, cmd.Result.Trim());
#endregion
}

[TestMethod]
public async Task Test_CancelAsync_Finished_Command()
{
using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
#region Example SshCommand CancelAsync Finished Command
client.Connect();
var testValue = Guid.NewGuid().ToString();
var command = $"echo {testValue}";
using var cmd = client.CreateCommand(command);
var asyncResult = cmd.BeginExecute();
while (!asyncResult.IsCompleted)
{
await Task.Delay(200);
}

cmd.CancelAsync();
cmd.EndExecute(asyncResult);
client.Disconnect();

Assert.IsTrue(asyncResult.IsCompleted);
Assert.AreEqual(testValue, cmd.Result.Trim());
#endregion
}

[TestMethod]
public void Test_Execute_OutputStream()
{
Expand Down Expand Up @@ -222,7 +265,7 @@ public void Test_Execute_Command_ExitStatus()
client.Connect();

var cmd = client.RunCommand("exit 128");

Console.WriteLine(cmd.ExitStatus);

client.Disconnect();
Expand Down Expand Up @@ -443,7 +486,7 @@ public void Test_Execute_Invalid_Command()
}

[TestMethod]

public void Test_MultipleThread_100_MultipleConnections()
{
try
Expand Down