Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner/core: fix point-get db privilege check #12268

Merged
merged 5 commits into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
// This plan is much faster to build and to execute because it avoid the optimization and coprocessor cost.
type PointGetPlan struct {
basePlan
dbName string
schema *expression.Schema
TblInfo *model.TableInfo
IndexInfo *model.IndexInfo
Expand Down Expand Up @@ -299,10 +300,6 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if tbl == nil {
return nil
}
dbName := tblName.Schema
if dbName.L == "" {
dbName = model.NewCIStr(ctx.GetSessionVars().CurrentDB)
}
// Do not handle partitioned table.
// Table partition implementation translates LogicalPlan from `DataSource` to
// `Union -> DataSource` in the logical plan optimization pass, since PointGetPlan
Expand Down Expand Up @@ -331,7 +328,11 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if schema == nil {
return nil
}
p := newPointGetPlan(ctx, schema, tbl, names)
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
Expand Down Expand Up @@ -371,7 +372,11 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if schema == nil {
return nil
}
p := newPointGetPlan(ctx, schema, tbl, names)
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
p.IndexInfo = idxInfo
p.IndexValues = idxValues
p.IndexValueParams = idxValueParams
Expand All @@ -380,9 +385,10 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}

func newPointGetPlan(ctx sessionctx.Context, schema *expression.Schema, tbl *model.TableInfo, names []*types.FieldName) *PointGetPlan {
func newPointGetPlan(ctx sessionctx.Context, dbName string, schema *expression.Schema, tbl *model.TableInfo, names []*types.FieldName) *PointGetPlan {
p := &PointGetPlan{
basePlan: newBasePlan(ctx, "Point_Get", 0),
dbName: dbName,
schema: schema,
TblInfo: tbl,
outputNames: names,
Expand All @@ -396,9 +402,8 @@ func checkFastPlanPrivilege(ctx sessionctx.Context, fastPlan *PointGetPlan, chec
if pm == nil {
return nil
}
dbName := ctx.GetSessionVars().CurrentDB
for _, checkType := range checkTypes {
if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, dbName, fastPlan.TblInfo.Name.L, "", checkType) {
if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, fastPlan.dbName, fastPlan.TblInfo.Name.L, "", checkType) {
return errors.New("privilege check fail")
}
}
Expand Down
18 changes: 18 additions & 0 deletions privilege/privileges/privileges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ func (s *testPrivilegeSuite) TestCheckDBPrivilege(c *C) {
c.Assert(pc.RequestVerification(activeRoles, "test", "", "", mysql.UpdatePriv), IsTrue)
}

func (s *testPrivilegeSuite) TestCheckPointGetDBPrivilege(c *C) {
rootSe := newSession(c, s.store, s.dbName)
mustExec(c, rootSe, `CREATE USER 'tester'@'localhost';`)
mustExec(c, rootSe, `GRANT SELECT,UPDATE ON test.* TO 'tester'@'localhost';`)
mustExec(c, rootSe, `flush privileges;`)
mustExec(c, rootSe, `create database test2`)
mustExec(c, rootSe, `create table test2.t(id int, v int, primary key(id))`)
mustExec(c, rootSe, `insert into test2.t(id, v) values(1, 1)`)

se := newSession(c, s.store, s.dbName)
c.Assert(se.Auth(&auth.UserIdentity{Username: "tester", Hostname: "localhost"}, nil, nil), IsTrue)
mustExec(c, se, `use test;`)
_, err := se.Execute(context.Background(), `select * from test2.t where id = 1`)
c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue)
_, err = se.Execute(context.Background(), "update test2.t set v = 2 where id = 1")
c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue)
}

func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) {
rootSe := newSession(c, s.store, s.dbName)
mustExec(c, rootSe, `CREATE USER 'test1'@'localhost';`)
Expand Down