Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cmd/main.go): add a channel to proc{} for detecting SIGINT #210

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"os/exec"
"os/signal"
"strings"
"sync/atomic"
"syscall"

"github.com/dnephin/pflag"
"github.com/fatih/color"
Expand Down Expand Up @@ -209,6 +211,9 @@ func run(opts *options) error {
return finishRun(opts, exec, err)
}
exitErr := goTestProc.cmd.Wait()
if signum := atomic.LoadInt32(&goTestProc.signal); signum != 0 {
return finishRun(opts, exec, exitError{num: signalExitCode + int(signum)})
}
if exitErr == nil || opts.rerunFailsMaxAttempts == 0 {
return finishRun(opts, exec, exitErr)
}
Expand Down Expand Up @@ -335,15 +340,18 @@ type proc struct {
cmd waiter
stdout io.Reader
stderr io.Reader
// signal is atomically set to the signal value when a signal is received
// by newSignalHandler.
signal int32
}

type waiter interface {
Wait() error
}

func startGoTest(ctx context.Context, args []string) (proc, error) {
func startGoTest(ctx context.Context, args []string) (*proc, error) {
if len(args) == 0 {
return proc{}, errors.New("missing command to run")
return nil, errors.New("missing command to run")
}

cmd := exec.CommandContext(ctx, args[0], args[1:]...)
Expand All @@ -352,21 +360,21 @@ func startGoTest(ctx context.Context, args []string) (proc, error) {
var err error
p.stdout, err = cmd.StdoutPipe()
if err != nil {
return p, err
return nil, err
}
p.stderr, err = cmd.StderrPipe()
if err != nil {
return p, err
return nil, err
}
if err := cmd.Start(); err != nil {
return p, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
return nil, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
}
log.Debugf("go test pid: %d", cmd.Process.Pid)

ctx, cancel := context.WithCancel(ctx)
newSignalHandler(ctx, cmd.Process.Pid)
newSignalHandler(ctx, cmd.Process.Pid, &p)
p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd}
return p, nil
return &p, nil
}

// ExitCodeWithDefault returns the ExitStatus of a process from the error returned by
Expand All @@ -387,12 +395,28 @@ type exitCoder interface {
ExitCode() int
}

func isExitCoder(err error) bool {
func IsExitCoder(err error) bool {
_, ok := err.(exitCoder)
return ok
}

func newSignalHandler(ctx context.Context, pid int) {
type exitError struct {
num int
}

func (e exitError) Error() string {
return fmt.Sprintf("exit code %d", e.num)
}

func (e exitError) ExitCode() int {
return e.num
}

// signalExitCode is the base value added to a signal number to produce the
// exit code value. This matches the behaviour of bash.
const signalExitCode = 128

func newSignalHandler(ctx context.Context, pid int, p *proc) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)

Expand All @@ -403,6 +427,8 @@ func newSignalHandler(ctx context.Context, pid int) {
case <-ctx.Done():
return
case s := <-c:
atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal)))

proc, err := os.FindProcess(pid)
if err != nil {
log.Errorf("failed to find pid of 'go test': %v", err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/main_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestE2E_SignalHandler(t *testing.T) {
assert.NilError(t, result.Cmd.Process.Signal(os.Interrupt))
icmd.WaitOnCmd(2*time.Second, result)

result.Assert(t, icmd.Expected{ExitCode: 102})
result.Assert(t, icmd.Expected{ExitCode: 130})
}

func TestE2E_MaxFails_EndTestRun(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down Expand Up @@ -339,8 +339,8 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: strings.NewReader("anything here is an error\n"),
Expand Down Expand Up @@ -375,8 +375,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down
8 changes: 4 additions & 4 deletions cmd/rerunfails_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
},
}

fn := func(args []string) proc {
fn := func(args []string) *proc {
next := events[0]
events = events[1:]
return proc{
return &proc{
cmd: fakeWaiter{result: next.err},
stdout: strings.NewReader(next.out),
stderr: bytes.NewReader(nil),
Expand All @@ -136,9 +136,9 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
assert.Error(t, err, "run-failed-3")
}

func patchStartGoTestFn(f func(args []string) proc) func() {
func patchStartGoTestFn(f func(args []string) *proc) func() {
orig := startGoTestFn
startGoTestFn = func(ctx context.Context, args []string) (proc, error) {
startGoTestFn = func(ctx context.Context, args []string) (*proc, error) {
return f(args), nil
}
return func() {
Expand Down
4 changes: 2 additions & 2 deletions cmd/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
args: w.opts.args,
initFilePath: path,
}
if err := runDelve(o); !isExitCoder(err) {
if err := runDelve(o); !IsExitCoder(err) {
return fmt.Errorf("delve failed: %w", err)
}
return nil
Expand All @@ -43,7 +43,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
opts := w.opts
opts.packages = []string{event.PkgPath}
var err error
if w.prevExec, err = runSingle(&opts); !isExitCoder(err) {
if w.prevExec, err = runSingle(&opts); !IsExitCoder(err) {
return err
}
return nil
Expand Down
7 changes: 3 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"os"
"os/exec"

"gotest.tools/gotestsum/cmd"
"gotest.tools/gotestsum/cmd/tool"
Expand All @@ -11,10 +10,10 @@ import (

func main() {
err := route(os.Args)
switch err.(type) {
case nil:
switch {
case err == nil:
return
case *exec.ExitError:
case cmd.IsExitCoder(err):
// go test should already report the error to stderr, exit with
// the same status code
os.Exit(cmd.ExitCodeWithDefault(err))
Expand Down