diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index a3bfcfee1..6b2217848 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -3,8 +3,10 @@ package pgdriver import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" + "io/ioutil" "net" "net/url" "os" @@ -246,15 +248,69 @@ func parseDSN(dsn string) ([]Option, error) { opts = append(opts, WithApplicationName(appName)) } - switch sslMode := q.string("sslmode"); sslMode { - case "verify-ca", "verify-full": - opts = append(opts, WithTLSConfig(new(tls.Config))) - case "allow", "prefer", "require", "": - opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) - case "disable": - opts = append(opts, WithInsecure(true)) - default: - return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) + if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" { + tlsConfig := &tls.Config{} + switch sslMode { + case "disable": + tlsConfig = nil + case "allow", "prefer", "": + tlsConfig.InsecureSkipVerify = true + case "require": + if sslRootCert == "" { + tlsConfig.InsecureSkipVerify = true + break + } + // For backwards compatibility reasons, in the presence of `sslrootcert`, + // `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at + // https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES . + fallthrough + case "verify-ca": + // The default certificate verification will also verify the host name + // which is not the behavior of `verify-ca`. As such, we need to manually + // check the certificate chain. + // At the time of writing, tls.Config has no option for this behavior + // (verify chain, but skip server name). + // See https://github.com/golang/go/issues/21971 . + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, 0, len(rawCerts)) + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return fmt.Errorf("pgdriver: failed to parse certificate: %w", err) + } + certs = append(certs, cert) + } + intermediates := x509.NewCertPool() + for _, cert := range certs[1:] { + intermediates.AddCert(cert) + } + _, err := certs[0].Verify(x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: intermediates, + }) + return err + } + case "verify-full": + tlsConfig.ServerName = u.Host + if host, _, err := net.SplitHostPort(u.Host); err == nil { + tlsConfig.ServerName = host + } + default: + return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) + } + if tlsConfig != nil && sslRootCert != "" { + rawCA, err := ioutil.ReadFile(sslRootCert) + if err != nil { + return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(rawCA) { + return nil, fmt.Errorf("pgdriver: failed to append root CA") + } + tlsConfig.RootCAs = certPool + } + opts = append(opts, WithTLSConfig(tlsConfig)) } if d := q.duration("timeout"); d != 0 { diff --git a/driver/pgdriver/proto.go b/driver/pgdriver/proto.go index 310b380ba..a72742cd8 100644 --- a/driver/pgdriver/proto.go +++ b/driver/pgdriver/proto.go @@ -121,7 +121,11 @@ func enableSSL(ctx context.Context, cn *Conn, tlsConf *tls.Config) error { return errors.New("pgdriver: SSL is not enabled on the server") } - cn.netConn = tls.Client(cn.netConn, tlsConf) + tlsCN := tls.Client(cn.netConn, tlsConf) + if err := tlsCN.HandshakeContext(ctx); err != nil { + return fmt.Errorf("pgdriver: TLS handshake failed: %w", err) + } + cn.netConn = tlsCN rd.Reset(cn.netConn) return nil