Skip to content

Commit

Permalink
Add ability to set timeout for handshake (#631)
Browse files Browse the repository at this point in the history
* Fixed issue with handshake timeout
  • Loading branch information
moredure authored and erikdubbelboer committed Aug 18, 2019
1 parent 2edabf3 commit ce02b85
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
47 changes: 43 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1537,7 +1537,7 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err == nil {
return conn, nil
}
Expand Down Expand Up @@ -1568,7 +1568,43 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
return cfg
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) {
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")

var timeoutErrorChPool sync.Pool

func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)

var ch chan error
chv := timeoutErrorChPool.Get()
if chv == nil {
chv = make(chan error)
}
ch = chv.(chan error)
defer timeoutErrorChPool.Put(chv)

conn := tls.Client(rawConn, tlsConfig)

go func() {
ch <- conn.Handshake()
}()

select {
case <-tc.C:
rawConn.Close()
<-ch
return nil, ErrTLSHandshakeTimeout
case err := <-ch:
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
if dial == nil {
if dialDualStack {
dial = DialDualStack
Expand All @@ -1585,7 +1621,10 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
panic("BUG: DialFunc returned (nil, nil)")
}
if isTLS {
conn = tls.Client(conn, tlsConfig)
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, timeout)
}
return conn, nil
}
Expand Down Expand Up @@ -1992,7 +2031,7 @@ func (c *pipelineConnClient) init() {

func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err != nil {
return err
}
Expand Down
40 changes: 40 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1773,3 +1773,43 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
t: t,
}
}

func TestClientTLSHandshakeTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}

addr := listener.Addr().String()
defer listener.Close()

complete := make(chan bool)
defer close(complete)

go func() {
conn, err := listener.Accept()
if err != nil {
t.Error(err)
return
}
<-complete
conn.Close()
}()

client := Client{
WriteTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
}

_, _, err = client.Get(nil, "https://"+addr)
if err == nil {
t.Fatal("tlsClientHandshake completed successfully")
}

if err != ErrTLSHandshakeTimeout {
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
}

0 comments on commit ce02b85

Please sign in to comment.