Skip to content

Commit

Permalink
*: support reload tls used by mysql protocol in place (#14749) (#15080)
Browse files Browse the repository at this point in the history
  • Loading branch information
sre-bot authored Mar 5, 2020
1 parent 3bfee86 commit 819603f
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 77 deletions.
4 changes: 4 additions & 0 deletions executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"context"
"crypto/tls"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -62,6 +63,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testExecSuite) TestShowProcessList(c *C) {
// Compose schema.
names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"}
Expand Down
4 changes: 4 additions & 0 deletions executor/explainfor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package executor_test

import (
"crypto/tls"
"fmt"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) {

}

func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) {
}

func (s *testSuite) TestExplainFor(c *C) {
tkRoot := testkit.NewTestKitWithInit(c, s.store)
tkUser := testkit.NewTestKitWithInit(c, s.store)
Expand Down
23 changes: 23 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
Expand Down Expand Up @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
err = e.executeUse(x)
case *ast.FlushStmt:
err = e.executeFlush(x)
case *ast.AlterInstanceStmt:
err = e.executeAlterInstance(x)
case *ast.BeginStmt:
err = e.executeBegin(ctx, x)
case *ast.CommitStmt:
Expand Down Expand Up @@ -1093,6 +1096,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error {
return nil
}

func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
if s.ReloadTLS {
logutil.Logger(context.Background()).Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError))
sm := e.ctx.GetSessionManager()
tlsCfg, err := util.LoadTLSCertificates(
variable.SysVars["ssl_ca"].Value,
variable.SysVars["ssl_key"].Value,
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
return err
}
logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
}
sm.UpdateTLSConfig(tlsCfg)
}
return nil
}

func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error {
h := domain.GetDomain(e.ctx).StatsHandle()
err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd
github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible
github.com/pingcap/tipb v0.0.0-20191120045257-1b9900292ab6
Expand Down
5 changes: 2 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d h1:zTHgLr8+0LTEJmj
github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5 h1:r2c8RQynYNGCFDWFPgo3TNx7Roq94STRcYTrtTg3JQ4=
github.com/pingcap/parser v0.0.0-20200301155133-79ec3dee69a5/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6 h1:Xm46UzfGEzxovTaj/hhIX8Q+o/mL4iB6SbwktExvMAY=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2 h1:NL23b8tsg6M1QpSQedK14/Jx++QeyKL2rGiBvXAQVfA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2/go.mod h1:b4gaAPSxaVVtaB+EHamV4Nsv8JmTdjlw0cTKmp4+dRQ=
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible h1:I8HirWsu1MZp6t9G/g8yKCEjJJxtHooKakEgccvdJ4M=
Expand All @@ -183,7 +183,6 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFd
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY2EPqg2NbXKuMHs5pXJB9hjj1fDHnF2vl28=
github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc=
github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs=
github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
Expand Down
3 changes: 3 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package infoschema_test

import (
"crypto/tls"
"fmt"
"os"
"strconv"
Expand Down Expand Up @@ -321,6 +322,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool

func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {}

func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {}

func (s *testTableSuite) TestSomeTables(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
5 changes: 4 additions & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) {
case *ast.AnalyzeTableStmt:
return b.buildAnalyze(x)
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt:
return b.buildSimple(node.(ast.StmtNode))
Expand Down Expand Up @@ -1385,6 +1385,9 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
p := &Simple{Statement: node}

switch raw := node.(type) {
case *ast.AlterInstanceStmt:
err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err)
case *ast.AlterUserStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
Expand Down
37 changes: 20 additions & 17 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,23 +495,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}

if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
if resp.Capability&mysql.ClientSSL > 0 {
tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig))
if tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
}
}
}

Expand Down
78 changes: 26 additions & 52 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ package server
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
Expand All @@ -43,6 +41,7 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

// For pprof
_ "net/http/pprof"
Expand Down Expand Up @@ -104,7 +103,7 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
// Server is the MySQL protocol server
type Server struct {
cfg *config.Config
tlsConfig *tls.Config
tlsConfig unsafe.Pointer // *tls.Config
driver IDriver
listener net.Listener
socket net.Listener
Expand Down Expand Up @@ -203,16 +202,22 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
s.loadTLSCertificates()

tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
logutil.Logger(context.Background()).Error("secure connection cert/key/ca load fail", zap.Error(err))
}
logutil.Logger(context.Background()).Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))

setSystemTimeZoneVariable()

s.capability = defaultCapability
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}

var err error

if s.cfg.Host != "" && s.cfg.Port != 0 {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
Expand Down Expand Up @@ -252,52 +257,12 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
return s, nil
}

func (s *Server) loadTLSCertificates() {
defer func() {
if s.tlsConfig != nil {
logutil.Logger(context.Background()).Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert
variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey
} else {
logutil.Logger(context.Background()).Warn("secure connection is not enabled")
}
}()

if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 {
s.tlsConfig = nil
return
}

tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey)
if err != nil {
logutil.Logger(context.Background()).Warn("load x509 failed", zap.Error(err))
s.tlsConfig = nil
return
}

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
var certPool *x509.CertPool
if len(s.cfg.Security.SSLCA) > 0 {
caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA)
if err != nil {
logutil.Logger(context.Background()).Warn("read file failed", zap.Error(err))
} else {
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA
}
}
s.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: certPool,
ClientAuth: clientAuthPolicy,
MinVersion: 0,
}
func setSSLVariable(ca, key, cert string) {
variable.SysVars["have_openssl"].Value = "YES"
variable.SysVars["have_ssl"].Value = "YES"
variable.SysVars["ssl_cert"].Value = cert
variable.SysVars["ssl_key"].Value = key
variable.SysVars["ssl_ca"].Value = ca
}

// Run runs the server.
Expand Down Expand Up @@ -545,6 +510,15 @@ func (s *Server) Kill(connectionID uint64, query bool) {
killConn(conn)
}

// UpdateTLSConfig implements the SessionManager interface.
func (s *Server) UpdateTLSConfig(cfg *tls.Config) {
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg))
}

func (s *Server) getTLSConfig() *tls.Config {
return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig))
}

func killConn(conn *clientConn) {
sessVars := conn.ctx.GetSessionVars()
atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1)
Expand Down
19 changes: 18 additions & 1 deletion server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"github.com/pingcap/errors"
"io/ioutil"
"net/http"
"os"
Expand Down Expand Up @@ -1012,10 +1013,26 @@ func runTestStmtCount(t *C) {
}

func runTestTLSConnection(t *C, overrider configOverrider) error {
db, err := sql.Open("mysql", getDSN(overrider))
dsn := getDSN(overrider)
db, err := sql.Open("mysql", dsn)
t.Assert(err, IsNil)
defer db.Close()
_, err = db.Exec("USE test")
if err != nil {
return errors.Annotate(err, "dsn:"+dsn)
}
return err
}

func runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error {
db, err := sql.Open("mysql", getDSN(overrider))
t.Assert(err, IsNil)
defer db.Close()
sql := "alter instance reload tls"
if errorNoRollback {
sql += " no rollback on error"
}
_, err = db.Exec(sql)
return err
}

Expand Down
Loading

0 comments on commit 819603f

Please sign in to comment.