Skip to content

Commit

Permalink
cherry pick #15341 to release-3.0
Browse files Browse the repository at this point in the history
Signed-off-by: sre-bot <sre-bot@pingcap.com>
  • Loading branch information
lysu authored and sre-bot committed Mar 17, 2020
1 parent 3400fe5 commit 222a0ed
Show file tree
Hide file tree
Showing 9 changed files with 2,268 additions and 10 deletions.
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,15 @@ type Log struct {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
1,082 changes: 1,082 additions & 0 deletions errno/errcode.go

Large diffs are not rendered by default.

1,079 changes: 1,079 additions & 0 deletions errno/errname.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.Logger(context.Background()).Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -516,6 +517,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
20 changes: 20 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,23 @@ func init() {
}

var (
<<<<<<< HEAD
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
=======
errUnknownFieldType = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MySQLErrName[errno.ErrUnknownFieldType])
errInvalidSequence = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MySQLErrName[errno.ErrInvalidSequence])
errInvalidType = terror.ClassServer.New(errno.ErrInvalidType, errno.MySQLErrName[errno.ErrInvalidType])
errNotAllowedCommand = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MySQLErrName[errno.ErrNotAllowedCommand])
errAccessDenied = terror.ClassServer.New(errno.ErrAccessDenied, errno.MySQLErrName[errno.ErrAccessDenied])
errConCount = terror.ClassServer.New(errno.ErrConCount, errno.MySQLErrName[errno.ErrConCount])
errSecureTransportRequired = terror.ClassServer.New(errno.ErrSecureTransportRequired, errno.MySQLErrName[errno.ErrSecureTransportRequired])
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -199,12 +209,22 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {

tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
<<<<<<< HEAD
logutil.Logger(context.Background()).Error("secure connection cert/key/ca load fail", zap.Error(err))
=======
logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)
}
if tlsConfig != nil {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
<<<<<<< HEAD
logutil.Logger(context.Background()).Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
=======
logutil.BgLogger().Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)
}

setSystemTimeZoneVariable()
Expand Down
12 changes: 12 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,18 @@ func (ts *TidbTestSuite) TestErrorNoRollback(c *C) {
cfg.Port = 4006
cfg.Status.ReportStatus = false

<<<<<<< HEAD
=======
cfg.Security = config.Security{
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)

>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)
// test reload tls fail with/without "error no rollback option"
cfg.Security = config.Security{
SSLCA: "/tmp/ca-cert-rollback.pem",
Expand Down
52 changes: 52 additions & 0 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import (

// Flag Names
const (
<<<<<<< HEAD
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
Expand All @@ -88,6 +89,36 @@ const (
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
=======
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRepairMode = "repair-mode"
nmRepairList = "repair-list"
nmRequireSecureTransport = "require-secure-transport"
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)

nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand All @@ -113,6 +144,13 @@ var (
tokenLimit = flag.Int(nmTokenLimit, 1000, "the limit of concurrent executed sessions")
pluginDir = flag.String(nmPluginDir, "/data/deploy/plugin", "the folder that hold plugin")
pluginLoad = flag.String(nmPluginLoad, "", "wait load plugin name(separated by comma)")
<<<<<<< HEAD
=======
affinityCPU = flag.String(nmAffinityCPU, "", "affinity cpu (cpu-no. separated by comma, e.g. 1,2,3)")
repairMode = flagBoolean(nmRepairMode, false, "enable admin repair mode")
repairList = flag.String(nmRepairList, "", "admin repair table list")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -439,6 +477,20 @@ func overrideConfig() {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}
<<<<<<< HEAD
=======
if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
if actualFlags[nmRepairMode] {
cfg.RepairMode = *repairMode
}
if actualFlags[nmRepairList] {
if cfg.RepairMode {
cfg.RepairTableList = stringToList(*repairList)
}
}
>>>>>>> aec6143... *: support require-secure-transport startup option (#15341)

// Log
if actualFlags[nmLogLevel] {
Expand Down
11 changes: 10 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -330,7 +335,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down

0 comments on commit 222a0ed

Please sign in to comment.