diff --git a/config/config.go b/config/config.go index 7eacb1e92631e..36e0edef05cc2 100644 --- a/config/config.go +++ b/config/config.go @@ -282,7 +282,8 @@ func (s *Security) ToTLSConfig() (tlsConfig *tls.Config, err error) { return } tlsConfig = &tls.Config{ - RootCAs: certPool, + RootCAs: certPool, + ClientCAs: certPool, } if len(s.ClusterSSLCert) != 0 && len(s.ClusterSSLKey) != 0 { diff --git a/server/http_status.go b/server/http_status.go index 64ccf9ffb569b..74d29aa818224 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -295,6 +295,9 @@ func (s *Server) setupStatusServerAndRPCServer(addr string, serverMux *http.Serv logutil.BgLogger().Info("listen failed", zap.Error(err)) return } + if tlsConfig != nil { + logutil.BgLogger().Info("HTTP/gRPC status server secure connection is enabled", zap.Bool("CN verification enabled", tlsConfig.VerifyPeerCertificate != nil)) + } m := cmux.New(l) // Match connections in order: // First HTTP, and otherwise grpc. diff --git a/server/server.go b/server/server.go index c834d80e8a004..e573eab87d9d4 100644 --- a/server/server.go +++ b/server/server.go @@ -217,9 +217,11 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err)) return nil, err } - logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) - setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) - atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig)) + if tlsConfig != nil { + setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig)) + logutil.BgLogger().Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) + } setSystemTimeZoneVariable() @@ -383,6 +385,7 @@ func (s *Server) Close() { func (s *Server) onConn(conn *clientConn) { ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { + terror.Log(err) if plugin.IsEnable(plugin.Audit) { conn.ctx.GetSessionVars().ConnectionInfo = conn.connectInfo() } diff --git a/server/tidb_test.go b/server/tidb_test.go index ec8b31c7fb08a..3c0101977df4f 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -218,34 +218,51 @@ func (ts *tidbTestSuite) TestStatusAPIWithTLS(c *C) { } func (ts *tidbTestSuite) TestStatusAPIWithTLSCNCheck(c *C) { - c.Skip("need add ca-tidb-test-1.crt to OS") - root := filepath.Join(os.Getenv("GOPATH"), "/src/github.com/pingcap/tidb") - ca := filepath.Join(root, "/tests/cncheckcert/ca-tidb-test-1.crt") + caPath := filepath.Join(os.TempDir(), "ca-cert-cn.pem") + serverKeyPath := filepath.Join(os.TempDir(), "server-key-cn.pem") + serverCertPath := filepath.Join(os.TempDir(), "server-cert-cn.pem") + client1KeyPath := filepath.Join(os.TempDir(), "client-key-cn-check-a.pem") + client1CertPath := filepath.Join(os.TempDir(), "client-cert-cn-check-a.pem") + client2KeyPath := filepath.Join(os.TempDir(), "client-key-cn-check-b.pem") + client2CertPath := filepath.Join(os.TempDir(), "client-cert-cn-check-b.pem") + + caCert, caKey, err := generateCert(0, "TiDB CA CN CHECK", nil, nil, filepath.Join(os.TempDir(), "ca-key-cn.pem"), caPath) + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server-cn-check", caCert, caKey, serverKeyPath, serverCertPath) + c.Assert(err, IsNil) + _, _, err = generateCert(2, "tidb-client-cn-check-a", caCert, caKey, client1KeyPath, client1CertPath, func(c *x509.Certificate) { + c.Subject.CommonName = "tidb-client-1" + }) + c.Assert(err, IsNil) + _, _, err = generateCert(3, "tidb-client-cn-check-b", caCert, caKey, client2KeyPath, client2CertPath, func(c *x509.Certificate) { + c.Subject.CommonName = "tidb-client-2" + }) + c.Assert(err, IsNil) cli := newTestServerClient() cli.statusScheme = "https" cfg := config.NewConfig() cfg.Port = cli.port cfg.Status.StatusPort = cli.statusPort - cfg.Security.ClusterSSLCA = ca - cfg.Security.ClusterSSLCert = filepath.Join(root, "/tests/cncheckcert/server-cert.pem") - cfg.Security.ClusterSSLKey = filepath.Join(root, "/tests/cncheckcert/server-key.pem") + cfg.Security.ClusterSSLCA = caPath + cfg.Security.ClusterSSLCert = serverCertPath + cfg.Security.ClusterSSLKey = serverKeyPath cfg.Security.ClusterVerifyCN = []string{"tidb-client-2"} server, err := NewServer(cfg, ts.tidbdrv) c.Assert(err, IsNil) go server.Run() time.Sleep(time.Millisecond * 100) - hc := newTLSHttpClient(c, ca, - filepath.Join(root, "/tests/cncheckcert/client-cert-1.pem"), - filepath.Join(root, "/tests/cncheckcert/client-key-1.pem"), + hc := newTLSHttpClient(c, caPath, + client1CertPath, + client1KeyPath, ) _, err = hc.Get(cli.statusURL("/status")) c.Assert(err, NotNil) - hc = newTLSHttpClient(c, ca, - filepath.Join(root, "/tests/cncheckcert/client-cert-2.pem"), - filepath.Join(root, "/tests/cncheckcert/client-key-2.pem"), + hc = newTLSHttpClient(c, caPath, + client2CertPath, + client2KeyPath, ) _, err = hc.Get(cli.statusURL("/status")) c.Assert(err, IsNil)