From 0ff60057bbafb685e9f9a97af5261f484f8283d1 Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Mon, 22 May 2023 10:51:18 -0400 Subject: [PATCH] ssh/test: set a timeout and WaitDelay on sshd subcommands This uses a copy of testenv.Command copied from the main repo, with light edits to allow the testenv helpers to build with Go 1.19. The testenv helper revealed an exec.Command leak in TestCertLogin, so we also fix that leak and simplify server cleanup using testing.T.Cleanup. For golang/go#60099. Fixes golang/go#60343. Change-Id: I7f79fcdb559498b987ee7689972ac53b83870aaf Reviewed-on: https://go-review.googlesource.com/c/crypto/+/496935 Auto-Submit: Bryan Mills TryBot-Result: Gopher Robot Reviewed-by: Roland Shoemaker Run-TryBot: Bryan Mills --- internal/testenv/exec.go | 120 ++++++++++++++++++++++++++++ internal/testenv/testenv_notunix.go | 15 ++++ internal/testenv/testenv_unix.go | 15 ++++ ssh/test/agent_unix_test.go | 1 - ssh/test/banner_test.go | 1 - ssh/test/cert_test.go | 1 - ssh/test/dial_unix_test.go | 1 - ssh/test/forward_unix_test.go | 17 ++-- ssh/test/multi_auth_test.go | 1 - ssh/test/session_test.go | 15 ---- ssh/test/test_unix_test.go | 64 +++++++-------- 11 files changed, 182 insertions(+), 69 deletions(-) create mode 100644 internal/testenv/exec.go create mode 100644 internal/testenv/testenv_notunix.go create mode 100644 internal/testenv/testenv_unix.go diff --git a/internal/testenv/exec.go b/internal/testenv/exec.go new file mode 100644 index 0000000000..4bacdc3ce8 --- /dev/null +++ b/internal/testenv/exec.go @@ -0,0 +1,120 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testenv + +import ( + "context" + "os" + "os/exec" + "reflect" + "strconv" + "testing" + "time" +) + +// CommandContext is like exec.CommandContext, but: +// - skips t if the platform does not support os/exec, +// - sends SIGQUIT (if supported by the platform) instead of SIGKILL +// in its Cancel function +// - if the test has a deadline, adds a Context timeout and WaitDelay +// for an arbitrary grace period before the test's deadline expires, +// - fails the test if the command does not complete before the test's deadline, and +// - sets a Cleanup function that verifies that the test did not leak a subprocess. +func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd { + t.Helper() + + var ( + cancelCtx context.CancelFunc + gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging) + ) + + if t, ok := t.(interface { + testing.TB + Deadline() (time.Time, bool) + }); ok { + if td, ok := t.Deadline(); ok { + // Start with a minimum grace period, just long enough to consume the + // output of a reasonable program after it terminates. + gracePeriod = 100 * time.Millisecond + if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" { + scale, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err) + } + gracePeriod *= time.Duration(scale) + } + + // If time allows, increase the termination grace period to 5% of the + // test's remaining time. + testTimeout := time.Until(td) + if gp := testTimeout / 20; gp > gracePeriod { + gracePeriod = gp + } + + // When we run commands that execute subprocesses, we want to reserve two + // grace periods to clean up: one for the delay between the first + // termination signal being sent (via the Cancel callback when the Context + // expires) and the process being forcibly terminated (via the WaitDelay + // field), and a second one for the delay becween the process being + // terminated and and the test logging its output for debugging. + // + // (We want to ensure that the test process itself has enough time to + // log the output before it is also terminated.) + cmdTimeout := testTimeout - 2*gracePeriod + + if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout { + // Either ctx doesn't have a deadline, or its deadline would expire + // after (or too close before) the test has already timed out. + // Add a shorter timeout so that the test will produce useful output. + ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout) + } + } + } + + cmd := exec.CommandContext(ctx, name, args...) + // Set the Cancel and WaitDelay fields only if present (go 1.20 and later). + // TODO: When Go 1.19 is no longer supported, remove this use of reflection + // and instead set the fields directly. + if cmdCancel := reflect.ValueOf(cmd).Elem().FieldByName("Cancel"); cmdCancel.IsValid() { + cmdCancel.Set(reflect.ValueOf(func() error { + if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded { + // The command timed out due to running too close to the test's deadline. + // There is no way the test did that intentionally — it's too close to the + // wire! — so mark it as a test failure. That way, if the test expects the + // command to fail for some other reason, it doesn't have to distinguish + // between that reason and a timeout. + t.Errorf("test timed out while running command: %v", cmd) + } else { + // The command is being terminated due to ctx being canceled, but + // apparently not due to an explicit test deadline that we added. + // Log that information in case it is useful for diagnosing a failure, + // but don't actually fail the test because of it. + t.Logf("%v: terminating command: %v", ctx.Err(), cmd) + } + return cmd.Process.Signal(Sigquit) + })) + } + if cmdWaitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); cmdWaitDelay.IsValid() { + cmdWaitDelay.Set(reflect.ValueOf(gracePeriod)) + } + + t.Cleanup(func() { + if cancelCtx != nil { + cancelCtx() + } + if cmd.Process != nil && cmd.ProcessState == nil { + t.Errorf("command was started, but test did not wait for it to complete: %v", cmd) + } + }) + + return cmd +} + +// Command is like exec.Command, but applies the same changes as +// testenv.CommandContext (with a default Context). +func Command(t testing.TB, name string, args ...string) *exec.Cmd { + t.Helper() + return CommandContext(t, context.Background(), name, args...) +} diff --git a/internal/testenv/testenv_notunix.go b/internal/testenv/testenv_notunix.go new file mode 100644 index 0000000000..c8918ce592 --- /dev/null +++ b/internal/testenv/testenv_notunix.go @@ -0,0 +1,15 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows || plan9 || (js && wasm) || wasip1 + +package testenv + +import ( + "os" +) + +// Sigquit is the signal to send to kill a hanging subprocess. +// On Unix we send SIGQUIT, but on non-Unix we only have os.Kill. +var Sigquit = os.Kill diff --git a/internal/testenv/testenv_unix.go b/internal/testenv/testenv_unix.go new file mode 100644 index 0000000000..4f51823ec6 --- /dev/null +++ b/internal/testenv/testenv_unix.go @@ -0,0 +1,15 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix + +package testenv + +import ( + "syscall" +) + +// Sigquit is the signal to send to kill a hanging subprocess. +// Send SIGQUIT to get a stack trace. +var Sigquit = syscall.SIGQUIT diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go index d90526c5cf..43fbdb22eb 100644 --- a/ssh/test/agent_unix_test.go +++ b/ssh/test/agent_unix_test.go @@ -17,7 +17,6 @@ import ( func TestAgentForward(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() diff --git a/ssh/test/banner_test.go b/ssh/test/banner_test.go index 22bdd67d19..3bfdd4b059 100644 --- a/ssh/test/banner_test.go +++ b/ssh/test/banner_test.go @@ -13,7 +13,6 @@ import ( func TestBannerCallbackAgainstOpenSSH(t *testing.T) { server := newServer(t) - defer server.Shutdown() clientConf := clientConfig() diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go index 77891e3622..83dd534c5c 100644 --- a/ssh/test/cert_test.go +++ b/ssh/test/cert_test.go @@ -18,7 +18,6 @@ import ( // Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly func TestCertLogin(t *testing.T) { s := newServer(t) - defer s.Shutdown() // Use a key different from the default. clientKey := testSigners["dsa"] diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index d3e3d54ed4..4a7ec31737 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -24,7 +24,6 @@ type dialTester interface { func testDial(t *testing.T, n, listenAddr string, x dialTester) { server := newServer(t) - defer server.Shutdown() sshConn := server.Dial(clientConfig()) defer sshConn.Close() diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go index f0595af75e..1171bc3a14 100644 --- a/ssh/test/forward_unix_test.go +++ b/ssh/test/forward_unix_test.go @@ -23,7 +23,6 @@ type closeWriter interface { func testPortForward(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -120,7 +119,6 @@ func TestPortForwardUnix(t *testing.T) { func testAcceptClose(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) sshListener, err := conn.Listen(n, listenAddr) @@ -162,10 +160,9 @@ func TestAcceptCloseUnix(t *testing.T) { // Check that listeners exit if the underlying client transport dies. func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { server := newServer(t) - defer server.Shutdown() - conn := server.Dial(clientConfig()) + client := server.Dial(clientConfig()) - sshListener, err := conn.Listen(n, listenAddr) + sshListener, err := client.Listen(n, listenAddr) if err != nil { t.Fatal(err) } @@ -184,14 +181,10 @@ func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { // It would be even nicer if we closed the server side, but it // is more involved as the fd for that side is dup()ed. - server.clientConn.Close() + server.lastDialConn.Close() - select { - case <-time.After(1 * time.Second): - t.Errorf("timeout: listener did not close.") - case err := <-quit: - t.Logf("quit as expected (error %v)", err) - } + err = <-quit + t.Logf("quit as expected (error %v)", err) } func TestPortForwardConnectionCloseTCP(t *testing.T) { diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go index da8f674b3e..6c253a7547 100644 --- a/ssh/test/multi_auth_test.go +++ b/ssh/test/multi_auth_test.go @@ -108,7 +108,6 @@ func TestMultiAuth(t *testing.T) { ctx := newMultiAuthTestCtx(t) server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")}) - defer server.Shutdown() clientConfig := clientConfig() server.setTestPassword(clientConfig.User, ctx.password) diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go index 7d96ced35d..e98b7865e5 100644 --- a/ssh/test/session_test.go +++ b/ssh/test/session_test.go @@ -25,7 +25,6 @@ import ( func TestRunCommandSuccess(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -42,7 +41,6 @@ func TestRunCommandSuccess(t *testing.T) { func TestHostKeyCheck(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() hostDB := hostKeyDB() @@ -64,7 +62,6 @@ func TestHostKeyCheck(t *testing.T) { func TestRunCommandStdin(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -87,7 +84,6 @@ func TestRunCommandStdin(t *testing.T) { func TestRunCommandStdinError(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -111,7 +107,6 @@ func TestRunCommandStdinError(t *testing.T) { func TestRunCommandFailed(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -128,7 +123,6 @@ func TestRunCommandFailed(t *testing.T) { func TestRunCommandWeClosed(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -148,7 +142,6 @@ func TestRunCommandWeClosed(t *testing.T) { func TestFuncLargeRead(t *testing.T) { server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -180,7 +173,6 @@ func TestFuncLargeRead(t *testing.T) { func TestKeyChange(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() hostDB := hostKeyDB() conf.HostKeyCallback = hostDB.Check @@ -227,7 +219,6 @@ func TestValidTerminalMode(t *testing.T) { t.Skipf("skipping on %s", runtime.GOOS) } server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -292,7 +283,6 @@ func TestWindowChange(t *testing.T) { t.Skipf("skipping on %s", runtime.GOOS) } server := newServer(t) - defer server.Shutdown() conn := server.Dial(clientConfig()) defer conn.Close() @@ -340,7 +330,6 @@ func TestWindowChange(t *testing.T) { func testOneCipher(t *testing.T, cipher string, cipherOrder []string) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() conf.Ciphers = []string{cipher} // Don't fail if sshd doesn't have the cipher. @@ -399,7 +388,6 @@ func TestMACs(t *testing.T) { for _, mac := range macOrder { t.Run(mac, func(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() conf.MACs = []string{mac} // Don't fail if sshd doesn't have the MAC. @@ -425,7 +413,6 @@ func TestKeyExchanges(t *testing.T) { for _, kex := range kexOrder { t.Run(kex, func(t *testing.T) { server := newServer(t) - defer server.Shutdown() conf := clientConfig() // Don't fail if sshd doesn't have the kex. conf.KeyExchanges = append([]string{kex}, kexOrder...) @@ -460,8 +447,6 @@ func TestClientAuthAlgorithms(t *testing.T) { } else { t.Errorf("failed for key %q", key) } - - server.Shutdown() }) } } diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go index 3012a9787f..f3f55db128 100644 --- a/ssh/test/test_unix_test.go +++ b/ssh/test/test_unix_test.go @@ -23,6 +23,7 @@ import ( "testing" "text/template" + "golang.org/x/crypto/internal/testenv" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/testdata" ) @@ -67,17 +68,13 @@ var configTmpl = map[string]*template.Template{ type server struct { t *testing.T - cleanup func() // executed during Shutdown configfile string - cmd *exec.Cmd - output bytes.Buffer // holds stderr from sshd process testUser string // test username for sshd testPasswd string // test password for sshd sshdTestPwSo string // dynamic library to inject a custom password into sshd - // Client half of the network connection. - clientConn net.Conn + lastDialConn net.Conn } func username() string { @@ -193,15 +190,15 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl s.t.Fatalf("unixConnection: %v", err) } - s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e") f, err := c2.File() if err != nil { s.t.Fatalf("UnixConn.File: %v", err) } defer f.Close() - s.cmd.Stdin = f - s.cmd.Stdout = f - s.cmd.Stderr = &s.output + cmd.Stdin = f + cmd.Stdout = f + cmd.Stderr = new(bytes.Buffer) if s.sshdTestPwSo != "" { if s.testUser == "" { @@ -210,18 +207,29 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl if s.testPasswd == "" { s.t.Fatal("password missing from sshd_test_pw.so config") } - s.cmd.Env = append(os.Environ(), + cmd.Env = append(os.Environ(), fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo), fmt.Sprintf("TEST_USER=%s", s.testUser), fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd)) } - if err := s.cmd.Start(); err != nil { - s.t.Fail() - s.Shutdown() + if err := cmd.Start(); err != nil { s.t.Fatalf("s.cmd.Start: %v", err) } - s.clientConn = c1 + s.lastDialConn = c1 + s.t.Cleanup(func() { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + cmd.Process.Signal(os.Interrupt) + cmd.Wait() + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd:\n%s", cmd.Stderr) + } + }) + conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config) if err != nil { return nil, err @@ -232,29 +240,11 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { conn, err := s.TryDial(config) if err != nil { - s.t.Fail() - s.Shutdown() s.t.Fatalf("ssh.Client: %v", err) } return conn } -func (s *server) Shutdown() { - if s.cmd != nil && s.cmd.Process != nil { - // Don't check for errors; if it fails it's most - // likely "os: process already finished", and we don't - // care about that. Use os.Interrupt, so child - // processes are killed too. - s.cmd.Process.Signal(os.Interrupt) - s.cmd.Wait() - } - if s.t.Failed() { - // log any output from sshd process - s.t.Logf("sshd: %s", s.output.String()) - } - s.cleanup() -} - func writeFile(path string, contents []byte) { f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) if err != nil { @@ -351,15 +341,15 @@ func newServerForConfig(t *testing.T, config string, configVars map[string]strin authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) } writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) + t.Cleanup(func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }) return &server{ t: t, configfile: f.Name(), - cleanup: func() { - if err := os.RemoveAll(dir); err != nil { - t.Error(err) - } - }, } }