Skip to content
This repository has been archived by the owner on May 8, 2024. It is now read-only.

Always kill spawned shell's process group to avoid pipe FD hangs #217

Merged
merged 6 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ func (executor *Executor) ExecuteScriptsAndStreamLogs(
scripts []string,
env map[string]string,
) (*exec.Cmd, error) {
sc, err := NewShellCommands(scripts, &env, func(bytes []byte) (int, error) {
sc, err := NewShellCommands(ctx, scripts, &env, func(bytes []byte) (int, error) {
return logUploader.Write(bytes)
})
var cmd *exec.Cmd
Expand Down
19 changes: 10 additions & 9 deletions internal/executor/piper/piper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package piper

import (
"context"
"errors"
"io"
"os"
Expand Down Expand Up @@ -32,25 +33,25 @@ func New(output io.Writer) (*Piper, error) {
return piper, nil
}

func (piper *Piper) Input() *os.File {
func (piper *Piper) FileProxy() *os.File {
return piper.w
}

func (piper *Piper) Close(force bool) (result error) {
// Terminate the Goroutine started in New()
if force {
_ = piper.r.Close()
}

func (piper *Piper) Close(ctx context.Context) (result error) {
// Close our writing end (if not closed yet)
if err := piper.w.Close(); err != nil && !errors.Is(err, os.ErrClosed) && result == nil {
result = err
}

// Wait for the Goroutine started in New(): it will reach EOF once
// all the copies of the writing end file descriptor are closed
if err := <-piper.errChan; err != nil && !errors.Is(err, os.ErrClosed) && result == nil {
result = err
select {
case err := <-piper.errChan:
if err != nil && !errors.Is(err, os.ErrClosed) && result == nil {
result = err
}
case <-ctx.Done():
result = ctx.Err()
}

return result
Expand Down
31 changes: 17 additions & 14 deletions internal/executor/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func ShellCommandsAndGetOutput(ctx context.Context, scripts []string, custom_env

// return true if executed successful
func ShellCommandsAndWait(ctx context.Context, scripts []string, custom_env *map[string]string, handler ShellOutputHandler) (*exec.Cmd, error) {
sc, err := NewShellCommands(scripts, custom_env, handler)
sc, err := NewShellCommands(ctx, scripts, custom_env, handler)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -73,14 +73,10 @@ func ShellCommandsAndWait(ctx context.Context, scripts []string, custom_env *map

return cmd, TimeOutError
case <-done:
if err := sc.piper.Close(false); err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
if err := sc.kill(); err != nil {
handler([]byte(fmt.Sprintf("\nFailed to kill a partially completed shell session: %s", err)))
}
} else {
handler([]byte(fmt.Sprintf("\nShell session I/O error: %s", err)))
}
_ = sc.kill()

if err := sc.piper.Close(ctx); err != nil {
handler([]byte(fmt.Sprintf("\nShell session I/O error: %s", err)))
}

if ws, ok := cmd.ProcessState.Sys().(syscall.WaitStatus); ok {
Expand All @@ -99,7 +95,12 @@ func ShellCommandsAndWait(ctx context.Context, scripts []string, custom_env *map
}
}

func NewShellCommands(scripts []string, custom_env *map[string]string, handler ShellOutputHandler) (*ShellCommands, error) {
func NewShellCommands(
ctx context.Context,
scripts []string,
custom_env *map[string]string,
handler ShellOutputHandler,
) (*ShellCommands, error) {
var cmd *exec.Cmd
var scriptFile *os.File
var err error
Expand Down Expand Up @@ -159,12 +160,12 @@ func NewShellCommands(scripts []string, custom_env *map[string]string, handler S
return nil, err
}

cmd.Stderr = sc.piper.Input()
cmd.Stdout = sc.piper.Input()
cmd.Stderr = sc.piper.FileProxy()
cmd.Stdout = sc.piper.FileProxy()

err = cmd.Start()
if err != nil {
if err := sc.piper.Close(true); err != nil {
if err := sc.piper.Close(ctx); err != nil {
_, _ = fmt.Fprintf(writer, "Shell session I/O error: %s", err)
}

Expand All @@ -175,7 +176,9 @@ func NewShellCommands(scripts []string, custom_env *map[string]string, handler S

sc.afterStart()

if err := sc.piper.Input().Close(); err != nil {
// At this point the shell has successfully started and inherited
// the proxy file descriptor. We can release our own descriptor now.
if err := sc.piper.FileProxy().Close(); err != nil {
_, _ = fmt.Fprintf(writer, "Shell session I/O error: %s", err)
}

Expand Down
32 changes: 30 additions & 2 deletions internal/executor/shell_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package executor
import (
"context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os/exec"
"runtime"
"testing"
Expand Down Expand Up @@ -107,12 +108,39 @@ func Test_ShellCommands_Timeout_Unix(t *testing.T) {
}
}

func TestChildrenProcessesAreNotWaitedFor(t *testing.T) {
func TestChildrenProcessesAreCancelled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

success, output := ShellCommandsAndGetOutput(ctx, []string{"sleep 60 & sleep 1"}, nil)
success, output := ShellCommandsAndGetOutput(ctx, []string{"sleep 60 & sleep 10"}, nil)

assert.False(t, success)
assert.Contains(t, output, "Timed out!")
}

func TestChildrenProcessesAreNotWaitedFor(t *testing.T) {
startTime := time.Now()

success, output := ShellCommandsAndGetOutput(context.Background(), []string{"sleep 60 & sleep 1"}, nil)

if time.Since(startTime) > 5*time.Second {
t.Fatalf("took more than 5 seconds")
}

assert.True(t, success)
assert.NotContains(t, output, "Timed out!")
}

func TestShellStartFailureDoesNotHang(t *testing.T) {
startTime := time.Now()

success, _ := ShellCommandsAndGetOutput(context.Background(), []string{"true"}, &map[string]string{
"CIRRUS_SHELL": "/bin/non-existent-shell",
})

if time.Since(startTime) > 1*time.Second {
t.Fatalf("took more than 1 second")
}

require.False(t, success)
}