Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support require-secure-transport startup option (#15341) #15415

Merged
merged 6 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,15 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible
github.com/pingcap/tipb v0.0.0-20191120045257-1b9900292ab6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0 h1:dXXNHvDwAEN1YNg
github.com/pingcap/kvproto v0.0.0-20200311073257-e53d835099b0/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg=
github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6 h1:Xm46UzfGEzxovTaj/hhIX8Q+o/mL4iB6SbwktExvMAY=
github.com/pingcap/parser v0.0.0-20200303082314-9711ba384af6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible h1:i8348dPpUM748ZtMPHvjCgagg/By7OJlzXHKgkc1tyY=
github.com/pingcap/parser v3.0.12-0.20200317072324-41ea4b21f5aa+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2 h1:NL23b8tsg6M1QpSQedK14/Jx++QeyKL2rGiBvXAQVfA=
github.com/pingcap/pd v1.1.0-beta.0.20191223090411-ea2b748f6ee2/go.mod h1:b4gaAPSxaVVtaB+EHamV4Nsv8JmTdjlw0cTKmp4+dRQ=
github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible h1:I8HirWsu1MZp6t9G/g8yKCEjJJxtHooKakEgccvdJ4M=
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -516,6 +517,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
31 changes: 18 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
errSecureTransportRequired = terror.ClassServer.New(codeSecureTransportRequired, mysql.MySQLErrName[mysql.ErrSecureTransportRequired])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -205,6 +206,8 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.Logger(context.Background()).Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}

setSystemTimeZoneVariable()
Expand Down Expand Up @@ -595,16 +598,18 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
codeSecureTransportRequired = mysql.ErrSecureTransportRequired
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
codeSecureTransportRequired: mysql.ErrSecureTransportRequired,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
9 changes: 9 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ func (ts *TidbTestSuite) TestErrorNoRollback(c *C) {
cfg.Port = 4006
cfg.Status.ReportStatus = false

cfg.Security = config.Security{
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)

// test reload tls fail with/without "error no rollback option"
cfg.Security = config.Security{
SSLCA: "/tmp/ca-cert-rollback.pem",
Expand Down
56 changes: 30 additions & 26 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,32 @@ import (

// Flag Names
const (
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRequireSecureTransport = "require-secure-transport"

nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand All @@ -114,6 +115,7 @@ var (
tokenLimit = flag.Int(nmTokenLimit, 1000, "the limit of concurrent executed sessions")
pluginDir = flag.String(nmPluginDir, "/data/deploy/plugin", "the folder that hold plugin")
pluginLoad = flag.String(nmPluginLoad, "", "wait load plugin name(separated by comma)")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -440,7 +442,9 @@ func overrideConfig() {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}

if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
// Log
if actualFlags[nmLogLevel] {
cfg.Log.Level = *logLevel
Expand Down
11 changes: 10 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -333,7 +338,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down