Skip to content

Commit

Permalink
*: support reload tls used by mysql protocol in place (#14749)
Browse files Browse the repository at this point in the history
  • Loading branch information
lysu authored Mar 3, 2020
1 parent 41d9b44 commit 5c68d53
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 73 deletions.
2 changes: 2 additions & 0 deletions domain/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo

func (msm *mockSessionManager) Kill(cid uint64, query bool) {}

func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {}

func (*testSuite) TestT(c *C) {
defer testleak.AfterTest(c)()
store, err := mockstore.NewMockTikvStore()
Expand Down
4 changes: 4 additions & 0 deletions executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"context"
"crypto/tls"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -73,6 +74,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testExecSuite) TestShowProcessList(c *C) {
// Compose schema.
names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"}
Expand Down
4 changes: 4 additions & 0 deletions executor/explainfor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package executor_test

import (
"crypto/tls"
"fmt"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testSuite) TestExplainFor(c *C) {
tkRoot := testkit.NewTestKitWithInit(c, s.store)
tkUser := testkit.NewTestKitWithInit(c, s.store)
Expand Down
23 changes: 23 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
err = e.executeUse(x)
case *ast.FlushStmt:
err = e.executeFlush(x)
case *ast.AlterInstanceStmt:
err = e.executeAlterInstance(x)
case *ast.BeginStmt:
err = e.executeBegin(ctx, x)
case *ast.CommitStmt:
Expand Down Expand Up @@ -1098,6 +1101,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error {
return nil
}

func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
if s.ReloadTLS {
logutil.BgLogger().Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError))
sm := e.ctx.GetSessionManager()
tlsCfg, err := util.LoadTLSCertificates(
variable.SysVars["ssl_ca"].Value,
variable.SysVars["ssl_key"].Value,
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
return err
}
logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
}
sm.UpdateTLSConfig(tlsCfg)
}
return nil
}

func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error {
h := domain.GetDomain(e.ctx).StatsHandle()
err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID)
Expand Down
2 changes: 2 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool

func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {}

func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {}

func (s *testTableSuite) TestSomeTables(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
5 changes: 4 additions & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) {
case *ast.AnalyzeTableStmt:
return b.buildAnalyze(x)
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt:
return b.buildSimple(node.(ast.StmtNode))
Expand Down Expand Up @@ -1690,6 +1690,9 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
case *ast.FlushStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("RELOAD")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ReloadPriv, "", "", "", err)
case *ast.AlterInstanceStmt:
err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err)
case *ast.AlterUserStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
Expand Down
37 changes: 20 additions & 17 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,23 +500,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}

if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
if resp.Capability&mysql.ClientSSL > 0 {
tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig))
if tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
}
}
}

Expand Down
78 changes: 27 additions & 51 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ package server
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"unsafe"
// For pprof
_ "net/http/pprof"
"os"
Expand Down Expand Up @@ -104,7 +103,7 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
// Server is the MySQL protocol server
type Server struct {
cfg *config.Config
tlsConfig *tls.Config
tlsConfig unsafe.Pointer // *tls.Config
driver IDriver
listener net.Listener
socket net.Listener
Expand Down Expand Up @@ -209,16 +208,23 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
s.loadTLSCertificates()

tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
return nil, err
}
logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))

setSystemTimeZoneVariable()

s.capability = defaultCapability
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}

var err error

if s.cfg.Host != "" && s.cfg.Port != 0 {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
Expand Down Expand Up @@ -258,51 +264,12 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
return s, nil
}

func (s *Server) loadTLSCertificates() {
defer func() {
if s.tlsConfig != nil {
logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert
variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey
} else {
logutil.BgLogger().Warn("secure connection is not enabled")
}
}()

if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 {
s.tlsConfig = nil
return
}

tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey)
if err != nil {
logutil.BgLogger().Warn("load x509 failed", zap.Error(err))
s.tlsConfig = nil
return
}

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
var certPool *x509.CertPool
if len(s.cfg.Security.SSLCA) > 0 {
caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA)
if err != nil {
logutil.BgLogger().Warn("read file failed", zap.Error(err))
} else {
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA
}
}
s.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: certPool,
ClientAuth: clientAuthPolicy,
}
func setSSLVariable(ca, key, cert string) {
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = cert
variable.SysVars["ssl_key"].Value = key
variable.SysVars["ssl_ca"].Value = ca
}

// Run runs the server.
Expand Down Expand Up @@ -564,6 +531,15 @@ func (s *Server) Kill(connectionID uint64, query bool) {
killConn(conn)
}

// UpdateTLSConfig implements the SessionManager interface.
func (s *Server) UpdateTLSConfig(cfg *tls.Config) {
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg))
}

func (s *Server) getTLSConfig() *tls.Config {
return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig))
}

func killConn(conn *clientConn) {
sessVars := conn.ctx.GetSessionVars()
atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1)
Expand Down
19 changes: 18 additions & 1 deletion server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/go-sql-driver/mysql"
. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/log"
tmysql "github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -1155,10 +1156,26 @@ func (cli *testServerClient) runTestStmtCount(t *C) {
}

func (cli *testServerClient) runTestTLSConnection(t *C, overrider configOverrider) error {
db, err := sql.Open("mysql", cli.getDSN(overrider))
dsn := cli.getDSN(overrider)
db, err := sql.Open("mysql", dsn)
t.Assert(err, IsNil)
defer db.Close()
_, err = db.Exec("USE test")
if err != nil {
return errors.Annotate(err, "dsn:"+dsn)
}
return err
}

func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error {
db, err := sql.Open("mysql", cli.getDSN(overrider))
t.Assert(err, IsNil)
defer db.Close()
sql := "alter instance reload tls"
if errorNoRollback {
sql += " no rollback on error"
}
_, err = db.Exec(sql)
return err
}

Expand Down
Loading

0 comments on commit 5c68d53

Please sign in to comment.