Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent propagation of Ctrl+C to child processes #44565

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/BuiltInTools/dotnet-watch/EnvironmentOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ internal enum TestFlags
/// Elevates the severity of <see cref="MessageDescriptor.WaitingForChanges"/> from <see cref="MessageSeverity.Output"/>.
/// </summary>
ElevateWaitingForChangesMessageSeverity = 1 << 2,

/// <summary>
/// Instead of using <see cref="Console.ReadKey()"/> to watch for Ctrl+C, Ctlr+R, and other keys, read from standard input.
/// This allows tests to trigger key based events.
/// </summary>
ReadKeyFromStdin = 1 << 3,
}

internal sealed record EnvironmentOptions(
Expand Down
13 changes: 6 additions & 7 deletions src/BuiltInTools/dotnet-watch/HotReloadDotNetWatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ public override async Task WatchAsync(CancellationToken shutdownCancellationToke

_console.KeyPressed += (key) =>
{
var modifiers = ConsoleModifiers.Control;
if ((key.Modifiers & modifiers) == modifiers && key.Key == ConsoleKey.R && forceRestartCancellationSource is { } source)
if (key.Modifiers.HasFlag(ConsoleModifiers.Control) && key.Key == ConsoleKey.R && forceRestartCancellationSource is { } source)
{
// provide immediate feedback to the user:
Context.Reporter.Report(source.IsCancellationRequested ? MessageDescriptor.RestartInProgress : MessageDescriptor.RestartRequested);
Expand Down Expand Up @@ -327,11 +326,6 @@ await Task.WhenAll(
}
finally
{
if (!rootProcessTerminationSource.IsCancellationRequested)
{
rootProcessTerminationSource.Cancel();
}

if (runtimeProcessLauncher != null)
{
// Request cleanup of all processes created by the launcher before we terminate the root process.
Expand All @@ -345,6 +339,11 @@ await Task.WhenAll(
await compilationHandler.TerminateNonRootProcessesAndDispose(CancellationToken.None);
}

if (!rootProcessTerminationSource.IsCancellationRequested)
{
rootProcessTerminationSource.Cancel();
}

try
{
// Wait for the root process to exit.
Expand Down
5 changes: 0 additions & 5 deletions src/BuiltInTools/dotnet-watch/Internal/IConsole.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@ namespace Microsoft.Extensions.Tools.Internal
/// </summary>
internal interface IConsole
{
event ConsoleCancelEventHandler CancelKeyPress;
event Action<ConsoleKeyInfo> KeyPressed;
TextWriter Out { get; }
TextWriter Error { get; }
TextReader In { get; }
bool IsInputRedirected { get; }
bool IsOutputRedirected { get; }
bool IsErrorRedirected { get; }
ConsoleColor ForegroundColor { get; set; }
void ResetColor();
void Clear();
Expand Down
83 changes: 57 additions & 26 deletions src/BuiltInTools/dotnet-watch/Internal/PhysicalConsole.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.DotNet.Watcher;

namespace Microsoft.Extensions.Tools.Internal
{
/// <summary>
Expand All @@ -9,52 +11,81 @@ namespace Microsoft.Extensions.Tools.Internal
/// </summary>
internal sealed class PhysicalConsole : IConsole
{
private readonly List<Action<ConsoleKeyInfo>> _keyPressedListeners = new();
public const char CtrlC = '\x03';
public const char CtrlR = '\x12';

public event Action<ConsoleKeyInfo>? KeyPressed;

private PhysicalConsole()
public PhysicalConsole(TestFlags testFlags)
{
Console.OutputEncoding = Encoding.UTF8;
Console.CancelKeyPress += (o, e) =>

bool readFromStdin;
if (testFlags.HasFlag(TestFlags.ReadKeyFromStdin))
{
CancelKeyPress?.Invoke(o, e);
};
readFromStdin = true;
}
else
{
try
{
Console.TreatControlCAsInput = true;
readFromStdin = false;
}
catch
{
// fails when stdin is redirected
readFromStdin = true;
}
}

_ = readFromStdin ? ListenToStandardInputAsync() : ListenToConsoleKeyPressAsync();
}

public event Action<ConsoleKeyInfo> KeyPressed
private async Task ListenToStandardInputAsync()
{
add
using var stream = Console.OpenStandardInput();
var buffer = new byte[1];

while (true)
{
_keyPressedListeners.Add(value);
ListenToConsoleKeyPress();
}
var bytesRead = await stream.ReadAsync(buffer, CancellationToken.None);
if (bytesRead != 1)
{
break;
}

var c = (char)buffer[0];

remove => _keyPressedListeners.Remove(value);
// handle all input keys that watcher might consume:
var key = c switch
{
CtrlC => new ConsoleKeyInfo('C', ConsoleKey.C, shift: false, alt: false, control: true),
CtrlR => new ConsoleKeyInfo('R', ConsoleKey.R, shift: false, alt: false, control: true),
>= 'A' and <= 'Z' => new ConsoleKeyInfo(c, ConsoleKey.A + (c - 'A'), shift: false, alt: false, control: false),
_ => default
};

if (key.Key != ConsoleKey.None)
{
KeyPressed?.Invoke(key);
}
}
}

private void ListenToConsoleKeyPress()
{
Task.Factory.StartNew(() =>
private Task ListenToConsoleKeyPressAsync()
=> Task.Factory.StartNew(() =>
{
while (true)
{
var key = Console.ReadKey(intercept: true);
for (var i = 0; i < _keyPressedListeners.Count; i++)
{
_keyPressedListeners[i](key);
}
KeyPressed?.Invoke(key);
}
}, TaskCreationOptions.LongRunning);
}

public static IConsole Singleton { get; } = new PhysicalConsole();

public event ConsoleCancelEventHandler? CancelKeyPress;
public TextWriter Error => Console.Error;
public TextReader In => Console.In;
public TextWriter Out => Console.Out;
public bool IsInputRedirected => Console.IsInputRedirected;
public bool IsOutputRedirected => Console.IsOutputRedirected;
public bool IsErrorRedirected => Console.IsErrorRedirected;

public ConsoleColor ForegroundColor
{
get => Console.ForegroundColor;
Expand Down
58 changes: 48 additions & 10 deletions src/BuiltInTools/dotnet-watch/Internal/ProcessRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,11 @@ private static void TerminateProcess(Process process, ProcessState state, IRepor

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
process.Kill();
TerminateWindowsProcess(process, state, reporter);
}
else
{
[DllImport("libc", SetLastError = true, EntryPoint = "kill")]
static extern int sys_kill(int pid, int sig);

var result = sys_kill(state.ProcessId, state.ForceExit ? SIGKILL : SIGTERM);
if (result != 0)
{
var error = Marshal.GetLastPInvokeError();
reporter.Verbose($"Error while sending SIGTERM to process {state.ProcessId}: {Marshal.GetPInvokeErrorMessage(error)} (code {error}).");
}
TerminateUnixProcess(state, reporter);
}

reporter.Verbose($"Process {state.ProcessId} killed.");
Expand All @@ -272,5 +264,51 @@ private static void TerminateProcess(Process process, ProcessState state, IRepor
#endif
}
}

private static void TerminateWindowsProcess(Process process, ProcessState state, IReporter reporter)
{
// Needs API: https://github.com/dotnet/runtime/issues/109432
// Code below does not work because the process creation needs CREATE_NEW_PROCESS_GROUP flag.
#if TODO
if (!state.ForceExit)
{
const uint CTRL_C_EVENT = 0;

[DllImport("kernel32.dll", SetLastError = true)]
static extern bool GenerateConsoleCtrlEvent(uint dwCtrlEvent, uint dwProcessGroupId);

[DllImport("kernel32.dll", SetLastError = true)]
static extern bool AttachConsole(uint dwProcessId);

[DllImport("kernel32.dll", SetLastError = true)]
static extern bool FreeConsole();

if (AttachConsole((uint)state.ProcessId) &&
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0) &&
FreeConsole())
{
return;
}

var error = Marshal.GetLastPInvokeError();
reporter.Verbose($"Failed to send Ctrl+C to process {state.ProcessId}: {Marshal.GetPInvokeErrorMessage(error)} (code {error})");
}
#endif

process.Kill();
}

private static void TerminateUnixProcess(ProcessState state, IReporter reporter)
{
[DllImport("libc", SetLastError = true, EntryPoint = "kill")]
static extern int sys_kill(int pid, int sig);

var result = sys_kill(state.ProcessId, state.ForceExit ? SIGKILL : SIGTERM);
if (result != 0)
{
var error = Marshal.GetLastPInvokeError();
reporter.Verbose($"Error while sending SIGTERM to process {state.ProcessId}: {Marshal.GetPInvokeErrorMessage(error)} (code {error}).");
}
}
}
}
52 changes: 31 additions & 21 deletions src/BuiltInTools/dotnet-watch/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.


using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.Loader;
using Microsoft.Build.Graph;
Expand Down Expand Up @@ -36,10 +37,12 @@ public static async Task<int> Main(string[] args)
// Register listeners that load Roslyn-related assemblies from the `Roslyn/bincore` directory.
RegisterAssemblyResolutionEvents(sdkRootDirectory);

var environmentOptions = EnvironmentOptions.FromEnvironment();

var program = TryCreate(
args,
PhysicalConsole.Singleton,
EnvironmentOptions.FromEnvironment(),
new PhysicalConsole(environmentOptions.TestFlags),
environmentOptions,
EnvironmentVariables.VerboseCliOutput,
out var exitCode);

Expand Down Expand Up @@ -77,6 +80,11 @@ public static async Task<int> Main(string[] args)
var workingDirectory = environmentOptions.WorkingDirectory;
reporter.Verbose($"Working directory: '{workingDirectory}'");

if (environmentOptions.TestFlags != TestFlags.None)
{
reporter.Verbose($"Test flags: {environmentOptions.TestFlags}");
}

string projectPath;
try
{
Expand All @@ -97,9 +105,28 @@ public static async Task<int> Main(string[] args)
// internal for testing
internal async Task<int> RunAsync()
{
var shutdownCancellationSourceDisposed = false;
var shutdownCancellationSource = new CancellationTokenSource();
var shutdownCancellationToken = shutdownCancellationSource.Token;
console.CancelKeyPress += OnCancelKeyPress;

console.KeyPressed += key =>
{
if (!shutdownCancellationSourceDisposed && key.Modifiers.HasFlag(ConsoleModifiers.Control) && key.Key == ConsoleKey.C)
{
// if we already canceled, we force immediate shutdown:
var forceShutdown = shutdownCancellationSource.IsCancellationRequested;

if (!forceShutdown)
{
reporter.Report(MessageDescriptor.ShutdownRequested);
shutdownCancellationSource.Cancel();
}
else
{
Environment.Exit(0);
}
}
};

try
{
Expand Down Expand Up @@ -130,26 +157,9 @@ internal async Task<int> RunAsync()
}
finally
{
console.CancelKeyPress -= OnCancelKeyPress;
shutdownCancellationSourceDisposed = true;
shutdownCancellationSource.Dispose();
}

void OnCancelKeyPress(object? sender, ConsoleCancelEventArgs args)
{
// if we already canceled, we force immediate shutdown:
var forceShutdown = shutdownCancellationSource.IsCancellationRequested;

if (!forceShutdown)
{
reporter.Report(MessageDescriptor.ShutdownRequested);
shutdownCancellationSource.Cancel();
args.Cancel = true;
}
else
{
Environment.Exit(0);
}
}
}

// internal for testing
Expand Down
Loading