Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Add OnInterrupt to run function code on CTRL+C #301

Merged
merged 7 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions survey.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/AlecAivazis/survey/v2/terminal"
)

// OnInterrupt is the function to run when
// SIGINT (CTRL+C) is sent to the process.
var OnInterrupt func()

// DefaultAskOptions is the default options on ask, using the OS stdio.
func defaultAskOptions() *AskOptions {
return &AskOptions{
Expand Down Expand Up @@ -55,6 +59,7 @@ func defaultAskOptions() *AskOptions {
},
KeepFilter: false,
},
OnInterrupt: OnInterrupt,
}
}
func defaultPromptConfig() *PromptConfig {
Expand Down Expand Up @@ -135,6 +140,7 @@ type AskOptions struct {
Stdio terminal.Stdio
Validators []Validator
PromptConfig PromptConfig
OnInterrupt func()
}

// WithStdio specifies the standard input, output and error files survey
Expand Down Expand Up @@ -180,6 +186,16 @@ func WithValidator(v Validator) AskOpt {
}
}

// WithInterruptFunc specifies a function to run on recieving
// SIGINT (aka CTRL+C) during prompt.
func WithInterruptFunc(fn func()) AskOpt {
return func(options *AskOptions) error {
options.OnInterrupt = fn
// nothing went wrong
return nil
}
}

type wantsStdio interface {
WithStdio(terminal.Stdio)
}
Expand Down Expand Up @@ -289,6 +305,10 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {

// grab the user input and save it
ans, err := q.Prompt.Prompt(&options.PromptConfig)
// if SIGINT is recieved.
if err == terminal.InterruptErr {
options.OnInterrupt()
}
// if there was a problem
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions survey_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
)

func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) {
func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error {
t.Parallel()

// Multiplex output to a buffer as well for the raw bytes.
Expand All @@ -28,7 +28,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S
}()

err = test(Stdio(c))
require.Nil(t, err)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this intentionally removed? If so, can you explain why?

Copy link
Contributor Author

@infalmo infalmo Nov 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because, L#30 would return an error when SIGINT was encountered, and the tests would fail (due to the assert). So I shifted the error assertion to the helper functions (RunPromptTest and the other one).

I could revert this, but that would mean rewriting another RunTest function exclusively for the OnInterrupt test, which isn't a very good idea imo.


// Close the slave end of the pty, and read the remaining bytes from the master end.
c.Tty().Close()
Expand All @@ -38,4 +37,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S

// Dump the terminal's screen.
t.Logf("\n%s", expect.StripTrailingEmptyLines(state.String()))

return err
}
72 changes: 64 additions & 8 deletions survey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type PromptTest struct {

func RunPromptTest(t *testing.T, test PromptTest) {
var answer interface{}
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
var err error
if p, ok := test.prompt.(wantsStdio); ok {
p.WithStdio(stdio)
Expand All @@ -40,12 +40,13 @@ func RunPromptTest(t *testing.T, test PromptTest) {
answer, err = test.prompt.Prompt(defaultPromptConfig())
return err
})
require.Nil(t, err)
require.Equal(t, test.expected, answer)
}

func RunPromptTestKeepFilter(t *testing.T, test PromptTest) {
var answer interface{}
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
var err error
if p, ok := test.prompt.(wantsStdio); ok {
p.WithStdio(stdio)
Expand All @@ -55,6 +56,7 @@ func RunPromptTestKeepFilter(t *testing.T, test PromptTest) {
answer, err = test.prompt.Prompt(config)
return err
})
require.Nil(t, err)
require.Equal(t, test.expected, answer)
}

Expand Down Expand Up @@ -133,7 +135,6 @@ func TestPagination_lastHalf(t *testing.T) {

func TestAsk(t *testing.T) {
t.Skip()
return
tests := []struct {
name string
questions []*Question
Expand Down Expand Up @@ -250,10 +251,10 @@ func TestAsk(t *testing.T) {
"pizza": true,
"commit-message": "Add editor prompt tests\n",
"commit-message-validated": "Add editor prompt tests\n",
"name": "Johnny Appleseed",
"day": []string{"Monday", "Wednesday"},
"password": "secret",
"color": "yellow",
"name": "Johnny Appleseed",
"day": []string{"Monday", "Wednesday"},
"password": "secret",
"color": "yellow",
},
},
{
Expand Down Expand Up @@ -305,9 +306,10 @@ func TestAsk(t *testing.T) {
test := test
t.Run(test.name, func(t *testing.T) {
answers := make(map[string]interface{})
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
return Ask(test.questions, &answers, WithStdio(stdio.In, stdio.Out, stdio.Err))
})
require.Nil(t, err)
require.Equal(t, test.expected, answers)
})
}
Expand All @@ -323,3 +325,57 @@ func TestAsk_returnsErrorIfTargetIsNil(t *testing.T) {
t.Error("Did not encounter error when asking with no where to record.")
}
}

func TestOnInterruptFunc(t *testing.T) {
// No Interrupt function set.
t.Run("No OnInterrupt", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})

// Set global Interrupt function.
OnInterrupt = func() { fmt.Println("Ended abruptly!") }
t.Run("Global OnInterrupt", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectString("Ended abruptly!")
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})

// Set local Interrupt function (overide global).
t.Run("Local Interrupt (override global", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectString("The end.")
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err),
WithInterruptFunc(func() { fmt.Println("The end.") }))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})
}
3 changes: 2 additions & 1 deletion survey_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
expect "github.com/Netflix/go-expect"
)

func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) {
func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error {
t.Skip("Windows does not support psuedoterminals")
return nil
}