diff --git a/client.go b/client.go index cd55961f92..9ca53d2c3a 100644 --- a/client.go +++ b/client.go @@ -1537,7 +1537,7 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) { for n > 0 { addr := c.nextAddr() tlsConfig := c.cachedTLSConfig(addr) - conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig) + conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) if err == nil { return conn, nil } @@ -1568,7 +1568,43 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config { return cfg } -func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) { +var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out") + +var timeoutErrorChPool sync.Pool + +func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + tc := AcquireTimer(timeout) + defer ReleaseTimer(tc) + + var ch chan error + chv := timeoutErrorChPool.Get() + if chv == nil { + chv = make(chan error) + } + ch = chv.(chan error) + defer timeoutErrorChPool.Put(chv) + + conn := tls.Client(rawConn, tlsConfig) + + go func() { + ch <- conn.Handshake() + }() + + select { + case <-tc.C: + rawConn.Close() + <-ch + return nil, ErrTLSHandshakeTimeout + case err := <-ch: + if err != nil { + rawConn.Close() + return nil, err + } + return conn, nil + } +} + +func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { if dial == nil { if dialDualStack { dial = DialDualStack @@ -1585,7 +1621,10 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * panic("BUG: DialFunc returned (nil, nil)") } if isTLS { - conn = tls.Client(conn, tlsConfig) + if timeout == 0 { + return tls.Client(conn, tlsConfig), nil + } + return tlsClientHandshake(conn, tlsConfig, timeout) } return conn, nil } @@ -1992,7 +2031,7 @@ func (c *pipelineConnClient) init() { func (c *pipelineConnClient) worker() error { tlsConfig := c.cachedTLSConfig() - conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig) + conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) if err != nil { return err } diff --git a/client_test.go b/client_test.go index 8cd3679541..987135011b 100644 --- a/client_test.go +++ b/client_test.go @@ -1773,3 +1773,43 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch t: t, } } + +func TestClientTLSHandshakeTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + addr := listener.Addr().String() + defer listener.Close() + + complete := make(chan bool) + defer close(complete) + + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error(err) + return + } + <-complete + conn.Close() + }() + + client := Client{ + WriteTimeout: 1 * time.Second, + ReadTimeout: 1 * time.Second, + } + + _, _, err = client.Get(nil, "https://"+addr) + if err == nil { + t.Fatal("tlsClientHandshake completed successfully") + } + + if err != ErrTLSHandshakeTimeout { + t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) + } +}