diff --git a/cmd/main.go b/cmd/main.go index 270ab1db..b513ddb6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,6 +8,7 @@ import ( "os/exec" "os/signal" "strings" + "sync/atomic" "syscall" "github.com/dnephin/pflag" @@ -210,11 +211,8 @@ func run(opts *options) error { return finishRun(opts, exec, err) } exitErr := goTestProc.cmd.Wait() - siggedOut := <-goTestProc.signal // check if we received a SIGINT - - if siggedOut != nil { - n, _ := (siggedOut).(syscall.Signal) - return finishRun(opts, exec, fmt.Errorf("syscall.Signal==%d", int(n))) + if signum := atomic.LoadInt32(&goTestProc.signal); signum != 0 { + return finishRun(opts, exec, exitError{num: signalExitCode + int(signum)}) } if exitErr == nil || opts.rerunFailsMaxAttempts == 0 { return finishRun(opts, exec, exitErr) @@ -342,39 +340,41 @@ type proc struct { cmd waiter stdout io.Reader stderr io.Reader - signal chan os.Signal + // signal is atomically set to the signal value when a signal is received + // by newSignalHandler. + signal int32 } type waiter interface { Wait() error } -func startGoTest(ctx context.Context, args []string) (proc, error) { +func startGoTest(ctx context.Context, args []string) (*proc, error) { if len(args) == 0 { - return proc{}, errors.New("missing command to run") + return nil, errors.New("missing command to run") } cmd := exec.CommandContext(ctx, args[0], args[1:]...) - p := proc{cmd: cmd, signal: make(chan os.Signal, 1)} + p := proc{cmd: cmd} log.Debugf("exec: %s", cmd.Args) var err error p.stdout, err = cmd.StdoutPipe() if err != nil { - return p, err + return nil, err } p.stderr, err = cmd.StderrPipe() if err != nil { - return p, err + return nil, err } if err := cmd.Start(); err != nil { - return p, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " ")) + return nil, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " ")) } log.Debugf("go test pid: %d", cmd.Process.Pid) ctx, cancel := context.WithCancel(ctx) newSignalHandler(ctx, cmd.Process.Pid, &p) p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd} - return p, nil + return &p, nil } // ExitCodeWithDefault returns the ExitStatus of a process from the error returned by @@ -395,11 +395,27 @@ type exitCoder interface { ExitCode() int } -func isExitCoder(err error) bool { +func IsExitCoder(err error) bool { _, ok := err.(exitCoder) return ok } +type exitError struct { + num int +} + +func (e exitError) Error() string { + return fmt.Sprintf("exit code %d", e.num) +} + +func (e exitError) ExitCode() int { + return e.num +} + +// signalExitCode is the base value added to a signal number to produce the +// exit code value. This matches the behaviour of bash. +const signalExitCode = 128 + func newSignalHandler(ctx context.Context, pid int, p *proc) { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) @@ -409,11 +425,11 @@ func newSignalHandler(ctx context.Context, pid int, p *proc) { select { case <-ctx.Done(): - close(p.signal) return case s := <-c: + atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal))) + proc, err := os.FindProcess(pid) - p.signal <- s if err != nil { log.Errorf("failed to find pid of 'go test': %v", err) return diff --git a/cmd/main_e2e_test.go b/cmd/main_e2e_test.go index de8e2d38..f09b701a 100644 --- a/cmd/main_e2e_test.go +++ b/cmd/main_e2e_test.go @@ -214,7 +214,7 @@ func TestE2E_SignalHandler(t *testing.T) { assert.NilError(t, result.Cmd.Process.Signal(os.Interrupt)) icmd.WaitOnCmd(2*time.Second, result) - result.Assert(t, icmd.Expected{ExitCode: 102}) + result.Assert(t, icmd.Expected{ExitCode: 130}) } func TestE2E_MaxFails_EndTestRun(t *testing.T) { diff --git a/cmd/main_test.go b/cmd/main_test.go index ad33aec3..cd5deb39 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -305,8 +305,8 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) { {"Package": "pkg", "Action": "fail"} ` - fn := func(args []string) proc { - return proc{ + fn := func(args []string) *proc { + return &proc{ cmd: fakeWaiter{result: newExitCode("failed", 1)}, stdout: strings.NewReader(jsonFailed), stderr: bytes.NewReader(nil), @@ -339,8 +339,8 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) { {"Package": "pkg", "Action": "fail"} ` - fn := func(args []string) proc { - return proc{ + fn := func(args []string) *proc { + return &proc{ cmd: fakeWaiter{result: newExitCode("failed", 1)}, stdout: strings.NewReader(jsonFailed), stderr: strings.NewReader("anything here is an error\n"), @@ -375,8 +375,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) { {"Package": "pkg", "Action": "fail"} ` - fn := func(args []string) proc { - return proc{ + fn := func(args []string) *proc { + return &proc{ cmd: fakeWaiter{result: newExitCode("failed", 1)}, stdout: strings.NewReader(jsonFailed), stderr: bytes.NewReader(nil), diff --git a/cmd/rerunfails_test.go b/cmd/rerunfails_test.go index faf1153a..0d63c082 100644 --- a/cmd/rerunfails_test.go +++ b/cmd/rerunfails_test.go @@ -109,10 +109,10 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) { }, } - fn := func(args []string) proc { + fn := func(args []string) *proc { next := events[0] events = events[1:] - return proc{ + return &proc{ cmd: fakeWaiter{result: next.err}, stdout: strings.NewReader(next.out), stderr: bytes.NewReader(nil), @@ -136,9 +136,9 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) { assert.Error(t, err, "run-failed-3") } -func patchStartGoTestFn(f func(args []string) proc) func() { +func patchStartGoTestFn(f func(args []string) *proc) func() { orig := startGoTestFn - startGoTestFn = func(ctx context.Context, args []string) (proc, error) { + startGoTestFn = func(ctx context.Context, args []string) (*proc, error) { return f(args), nil } return func() { diff --git a/cmd/watch.go b/cmd/watch.go index 0864ec5e..b0665a67 100644 --- a/cmd/watch.go +++ b/cmd/watch.go @@ -34,7 +34,7 @@ func (w *watchRuns) run(event filewatcher.Event) error { args: w.opts.args, initFilePath: path, } - if err := runDelve(o); !isExitCoder(err) { + if err := runDelve(o); !IsExitCoder(err) { return fmt.Errorf("delve failed: %w", err) } return nil @@ -43,7 +43,7 @@ func (w *watchRuns) run(event filewatcher.Event) error { opts := w.opts opts.packages = []string{event.PkgPath} var err error - if w.prevExec, err = runSingle(&opts); !isExitCoder(err) { + if w.prevExec, err = runSingle(&opts); !IsExitCoder(err) { return err } return nil diff --git a/main.go b/main.go index c532be1e..854472ec 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "os" - "os/exec" "gotest.tools/gotestsum/cmd" "gotest.tools/gotestsum/cmd/tool" @@ -11,10 +10,10 @@ import ( func main() { err := route(os.Args) - switch err.(type) { - case nil: + switch { + case err == nil: return - case *exec.ExitError: + case cmd.IsExitCoder(err): // go test should already report the error to stderr, exit with // the same status code os.Exit(cmd.ExitCodeWithDefault(err))