Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server, sessionctx: improved mysql compatibility with support for init_connect (#23713) #26031

Merged
merged 11 commits into from
Jul 14, 2021
64 changes: 64 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,18 @@ func (cc *clientConn) handshake(ctx context.Context) error {
}
return err
}

// MySQL supports an "init_connect" query, which can be run on initial connection.
// The query must return a non-error or the client is disconnected.
if err := cc.initConnect(ctx); err != nil {
logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err))
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
if err1 := cc.writeError(ctx, initErr); err1 != nil {
terror.Log(err1)
}
return initErr
}

data := cc.alloc.AllocWithLen(4, 32)
data = append(data, mysql.OKHeader)
data = append(data, 0, 0)
Expand Down Expand Up @@ -713,6 +725,58 @@ func (cc *clientConn) PeerHost(hasPassword string) (host string, err error) {
return
}

// skipInitConnect follows MySQL's rules of when init-connect should be skipped.
// In 5.7 it is any user with SUPER privilege, but in 8.0 it is:
// - SUPER or the CONNECTION_ADMIN dynamic privilege.
// - (additional exception) users with expired passwords (not yet supported)
// In TiDB CONNECTION_ADMIN is satisfied by SUPER, so we only need to check once.
func (cc *clientConn) skipInitConnect() bool {
checker := cc.ctx.GetPrivilegeManager()
activeRoles := cc.ctx.GetSessionVars().ActiveRoles
return checker != nil && checker.RequestVerification(activeRoles, "", "", "", mysql.SuperPriv)
}

// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
if err != nil {
return err
}
if val == "" || cc.skipInitConnect() {
return nil
}
logutil.Logger(ctx).Debug("init_connect starting")
stmts, err := cc.ctx.Parse(ctx, val)
if err != nil {
return err
}
for _, stmt := range stmts {
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
if err != nil {
return err
}
// init_connect does not care about the results,
// but they need to be drained because of lazy loading.
if rs != nil {
req := rs.NewChunk()
for {
if err = rs.Next(ctx, req); err != nil {
return err
}
if req.NumRows() == 0 {
break
}
}
if err := rs.Close(); err != nil {
return err
}
}
}
logutil.Logger(ctx).Debug("init_connect complete")
return nil
}

// Run reads client query and writes query result to client in for loop, if there is a panic during query handling,
// it will be recovered and log the panic error.
// This function returns and the connection is closed if there is an IO error or there is a panic.
Expand Down
3 changes: 3 additions & 0 deletions server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -106,6 +107,8 @@ type QueryCtx interface {

SetSessionManager(util.SessionManager)

GetPrivilegeManager() privilege.Manager

// GetTxnWriteThroughputSLI returns the TxnWriteThroughputSLI.
GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI
}
Expand Down
6 changes: 6 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -369,6 +370,11 @@ func (tc *TiDBContext) SetCommandValue(command byte) {
tc.session.SetCommandValue(command)
}

// GetPrivilegeManager implements QueryCtx GetPrivilegeManager method.
func (tc *TiDBContext) GetPrivilegeManager() privilege.Manager {
return privilege.GetPrivilegeManager(tc.session)
}

// GetSessionVars return SessionVars.
func (tc *TiDBContext) GetSessionVars() *variable.SessionVars {
return tc.session.GetSessionVars()
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var (
errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount)
errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired)
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
48 changes: 48 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,54 @@ func (cli *testServerClient) waitUntilServerOnline() {
}
}

func (cli *testServerClient) runTestInitConnect(c *C) {

cli.runTests(c, nil, func(dbt *DBTest) {
dbt.mustExec(`SET GLOBAL init_connect="insert into test.ts VALUES (NOW());SET @a=1;"`)
dbt.mustExec(`CREATE USER init_nonsuper`)
dbt.mustExec(`CREATE USER init_super`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP ON test.* TO init_nonsuper`)
dbt.mustExec(`GRANT SELECT, INSERT, DROP, SUPER ON *.* TO init_super`)
dbt.mustExec(`CREATE TABLE ts (a TIMESTAMP)`)
})

// test init_nonsuper
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_nonsuper"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT @a`)
c.Assert(rows.Next(), IsTrue)
var a int
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, 1)
c.Assert(rows.Close(), IsNil)
})

// test init_super
cli.runTests(c, func(config *mysql.Config) {
config.User = "init_super"
}, func(dbt *DBTest) {
rows := dbt.mustQuery(`SELECT IFNULL(@a,"")`)
c.Assert(rows.Next(), IsTrue)
var a string
err := rows.Scan(&a)
c.Assert(err, IsNil)
dbt.Check(a, Equals, "") // null
c.Assert(rows.Close(), IsNil)
// change the init-connect to invalid.
dbt.mustExec(`SET GLOBAL init_connect="invalidstring"`)
})

db, err := sql.Open("mysql", cli.getDSN(func(config *mysql.Config) {
config.User = "init_nonsuper"
}))
c.Assert(err, IsNil, Commentf("Error connecting")) // doesn't fail because of lazy loading
defer db.Close() // may already be closed
_, err = db.Exec("SELECT 1") // fails because of init sql
c.Assert(err, NotNil)
}

// Client errors are only incremented when using the TiDB Server protocol,
// and not internal SQL statements. Thus, this test is in the server-test suite.
func (cli *testServerClient) runTestInfoschemaClientErrors(t *C) {
Expand Down
4 changes: 4 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,10 @@ func (ts *tidbTestSuite) TestClientErrors(c *C) {
ts.runTestInfoschemaClientErrors(c)
}

func (ts *tidbTestSuite) TestInitConnect(c *C) {
ts.runTestInitConnect(c)
}

func (ts *tidbTestSuite) TestSumAvg(c *C) {
c.Parallel()
ts.runTestSumAvg(c)
Expand Down
2 changes: 1 addition & 1 deletion sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ var defaultSysVars = []*SysVar{
{ScopeNone, "innodb_rollback_on_timeout", "0"},
{ScopeGlobal | ScopeSession, "query_alloc_block_size", "8192"},
{ScopeGlobal, SlaveCompressedProtocol, "0"},
{ScopeGlobal | ScopeSession, InitConnect, ""},
{ScopeGlobal, InitConnect, ""},
{ScopeGlobal, "rpl_semi_sync_slave_trace_level", ""},
{ScopeNone, "have_compress", "YES"},
{ScopeNone, "thread_concurrency", "10"},
Expand Down