Skip to content

Commit

Permalink
*: support require-secure-transport startup option (#15341) (#15415)
Browse files Browse the repository at this point in the history
  • Loading branch information
sre-bot authored Mar 18, 2020
1 parent 0043698 commit d80855c
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 52 deletions.
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,15 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible
github.com/pingcap/tipb v0.0.0-20191120045257-1b9900292ab6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0 h1:dXXNHvDwAEN1YNg
github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6 h1:Xm46UzfGEzxovTaj/hhIX8Q+o/mL4iB6SbwktExvMAY=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible h1:i8348dPpUM748ZtMPHvjCgagg/By7OJlzXHKgkc1tyY=
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2 h1:NL23b8tsg6M1QpSQedK14/Jx++QeyKL2rGiBvXAQVfA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2/go.mod h1:b4gaAPSxaVVtaB+EHamV4Nsv8JmTdjlw0cTKmp4+dRQ=
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible h1:I8HirWsu1MZp6t9G/g8yKCEjJJxtHooKakEgccvdJ4M=
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -516,6 +517,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
31 changes: 18 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errSecureTransportRequired = terror.ClassServer.New(codeSecureTransportRequired, mysql.MySQLErrName[mysql.ErrSecureTransportRequired])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -205,6 +206,8 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.Logger(context.Background()).Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}

setSystemTimeZoneVariable()
Expand Down Expand Up @@ -595,16 +598,18 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeSecureTransportRequired = mysql.ErrSecureTransportRequired
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeSecureTransportRequired: mysql.ErrSecureTransportRequired,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
9 changes: 9 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ func (ts *TidbTestSuite) TestErrorNoRollback(c *C) {
cfg.Port = 4006
cfg.Status.ReportStatus = false

cfg.Security = config.Security{
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)

// test reload tls fail with/without "error no rollback option"
cfg.Security = config.Security{
SSLCA: "/tmp/ca-cert-rollback.pem",
Expand Down
56 changes: 30 additions & 26 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,32 @@ import (

// Flag Names
const (
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRequireSecureTransport = "require-secure-transport"

nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand All @@ -114,6 +115,7 @@ var (
tokenLimit = flag.Int(nmTokenLimit, 1000, "the limit of concurrent executed sessions")
pluginDir = flag.String(nmPluginDir, "/data/deploy/plugin", "the folder that hold plugin")
pluginLoad = flag.String(nmPluginLoad, "", "wait load plugin name(separated by comma)")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -440,7 +442,9 @@ func overrideConfig() {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}

if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
// Log
if actualFlags[nmLogLevel] {
cfg.Log.Level = *logLevel
Expand Down
11 changes: 10 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -333,7 +338,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down

0 comments on commit d80855c

Please sign in to comment.