diff --git a/internal/testenv/exec.go b/internal/testenv/exec.go new file mode 100644 index 0000000000..4bacdc3ce8 --- /dev/null +++ b/internal/testenv/exec.go @@ -0,0 +1,120 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testenv + +import ( + "context" + "os" + "os/exec" + "reflect" + "strconv" + "testing" + "time" +) + +// CommandContext is like exec.CommandContext, but: +// - skips t if the platform does not support os/exec, +// - sends SIGQUIT (if supported by the platform) instead of SIGKILL +// in its Cancel function +// - if the test has a deadline, adds a Context timeout and WaitDelay +// for an arbitrary grace period before the test's deadline expires, +// - fails the test if the command does not complete before the test's deadline, and +// - sets a Cleanup function that verifies that the test did not leak a subprocess. +func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd { + t.Helper() + + var ( + cancelCtx context.CancelFunc + gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging) + ) + + if t, ok := t.(interface { + testing.TB + Deadline() (time.Time, bool) + }); ok { + if td, ok := t.Deadline(); ok { + // Start with a minimum grace period, just long enough to consume the + // output of a reasonable program after it terminates. + gracePeriod = 100 * time.Millisecond + if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" { + scale, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err) + } + gracePeriod *= time.Duration(scale) + } + + // If time allows, increase the termination grace period to 5% of the + // test's remaining time. + testTimeout := time.Until(td) + if gp := testTimeout / 20; gp > gracePeriod { + gracePeriod = gp + } + + // When we run commands that execute subprocesses, we want to reserve two + // grace periods to clean up: one for the delay between the first + // termination signal being sent (via the Cancel callback when the Context + // expires) and the process being forcibly terminated (via the WaitDelay + // field), and a second one for the delay becween the process being + // terminated and and the test logging its output for debugging. + // + // (We want to ensure that the test process itself has enough time to + // log the output before it is also terminated.) + cmdTimeout := testTimeout - 2*gracePeriod + + if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout { + // Either ctx doesn't have a deadline, or its deadline would expire + // after (or too close before) the test has already timed out. + // Add a shorter timeout so that the test will produce useful output. + ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout) + } + } + } + + cmd := exec.CommandContext(ctx, name, args...) + // Set the Cancel and WaitDelay fields only if present (go 1.20 and later). + // TODO: When Go 1.19 is no longer supported, remove this use of reflection + // and instead set the fields directly. + if cmdCancel := reflect.ValueOf(cmd).Elem().FieldByName("Cancel"); cmdCancel.IsValid() { + cmdCancel.Set(reflect.ValueOf(func() error { + if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded { + // The command timed out due to running too close to the test's deadline. + // There is no way the test did that intentionally — it's too close to the + // wire! — so mark it as a test failure. That way, if the test expects the + // command to fail for some other reason, it doesn't have to distinguish + // between that reason and a timeout. + t.Errorf("test timed out while running command: %v", cmd) + } else { + // The command is being terminated due to ctx being canceled, but + // apparently not due to an explicit test deadline that we added. + // Log that information in case it is useful for diagnosing a failure, + // but don't actually fail the test because of it. + t.Logf("%v: terminating command: %v", ctx.Err(), cmd) + } + return cmd.Process.Signal(Sigquit) + })) + } + if cmdWaitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); cmdWaitDelay.IsValid() { + cmdWaitDelay.Set(reflect.ValueOf(gracePeriod)) + } + + t.Cleanup(func() { + if cancelCtx != nil { + cancelCtx() + } + if cmd.Process != nil && cmd.ProcessState == nil { + t.Errorf("command was started, but test did not wait for it to complete: %v", cmd) + } + }) + + return cmd +} + +// Command is like exec.Command, but applies the same changes as +// testenv.CommandContext (with a default Context). +func Command(t testing.TB, name string, args ...string) *exec.Cmd { + t.Helper() + return CommandContext(t, context.Background(), name, args...) +} diff --git a/internal/testenv/testenv_notunix.go b/internal/testenv/testenv_notunix.go new file mode 100644 index 0000000000..c8918ce592 --- /dev/null +++ b/internal/testenv/testenv_notunix.go @@ -0,0 +1,15 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows || plan9 || (js && wasm) || wasip1 + +package testenv + +import ( + "os" +) + +// Sigquit is the signal to send to kill a hanging subprocess. +// On Unix we send SIGQUIT, but on non-Unix we only have os.Kill. +var Sigquit = os.Kill diff --git a/internal/testenv/testenv_unix.go b/internal/testenv/testenv_unix.go new file mode 100644 index 0000000000..4f51823ec6 --- /dev/null +++ b/internal/testenv/testenv_unix.go @@ -0,0 +1,15 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix + +package testenv + +import ( + "syscall" +) + +// Sigquit is the signal to send to kill a hanging subprocess. +// Send SIGQUIT to get a stack trace. +var Sigquit = syscall.SIGQUIT diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go index d90526c5cf..43fbdb22eb 100644 --- a/ssh/test/agent_unix_test.go +++ b/ssh/test/agent_unix_test.go @@ -17,7 +17,6 @@ import ( func TestAgentForward(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() diff --git a/ssh/test/banner_test.go b/ssh/test/banner_test.go index 22bdd67d19..3bfdd4b059 100644 --- a/ssh/test/banner_test.go +++ b/ssh/test/banner_test.go @@ -13,7 +13,6 @@ import ( func TestBannerCallbackAgainstOpenSSH(t *testing.T) { server := newServer(t) - defer server.Shutdown() clientConf := clientConfig() diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go index 77891e3622..83dd534c5c 100644 --- a/ssh/test/cert_test.go +++ b/ssh/test/cert_test.go @@ -18,7 +18,6 @@ import ( // Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly func TestCertLogin(t *testing.T) { s := newServer(t) - defer s.Shutdown() // Use a key different from the default. clientKey := testSigners["dsa"] diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index d3e3d54ed4..4a7ec31737 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -24,7 +24,6 @@ type dialTester interface { func testDial(t *testing.T, n, listenAddr string, x dialTester) { server := newServer(t) - defer server.Shutdown() sshConn := server.Dial(clientConfig()) defer sshConn.Close() diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go index f0595af75e..1171bc3a14 100644 --- a/ssh/test/forward_unix_test.go +++ b/ssh/test/forward_unix_test.go @@ -23,7 +23,6 @@ type closeWriter interface { func testPortForward(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -120,7 +119,6 @@ func TestPortForwardUnix(t *testing.T) { func testAcceptClose(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) sshListener, err := conn.Listen(n, listenAddr) @@ -162,10 +160,9 @@ func TestAcceptCloseUnix(t *testing.T) { // Check that listeners exit if the underlying client transport dies. func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() - conn := server.Dial(clientConfig()) + client := server.Dial(clientConfig()) - sshListener, err := conn.Listen(n, listenAddr) + sshListener, err := client.Listen(n, listenAddr) if err != nil { t.Fatal(err) } @@ -184,14 +181,10 @@ func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { // It would be even nicer if we closed the server side, but it // is more involved as the fd for that side is dup()ed. - server.clientConn.Close() + server.lastDialConn.Close() - select { - case <-time.After(1 * time.Second): - t.Errorf("timeout: listener did not close.") - case err := <-quit: - t.Logf("quit as expected (error %v)", err) - } + err = <-quit + t.Logf("quit as expected (error %v)", err) } func TestPortForwardConnectionCloseTCP(t *testing.T) { diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go index da8f674b3e..6c253a7547 100644 --- a/ssh/test/multi_auth_test.go +++ b/ssh/test/multi_auth_test.go @@ -108,7 +108,6 @@ func TestMultiAuth(t *testing.T) { ctx := newMultiAuthTestCtx(t) server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")}) - defer server.Shutdown() clientConfig := clientConfig() server.setTestPassword(clientConfig.User, ctx.password) diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go index 7d96ced35d..e98b7865e5 100644 --- a/ssh/test/session_test.go +++ b/ssh/test/session_test.go @@ -25,7 +25,6 @@ import ( func TestRunCommandSuccess(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -42,7 +41,6 @@ func TestRunCommandSuccess(t *testing.T) { func TestHostKeyCheck(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() hostDB := hostKeyDB() @@ -64,7 +62,6 @@ func TestHostKeyCheck(t *testing.T) { func TestRunCommandStdin(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -87,7 +84,6 @@ func TestRunCommandStdin(t *testing.T) { func TestRunCommandStdinError(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -111,7 +107,6 @@ func TestRunCommandStdinError(t *testing.T) { func TestRunCommandFailed(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -128,7 +123,6 @@ func TestRunCommandFailed(t *testing.T) { func TestRunCommandWeClosed(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -148,7 +142,6 @@ func TestRunCommandWeClosed(t *testing.T) { func TestFuncLargeRead(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -180,7 +173,6 @@ func TestFuncLargeRead(t *testing.T) { func TestKeyChange(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() hostDB := hostKeyDB() conf.HostKeyCallback = hostDB.Check @@ -227,7 +219,6 @@ func TestValidTerminalMode(t *testing.T) { t.Skipf("skipping on %s", runtime.GOOS) } server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -292,7 +283,6 @@ func TestWindowChange(t *testing.T) { t.Skipf("skipping on %s", runtime.GOOS) } server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -340,7 +330,6 @@ func TestWindowChange(t *testing.T) { func testOneCipher(t *testing.T, cipher string, cipherOrder []string) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() conf.Ciphers = []string{cipher} // Don't fail if sshd doesn't have the cipher. @@ -399,7 +388,6 @@ func TestMACs(t *testing.T) { for _, mac := range macOrder { t.Run(mac, func(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() conf.MACs = []string{mac} // Don't fail if sshd doesn't have the MAC. @@ -425,7 +413,6 @@ func TestKeyExchanges(t *testing.T) { for _, kex := range kexOrder { t.Run(kex, func(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() // Don't fail if sshd doesn't have the kex. conf.KeyExchanges = append([]string{kex}, kexOrder...) @@ -460,8 +447,6 @@ func TestClientAuthAlgorithms(t *testing.T) { } else { t.Errorf("failed for key %q", key) } - - server.Shutdown() }) } } diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go index 3012a9787f..f3f55db128 100644 --- a/ssh/test/test_unix_test.go +++ b/ssh/test/test_unix_test.go @@ -23,6 +23,7 @@ import ( "testing" "text/template" + "golang.org/x/crypto/internal/testenv" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/testdata" ) @@ -67,17 +68,13 @@ var configTmpl = map[string]*template.Template{ type server struct { t *testing.T - cleanup func() // executed during Shutdown configfile string - cmd *exec.Cmd - output bytes.Buffer // holds stderr from sshd process testUser string // test username for sshd testPasswd string // test password for sshd sshdTestPwSo string // dynamic library to inject a custom password into sshd - // Client half of the network connection. - clientConn net.Conn + lastDialConn net.Conn } func username() string { @@ -193,15 +190,15 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl s.t.Fatalf("unixConnection: %v", err) } - s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e") f, err := c2.File() if err != nil { s.t.Fatalf("UnixConn.File: %v", err) } defer f.Close() - s.cmd.Stdin = f - s.cmd.Stdout = f - s.cmd.Stderr = &s.output + cmd.Stdin = f + cmd.Stdout = f + cmd.Stderr = new(bytes.Buffer) if s.sshdTestPwSo != "" { if s.testUser == "" { @@ -210,18 +207,29 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl if s.testPasswd == "" { s.t.Fatal("password missing from sshd_test_pw.so config") } - s.cmd.Env = append(os.Environ(), + cmd.Env = append(os.Environ(), fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo), fmt.Sprintf("TEST_USER=%s", s.testUser), fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd)) } - if err := s.cmd.Start(); err != nil { - s.t.Fail() - s.Shutdown() + if err := cmd.Start(); err != nil { s.t.Fatalf("s.cmd.Start: %v", err) } - s.clientConn = c1 + s.lastDialConn = c1 + s.t.Cleanup(func() { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + cmd.Process.Signal(os.Interrupt) + cmd.Wait() + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd:\n%s", cmd.Stderr) + } + }) + conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config) if err != nil { return nil, err @@ -232,29 +240,11 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { conn, err := s.TryDial(config) if err != nil { - s.t.Fail() - s.Shutdown() s.t.Fatalf("ssh.Client: %v", err) } return conn } -func (s *server) Shutdown() { - if s.cmd != nil && s.cmd.Process != nil { - // Don't check for errors; if it fails it's most - // likely "os: process already finished", and we don't - // care about that. Use os.Interrupt, so child - // processes are killed too. - s.cmd.Process.Signal(os.Interrupt) - s.cmd.Wait() - } - if s.t.Failed() { - // log any output from sshd process - s.t.Logf("sshd: %s", s.output.String()) - } - s.cleanup() -} - func writeFile(path string, contents []byte) { f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) if err != nil { @@ -351,15 +341,15 @@ func newServerForConfig(t *testing.T, config string, configVars map[string]strin authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) } writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) + t.Cleanup(func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }) return &server{ t: t, configfile: f.Name(), - cleanup: func() { - if err := os.RemoveAll(dir); err != nil { - t.Error(err) - } - }, } }