From 0aa5087de0d907f2fcbf69b1fd9c106ccbbe9977 Mon Sep 17 00:00:00 2001 From: Sleeyax Date: Thu, 12 May 2022 16:53:57 +0200 Subject: [PATCH] Add optional method ProxyTLSConnection (closes #779) Removed the call to NetDialTLSContext from the HTTP proxy CONNECT step and replaced it with a regular net.Dial in order to prevent connection issues. Custom TLS connections can now be made via the new optional ProxyTLSConnection method, after the proxy connection has been successfully established. --- client.go | 45 ++++++++++++++++++++++++++++----------------- proxy.go | 2 +- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index d451edd1..ff31cb7c 100644 --- a/client.go +++ b/client.go @@ -65,6 +65,12 @@ type Dialer struct { // TLSClientConfig is ignored. NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // ProxyTLSConnection specifies the dial function for creating TLS connections through a Proxy. If + // ProxyTLSConnection is nil, NetDialTLSContext is used. + // If ProxyTLSConnection is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + ProxyTLSConnection func(ctx context.Context, proxyConn net.Conn) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -346,26 +352,31 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" && d.NetDialTLSContext == nil { - // If NetDialTLSContext is set, assume that the TLS handshake has already been done + if u.Scheme == "https" { + if d.ProxyTLSConnection != nil && d.Proxy != nil { + // If we are connected to a proxy, perform the TLS handshake through the existing tunnel + netConn, err = d.ProxyTLSConnection(ctx, netConn) + } else if d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done - cfg := cloneTLSConfig(d.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = hostNoPort - } - tlsConn := tls.Client(netConn, cfg) - netConn = tlsConn + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(ctx, tlsConn, cfg) - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } - if err != nil { - return nil, nil, err + if err != nil { + return nil, nil, err + } } } diff --git a/proxy.go b/proxy.go index e0f466b7..07266f0d 100644 --- a/proxy.go +++ b/proxy.go @@ -33,7 +33,7 @@ type httpProxyDialer struct { func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { hostPort, _ := hostPortNoPort(hpd.proxyURL) - conn, err := hpd.forwardDial(network, hostPort) + conn, err := net.Dial(network, hostPort) if err != nil { return nil, err }