diff --git a/executor/builder.go b/executor/builder.go index 7d5a2fd50c009..86f503e82f098 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -559,6 +559,7 @@ func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) Executor e := &SelectLockExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src), Lock: v.Lock, + partitionedTable: v.PartitionedTable, } return e } diff --git a/executor/executor.go b/executor/executor.go index 0bde512172b76..b879adf76b95b 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -742,6 +742,11 @@ type SelectLockExec struct { Lock ast.SelectLockType keys []kv.Key + + partitionedTable []table.PartitionedTable + + // tblID2Table is cached to reduce cost. + tblID2Table map[int64]table.PartitionedTable } // Open implements the Executor Open interface. @@ -755,6 +760,18 @@ func (e *SelectLockExec) Open(ctx context.Context) error { // This operation is only for schema validator check. txnCtx.UpdateDeltaForTable(id, 0, 0, map[int64]int64{}) } + + if len(e.Schema().TblID2Handle) > 0 && len(e.partitionedTable) > 0 { + e.tblID2Table = make(map[int64]table.PartitionedTable, len(e.partitionedTable)) + for id := range e.Schema().TblID2Handle { + for _, p := range e.partitionedTable { + if id == p.Meta().ID { + e.tblID2Table[id] = p + } + } + } + } + return nil } @@ -774,12 +791,23 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { if len(e.Schema().TblID2Handle) == 0 || e.Lock != ast.SelectLockForUpdate { return nil } - if req.NumRows() != 0 { + + if req.NumRows() > 0 { iter := chunk.NewIterator4Chunk(req) - for id, cols := range e.Schema().TblID2Handle { - for _, col := range cols { - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(id, row.GetInt64(col.Index))) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + for id, cols := range e.Schema().TblID2Handle { + physicalID := id + if pt, ok := e.tblID2Table[id]; ok { + // On a partitioned table, we have to use physical ID to encode the lock key! + p, err := pt.GetPartitionByRow(e.ctx, row.GetDatumRow(e.base().retFieldTypes)) + if err != nil { + return err + } + physicalID = p.GetPhysicalID() + } + + for _, col := range cols { + e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(physicalID, row.GetInt64(col.Index))) } } } diff --git a/executor/write.go b/executor/write.go index 09506aed09587..356656dad4ac1 100644 --- a/executor/write.go +++ b/executor/write.go @@ -117,7 +117,17 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } - unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(t.Meta().ID, h) + + physicalID := t.Meta().ID + if pt, ok := t.(table.PartitionedTable); ok { + p, err := pt.GetPartitionByRow(ctx, oldData) + if err != nil { + return false, false, 0, err + } + physicalID = p.GetPhysicalID() + } + + unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) txnCtx := ctx.GetSessionVars().TxnCtx if txnCtx.IsPessimistic { txnCtx.AddUnchangedRowKey(unchangedRowKey) diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index f32e566aeaa2c..1b0a4661c7d81 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1280,7 +1280,8 @@ func (p *LogicalLimit) exhaustPhysicalPlans(prop *property.PhysicalProperty) []P func (p *LogicalLock) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { childProp := prop.Clone() lock := PhysicalLock{ - Lock: p.Lock, + Lock: p.Lock, + PartitionedTable: p.partitionedTable, }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) return []PhysicalPlan{lock} } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 3137712f0d05c..aa4fca92e3b79 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2239,6 +2239,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName) (L if tableInfo.GetPartitionInfo() != nil { b.optFlag = b.optFlag | flagPartitionProcessor + b.partitionedTable = append(b.partitionedTable, tbl.(table.PartitionedTable)) // check partition by name. for _, name := range tn.PartitionNames { _, err = tables.FindPartitionByName(tableInfo, name.L) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 919e8168fa65f..3cea589b8d8d3 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -684,7 +684,8 @@ type LogicalLimit struct { type LogicalLock struct { baseLogicalPlan - Lock ast.SelectLockType + Lock ast.SelectLockType + partitionedTable []table.PartitionedTable } // WindowFrame represents a window function frame. diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 8b561e9bcc005..ffabc868fe4dd 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/ranger" ) @@ -280,6 +281,8 @@ type PhysicalLock struct { basePhysicalPlan Lock ast.SelectLockType + + PartitionedTable []table.PartitionedTable } // PhysicalLimit is the physical operator of Limit. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index bde5ea183e54e..68c1447dc517d 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -200,6 +200,9 @@ type PlanBuilder struct { inStraightJoin bool windowSpecs map[string]*ast.WindowSpec + + // SelectLock need this information to locate the lock on partitions. + partitionedTable []table.PartitionedTable } // GetVisitInfo gets the visitInfo of the PlanBuilder. @@ -575,7 +578,7 @@ func removeIgnoredPaths(paths, ignoredPaths []*accessPath, tblInfo *model.TableI } func (b *PlanBuilder) buildSelectLock(src LogicalPlan, lock ast.SelectLockType) *LogicalLock { - selectLock := LogicalLock{Lock: lock}.Init(b.ctx) + selectLock := LogicalLock{Lock: lock, partitionedTable: b.partitionedTable}.Init(b.ctx) selectLock.SetChildren(src) return selectLock } diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 35bdc1931712e..4297915284922 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -368,6 +368,12 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) error { return p.baseLogicalPlan.PruneColumns(parentUsedCols) } + if len(p.partitionedTable) > 0 { + // If the children include partitioned tables, do not prune columns. + // Because the executor needs the partitioned columns to calculate the lock key. + return p.children[0].PruneColumns(p.Schema().Columns) + } + for _, cols := range p.children[0].Schema().TblID2Handle { parentUsedCols = append(parentUsedCols, cols...) } diff --git a/session/session_test.go b/session/session_test.go index bc74ef40bdf78..d48b261678cb4 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2799,3 +2799,69 @@ func (s *testSessionSuite) TestGrantViewRelated(c *C) { tkUser.MustQuery("select current_user();").Check(testkit.Rows("u_version29@%")) tkUser.MustExec("create view v_version29_c as select * from v_version29;") } + +func (s *testSessionSuite) TestPessimisticLockOnPartition(c *C) { + // This test checks that 'select ... for update' locks the partition instead of the table. + // Cover a bug that table ID is used to encode the lock key mistakenly. + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`create table if not exists forupdate_on_partition ( + age int not null primary key, + nickname varchar(20) not null, + gender int not null default 0, + first_name varchar(30) not null default '', + last_name varchar(20) not null default '', + full_name varchar(60) as (concat(first_name, ' ', last_name)), + index idx_nickname (nickname) +) partition by range (age) ( + partition child values less than (18), + partition young values less than (30), + partition middle values less than (50), + partition old values less than (123) +);`) + tk.MustExec("insert into forupdate_on_partition (`age`, `nickname`) values (25, 'cosven');") + + tk1 := testkit.NewTestKit(c, s.store) + tk1.MustExec("use test") + + tk.MustExec("begin pessimistic") + tk.MustQuery("select * from forupdate_on_partition where age=25 for update").Check(testkit.Rows("25 cosven 0 ")) + tk1.MustExec("begin pessimistic") + + ch := make(chan int32, 5) + go func() { + tk1.MustExec("update forupdate_on_partition set first_name='sw' where age=25") + ch <- 0 + tk1.MustExec("commit") + }() + + // Leave 50ms for tk1 to run, tk1 should be blocked at the update operation. + time.Sleep(50 * time.Millisecond) + ch <- 1 + + tk.MustExec("commit") + // tk1 should be blocked until tk commit, check the order. + c.Assert(<-ch, Equals, int32(1)) + c.Assert(<-ch, Equals, int32(0)) + + // Once again... + // This time, test for the update-update conflict. + tk.MustExec("begin pessimistic") + tk.MustExec("update forupdate_on_partition set first_name='sw' where age=25") + tk1.MustExec("begin pessimistic") + + go func() { + tk1.MustExec("update forupdate_on_partition set first_name = 'xxx' where age=25") + ch <- 0 + tk1.MustExec("commit") + }() + + // Leave 50ms for tk1 to run, tk1 should be blocked at the update operation. + time.Sleep(50 * time.Millisecond) + ch <- 1 + + tk.MustExec("commit") + // tk1 should be blocked until tk commit, check the order. + c.Assert(<-ch, Equals, int32(1)) + c.Assert(<-ch, Equals, int32(0)) +} diff --git a/table/table.go b/table/table.go index 5f852dac47e18..5931a7879270b 100644 --- a/table/table.go +++ b/table/table.go @@ -212,7 +212,7 @@ type PhysicalTable interface { type PartitionedTable interface { Table GetPartition(physicalID int64) PhysicalTable - GetPartitionByRow(sessionctx.Context, []types.Datum) (Table, error) + GetPartitionByRow(sessionctx.Context, []types.Datum) (PhysicalTable, error) } // TableFromMeta builds a table.Table from *model.TableInfo. diff --git a/table/tables/partition.go b/table/tables/partition.go index e6b3ad298d6a9..37b8bf2bc477c 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -339,7 +339,7 @@ func (t *partitionedTable) GetPartition(pid int64) table.PhysicalTable { } // GetPartitionByRow returns a Table, which is actually a Partition. -func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.Table, error) { +func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.PhysicalTable, error) { pid, err := t.locatePartition(ctx, t.Meta().GetPartitionInfo(), r) if err != nil { return nil, errors.Trace(err)