Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: send SIGTERM signal to --cmd instead of SIGKILL #687

Merged
merged 5 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions cmd/templ/generatecmd/run/run_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package run_test

import (
"context"
"embed"
"io"
"net/http"
"os"
"path/filepath"
"syscall"
"testing"
"time"

"github.com/a-h/templ/cmd/templ/generatecmd/run"
)

//go:embed testprogram/*
var testprogram embed.FS

func TestGoRun(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}

// Copy testprogram to a temporary directory.
dir, err := os.MkdirTemp("", "testprogram")
if err != nil {
t.Fatalf("failed to make test dir: %v", err)
}
files, err := testprogram.ReadDir("testprogram")
if err != nil {
t.Fatalf("failed to read embedded dir: %v", err)
}
for _, file := range files {
srcFileName := "testprogram/" + file.Name()
srcData, err := testprogram.ReadFile(srcFileName)
if err != nil {
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
}
tgtFileName := filepath.Join(dir, file.Name())
tgtFile, err := os.Create(tgtFileName)
if err != nil {
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
}
defer tgtFile.Close()
if _, err := tgtFile.Write(srcData); err != nil {
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
}
}
// Rename the go.mod.embed file to go.mod.
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
t.Fatalf("failed to rename go.mod.embed: %v", err)
}

tests := []struct {
name string
cmd string
}{
{
name: "Well behaved programs get shut down",
cmd: "go run .",
},
{
name: "Badly behaved programs get shut down",
cmd: "go run . -badly-behaved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cmd, err := run.Run(ctx, dir, tt.cmd)
if err != nil {
t.Fatalf("failed to run program: %v", err)
}

time.Sleep(1 * time.Second)

pid := cmd.Process.Pid

if err := run.KillAll(); err != nil {
t.Fatalf("failed to kill all: %v", err)
}

// Check the parent process is no longer running.
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
t.Fatalf("process %d is still running", pid)
}
// Check that the child was stopped.
body, err := readResponse("http://localhost:7777")
if err == nil {
t.Fatalf("child process is still running: %s", body)
}
})
}
}

func readResponse(url string) (body string, err error) {
resp, err := http.Get(url)
if err != nil {
return body, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return body, err
}
return string(b), nil
}
46 changes: 35 additions & 11 deletions cmd/templ/generatecmd/run/run_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,63 @@ package run

import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)

var m = &sync.Mutex{}
var running = map[string]*exec.Cmd{}
var (
m = &sync.Mutex{}
running = map[string]*exec.Cmd{}
)

func KillAll() (err error) {
m.Lock()
defer m.Unlock()
var errs []error
for _, cmd := range running {
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
if err != nil {
return err
if err := kill(cmd); err != nil {
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
}
}
running = map[string]*exec.Cmd{}
return
return errors.Join(errs...)
}

func kill(cmd *exec.Cmd) (err error) {
errs := make([]error, 4)
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
errs[2] = ignoreExited(cmd.Wait())
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
return errors.Join(errs...)
}

func Stop(cmd *exec.Cmd) (err error) {
return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
func ignoreExited(err error) error {
if errors.Is(err, syscall.ESRCH) {
return nil
}
// Ignore *exec.ExitError
if _, ok := err.(*exec.ExitError); ok {
return nil
}
return err
}

func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
if err = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil {
return cmd, err
if err := kill(cmd); err != nil {
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
}

delete(running, input)
}
parts := strings.Fields(input)
Expand All @@ -48,7 +70,9 @@ func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err erro
args = append(args, parts[1:]...)
}

cmd = exec.Command(executable, args...)
cmd = exec.CommandContext(ctx, executable, args...)
// Wait for the process to finish gracefully before termination.
cmd.WaitDelay = time.Second * 3
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
Expand Down
3 changes: 3 additions & 0 deletions cmd/templ/generatecmd/run/testprogram/go.mod.embed
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module testprogram

go 1.22.6
63 changes: 63 additions & 0 deletions cmd/templ/generatecmd/run/testprogram/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)

// This is a test program. It is used only to test the behaviour of the run package.
// The run package is supposed to be able to run and stop programs. Those programs may start
// child processes, which should also be stopped when the parent program is stopped.

// For example, running `go run .` will compile an executable and run it.

// So, this program does nothing. It just waits for a signal to stop.

// In "Well behaved" mode, the program will stop when it receives a signal.
// In "Badly behaved" mode, the program will ignore the signal and continue running.

// The run package should be able to stop the program in both cases.

var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")

func main() {
flag.Parse()

mode := "Well behaved"
if *badlyBehavedFlag {
mode = "Badly behaved"
}
fmt.Printf("%s process %d started.\n", mode, os.Getpid())

// Start a web server on a known port so that we can check that this process is
// not running, when it's been started as a child process, and we don't know
// its pid.
go func() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d", os.Getpid())
})
err := http.ListenAndServe("127.0.0.1:7777", nil)
if err != nil {
fmt.Printf("Error running web server: %v\n", err)
}
}()

sigs := make(chan os.Signal, 1)
if !*badlyBehavedFlag {
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
}
for {
select {
case <-sigs:
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
return
case <-time.After(1 * time.Second):
fmt.Printf("Process %d still running...\n", os.Getpid())
}
}
}