Skip to content

Commit

Permalink
privilege: fix atomic problem of GRANT and REVOKE (#14219) (#15092)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lingyu Song authored Mar 4, 2020
1 parent b93defa commit 3bfee86
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 75 deletions.
178 changes: 133 additions & 45 deletions executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
)

/***
Expand Down Expand Up @@ -90,9 +92,35 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
}

// Grant for each user
for idx, user := range e.Users {
// Check if user exists.
// Commit the old transaction, like DDL.
if err := e.ctx.NewTxn(ctx); err != nil {
return err
}
defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }()

// Create internal session to start internal transaction.
isCommit := false
internalSession, err := e.getSysSession()
if err != nil {
return err
}
defer func() {
if !isCommit {
_, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback")
if err != nil {
logutil.Logger(context.Background()).Error("rollback error occur at grant privilege", zap.Error(err))
}
}
e.releaseSysSession(internalSession)
}()

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin")
if err != nil {
return err
}

// Check which user is not exist.
for _, user := range e.Users {
exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname)
if err != nil {
return err
Expand All @@ -106,31 +134,34 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd)
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user)
_, err := e.ctx.(sqlexec.SQLExecutor).Execute(ctx, sql)
_, err := internalSession.(sqlexec.SQLExecutor).Execute(ctx, sql)
if err != nil {
return err
}
}
}

// Grant for each user
for _, user := range e.Users {
// If there is no privilege entry in corresponding table, insert a new one.
// Global scope: mysql.global_priv
// DB scope: mysql.DB
// Table scope: mysql.Tables_priv
// Column scope: mysql.Columns_priv
if e.TLSOptions != nil {
err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname)
err = checkAndInitGlobalPriv(internalSession, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
}
switch e.Level.Level {
case ast.GrantLevelDB:
err := checkAndInitDBPriv(e.ctx, dbName, e.is, user.User.Username, user.User.Hostname)
err := checkAndInitDBPriv(internalSession, dbName, e.is, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
case ast.GrantLevelTable:
err := checkAndInitTablePriv(e.ctx, dbName, e.Level.TableName, e.is, user.User.Username, user.User.Hostname)
err := checkAndInitTablePriv(internalSession, dbName, e.Level.TableName, e.is, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
Expand All @@ -140,15 +171,8 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
privs = append(privs, &ast.PrivElem{Priv: mysql.GrantPriv})
}

if idx == 0 {
// Commit the old transaction, like DDL.
if err := e.ctx.NewTxn(ctx); err != nil {
return err
}
defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }()
}
// Grant global priv to user.
err = e.grantGlobalPriv(user)
err = e.grantGlobalPriv(internalSession, user)
if err != nil {
return err
}
Expand All @@ -157,17 +181,23 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
if len(priv.Cols) > 0 {
// Check column scope privilege entry.
// TODO: Check validity before insert new entry.
err := e.checkAndInitColumnPriv(user.User.Username, user.User.Hostname, priv.Cols)
err := e.checkAndInitColumnPriv(user.User.Username, user.User.Hostname, priv.Cols, internalSession)
if err != nil {
return err
}
}
err := e.grantLevelPriv(priv, user)
err := e.grantLevelPriv(priv, user, internalSession)
if err != nil {
return err
}
}
}

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit")
if err != nil {
return err
}
isCommit = true
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
}
Expand Down Expand Up @@ -216,7 +246,7 @@ func checkAndInitTablePriv(ctx sessionctx.Context, dbName, tblName string, is in

// checkAndInitColumnPriv checks if column scope privilege entry exists in mysql.Columns_priv.
// If unexists, insert a new one.
func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast.ColumnName) error {
func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast.ColumnName, internalSession sessionctx.Context) error {
dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is)
if err != nil {
return err
Expand All @@ -226,15 +256,15 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast
if col == nil {
return errors.Errorf("Unknown column: %s", c.Name.O)
}
ok, err := columnPrivEntryExists(e.ctx, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
ok, err := columnPrivEntryExists(internalSession, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
if ok {
continue
}
// Entry does not exist for user-host-db-tbl-col. Insert a new entry.
err = initColumnPrivEntry(e.ctx, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
err = initColumnPrivEntry(internalSession, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
Expand All @@ -245,33 +275,33 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast
// initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege.
func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}")
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initDBPrivEntry inserts a new row into mysql.DB with empty privilege.
func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initTablePrivEntry inserts a new row into mysql.Tables_priv with empty privilege.
func initTablePrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES ('%s', '%s', '%s', '%s', '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initColumnPrivEntry inserts a new row into mysql.Columns_priv with empty privilege.
func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string, col string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Column_name, Column_priv) VALUES ('%s', '%s', '%s', '%s', '%s', '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantGlobalPriv grants priv to user in global scope.
func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error {
func (e *GrantExec) grantGlobalPriv(ctx sessionctx.Context, user *ast.UserSpec) error {
if len(e.TLSOptions) == 0 {
return nil
}
Expand All @@ -280,7 +310,7 @@ func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error {
return errors.Trace(err)
}
sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

Expand Down Expand Up @@ -356,24 +386,24 @@ func tlsOption2GlobalPriv(tlsOptions []*ast.TLSOption) (priv []byte, err error)
}

// grantLevelPriv grants priv to user in s.Level scope.
func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
switch e.Level.Level {
case ast.GrantLevelGlobal:
return e.grantGlobalLevel(priv, user)
return e.grantGlobalLevel(priv, user, internalSession)
case ast.GrantLevelDB:
return e.grantDBLevel(priv, user)
return e.grantDBLevel(priv, user, internalSession)
case ast.GrantLevelTable:
if len(priv.Cols) == 0 {
return e.grantTableLevel(priv, user)
return e.grantTableLevel(priv, user, internalSession)
}
return e.grantColumnLevel(priv, user)
return e.grantColumnLevel(priv, user, internalSession)
default:
return errors.Errorf("Unknown grant level: %#v", e.Level)
}
}

// grantGlobalLevel manipulates mysql.user table.
func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
if priv.Priv == 0 {
return nil
}
Expand All @@ -382,12 +412,12 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) err
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user.User.Username, user.User.Hostname)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantDBLevel manipulates mysql.db table.
func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName := e.Level.DBName
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
Expand All @@ -397,28 +427,28 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, user.User.Username, user.User.Hostname, dbName)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantTableLevel manipulates mysql.tables_priv table.
func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName := e.Level.DBName
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
}
tblName := e.Level.TableName
asgns, err := composeTablePrivUpdateForGrant(e.ctx, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName)
asgns, err := composeTablePrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tblName)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantColumnLevel manipulates mysql.tables_priv table.
func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is)
if err != nil {
return err
Expand All @@ -429,12 +459,12 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) err
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
asgns, err := composeColumnPrivUpdateForGrant(e.ctx, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
asgns, err := composeColumnPrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -610,7 +640,11 @@ func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.Privile

// recordExists is a helper function to check if the sql returns any row.
func recordExists(ctx sessionctx.Context, sql string) (bool, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
recordSets, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return false, err
}
rows, _, err := getRowsAndFields(ctx, recordSets)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -645,14 +679,21 @@ func columnPrivEntryExists(ctx sessionctx.Context, name string, host string, db
// Return Table_priv and Column_priv.
func getTablePriv(ctx sessionctx.Context, name string, host string, db string, tbl string) (string, string, error) {
sql := fmt.Sprintf(`SELECT Table_priv, Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl)
rows, fields, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return "", "", err
}
if len(rows) < 1 {
if len(rs) < 1 {
return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl)
}
var tPriv, cPriv string
rows, fields, err := getRowsAndFields(ctx, rs)
if err != nil {
return "", "", err
}
if len(rows) < 1 {
return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl)
}
row := rows[0]
if fields[0].Column.Tp == mysql.TypeSet {
tablePriv := row.GetSet(0)
Expand All @@ -669,7 +710,14 @@ func getTablePriv(ctx sessionctx.Context, name string, host string, db string, t
// Return Column_priv.
func getColumnPriv(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (string, error) {
sql := fmt.Sprintf(`SELECT Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col)
rows, fields, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return "", err
}
if len(rs) < 1 {
return "", errors.Errorf("get column privilege fail for %s %s %s %s", name, host, db, tbl)
}
rows, fields, err := getRowsAndFields(ctx, rs)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -699,3 +747,43 @@ func getTargetSchemaAndTable(ctx sessionctx.Context, dbName, tableName string, i
}
return dbName, tbl, nil
}

// getRowsAndFields is used to extract rows from record sets.
func getRowsAndFields(ctx sessionctx.Context, recordSets []sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) {
var (
rows []chunk.Row
fields []*ast.ResultField
)

for i, rs := range recordSets {
tmp, err := getRowFromRecordSet(context.Background(), ctx, rs)
if err != nil {
return nil, nil, err
}
if err = rs.Close(); err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
}
return rows, fields, nil
}

func getRowFromRecordSet(ctx context.Context, se sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) {
var rows []chunk.Row
req := rs.NewChunk()
for {
err := rs.Next(ctx, req)
if err != nil || req.NumRows() == 0 {
return rows, err
}
iter := chunk.NewIterator4Chunk(req)
for r := iter.Begin(); r != iter.End(); r = iter.Next() {
rows = append(rows, r)
}
req = chunk.Renew(req, se.GetSessionVars().MaxChunkSize)
}
}
Loading

0 comments on commit 3bfee86

Please sign in to comment.