diff --git a/survey.go b/survey.go index ede27272..890c8925 100644 --- a/survey.go +++ b/survey.go @@ -10,6 +10,10 @@ import ( "github.com/AlecAivazis/survey/v2/terminal" ) +// OnInterrupt is the function to run when +// SIGINT (CTRL+C) is sent to the process. +var OnInterrupt func() + // DefaultAskOptions is the default options on ask, using the OS stdio. func defaultAskOptions() *AskOptions { return &AskOptions{ @@ -55,6 +59,7 @@ func defaultAskOptions() *AskOptions { }, KeepFilter: false, }, + OnInterrupt: OnInterrupt, } } func defaultPromptConfig() *PromptConfig { @@ -135,6 +140,7 @@ type AskOptions struct { Stdio terminal.Stdio Validators []Validator PromptConfig PromptConfig + OnInterrupt func() } // WithStdio specifies the standard input, output and error files survey @@ -180,6 +186,16 @@ func WithValidator(v Validator) AskOpt { } } +// WithInterruptFunc specifies a function to run on recieving +// SIGINT (aka CTRL+C) during prompt. +func WithInterruptFunc(fn func()) AskOpt { + return func(options *AskOptions) error { + options.OnInterrupt = fn + // nothing went wrong + return nil + } +} + type wantsStdio interface { WithStdio(terminal.Stdio) } @@ -289,6 +305,10 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error { // grab the user input and save it ans, err := q.Prompt.Prompt(&options.PromptConfig) + // if SIGINT is recieved. + if err == terminal.InterruptErr { + options.OnInterrupt() + } // if there was a problem if err != nil { return err diff --git a/survey_posix_test.go b/survey_posix_test.go index 95dea35b..c6d5de16 100644 --- a/survey_posix_test.go +++ b/survey_posix_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) { +func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error { t.Parallel() // Multiplex output to a buffer as well for the raw bytes. @@ -28,7 +28,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S }() err = test(Stdio(c)) - require.Nil(t, err) // Close the slave end of the pty, and read the remaining bytes from the master end. c.Tty().Close() @@ -38,4 +37,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S // Dump the terminal's screen. t.Logf("\n%s", expect.StripTrailingEmptyLines(state.String())) + + return err } diff --git a/survey_test.go b/survey_test.go index 58f8420e..6a222dfb 100644 --- a/survey_test.go +++ b/survey_test.go @@ -31,7 +31,7 @@ type PromptTest struct { func RunPromptTest(t *testing.T, test PromptTest) { var answer interface{} - RunTest(t, test.procedure, func(stdio terminal.Stdio) error { + err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error { var err error if p, ok := test.prompt.(wantsStdio); ok { p.WithStdio(stdio) @@ -40,12 +40,13 @@ func RunPromptTest(t *testing.T, test PromptTest) { answer, err = test.prompt.Prompt(defaultPromptConfig()) return err }) + require.Nil(t, err) require.Equal(t, test.expected, answer) } func RunPromptTestKeepFilter(t *testing.T, test PromptTest) { var answer interface{} - RunTest(t, test.procedure, func(stdio terminal.Stdio) error { + err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error { var err error if p, ok := test.prompt.(wantsStdio); ok { p.WithStdio(stdio) @@ -55,6 +56,7 @@ func RunPromptTestKeepFilter(t *testing.T, test PromptTest) { answer, err = test.prompt.Prompt(config) return err }) + require.Nil(t, err) require.Equal(t, test.expected, answer) } @@ -133,7 +135,6 @@ func TestPagination_lastHalf(t *testing.T) { func TestAsk(t *testing.T) { t.Skip() - return tests := []struct { name string questions []*Question @@ -250,10 +251,10 @@ func TestAsk(t *testing.T) { "pizza": true, "commit-message": "Add editor prompt tests\n", "commit-message-validated": "Add editor prompt tests\n", - "name": "Johnny Appleseed", - "day": []string{"Monday", "Wednesday"}, - "password": "secret", - "color": "yellow", + "name": "Johnny Appleseed", + "day": []string{"Monday", "Wednesday"}, + "password": "secret", + "color": "yellow", }, }, { @@ -305,9 +306,10 @@ func TestAsk(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { answers := make(map[string]interface{}) - RunTest(t, test.procedure, func(stdio terminal.Stdio) error { + err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error { return Ask(test.questions, &answers, WithStdio(stdio.In, stdio.Out, stdio.Err)) }) + require.Nil(t, err) require.Equal(t, test.expected, answers) }) } @@ -323,3 +325,57 @@ func TestAsk_returnsErrorIfTargetIsNil(t *testing.T) { t.Error("Did not encounter error when asking with no where to record.") } } + +func TestOnInterruptFunc(t *testing.T) { + // No Interrupt function set. + t.Run("No OnInterrupt", func(t *testing.T) { + answer := "" + err := RunTest(t, func(e *expect.Console) { + e.ExpectString("Are you a bot?") + e.SendLine(string(terminal.KeyInterrupt)) + e.ExpectEOF() + }, func(t terminal.Stdio) error { + return AskOne(&Input{Message: "Are you a bot?"}, &answer, + WithStdio(t.In, t.Out, t.Err)) + }) + + require.Equal(t, terminal.InterruptErr, err) + require.Equal(t, "", answer) + }) + + // Set global Interrupt function. + OnInterrupt = func() { fmt.Println("Ended abruptly!") } + t.Run("Global OnInterrupt", func(t *testing.T) { + answer := "" + err := RunTest(t, func(e *expect.Console) { + e.ExpectString("Are you a bot?") + e.SendLine(string(terminal.KeyInterrupt)) + e.ExpectString("Ended abruptly!") + e.ExpectEOF() + }, func(t terminal.Stdio) error { + return AskOne(&Input{Message: "Are you a bot?"}, &answer, + WithStdio(t.In, t.Out, t.Err)) + }) + + require.Equal(t, terminal.InterruptErr, err) + require.Equal(t, "", answer) + }) + + // Set local Interrupt function (overide global). + t.Run("Local Interrupt (override global", func(t *testing.T) { + answer := "" + err := RunTest(t, func(e *expect.Console) { + e.ExpectString("Are you a bot?") + e.SendLine(string(terminal.KeyInterrupt)) + e.ExpectString("The end.") + e.ExpectEOF() + }, func(t terminal.Stdio) error { + return AskOne(&Input{Message: "Are you a bot?"}, &answer, + WithStdio(t.In, t.Out, t.Err), + WithInterruptFunc(func() { fmt.Println("The end.") })) + }) + + require.Equal(t, terminal.InterruptErr, err) + require.Equal(t, "", answer) + }) +} diff --git a/survey_windows_test.go b/survey_windows_test.go index e22022ce..3c26ad08 100644 --- a/survey_windows_test.go +++ b/survey_windows_test.go @@ -7,6 +7,7 @@ import ( expect "github.com/Netflix/go-expect" ) -func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) { +func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error { t.Skip("Windows does not support psuedoterminals") + return nil }