diff --git a/executor/executor.go b/executor/executor.go index 63f65f3d816b3..621b6e65d1cf0 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -887,6 +887,9 @@ type SelectLockExec struct { // tblID2Table is cached to reduce cost. tblID2Table map[int64]table.PartitionedTable + + // ptCol2RowIndexes is partitioned table column map to row indexes + ptCol2RowIndexes map[int64][]int } // Open implements the Executor Open interface. @@ -895,12 +898,18 @@ func (e *SelectLockExec) Open(ctx context.Context) error { return err } + is := domain.GetDomain(e.ctx).InfoSchema() if len(e.tblID2Handle) > 0 && len(e.partitionedTable) > 0 { e.tblID2Table = make(map[int64]table.PartitionedTable, len(e.partitionedTable)) + e.ptCol2RowIndexes = make(map[int64][]int, len(e.partitionedTable)) for id := range e.tblID2Handle { for _, p := range e.partitionedTable { if id == p.Meta().ID { e.tblID2Table[id] = p + err := e.generatePartitionedTableColumnMap(p, is) + if err != nil { + return err + } } } } @@ -909,6 +918,55 @@ func (e *SelectLockExec) Open(ctx context.Context) error { return nil } +func (e *SelectLockExec) generatePartitionedTableColumnMap(pt table.PartitionedTable, is infoschema.InfoSchema) error { + // Get Table Name and DB name + tblInfo := pt.Meta() + dbInfo, ok := is.SchemaByTable(tblInfo) + if !ok { + return errors.Trace(errors.Errorf("Cannot get schema info for table %s", tblInfo.Name.O)) + } + colNamePrefix := fmt.Sprintf("%s.%s.", dbInfo.Name.L, tblInfo.Name.L) + cols := pt.VisibleCols() + matched := false + ret := make([]int, 0, len(cols)) + for _, colInfo := range cols { + colFullName := colNamePrefix + colInfo.Name.L + matched = false + for i, col := range e.schema.Columns { + if col.OrigName == colFullName { + ret = append(ret, i) + matched = true + break + } + } + if !matched { + return errors.Trace(errors.Errorf("Table %s column %s cannot find data with select result", tblInfo.Name.O, colInfo.Name.L)) + } + } + e.ptCol2RowIndexes[tblInfo.ID] = ret + return nil +} + +func (e *SelectLockExec) projectRowToPartitionedTableRow(row chunk.Row, ptID int64) ([]types.Datum, error) { + rowDatums := row.GetDatumRow(e.base().retFieldTypes) + numDatums := len(rowDatums) + if len(e.schema.Columns) != numDatums { + return nil, errors.Trace(errors.Errorf("Columns length not match row fields length")) + } + proj, have := e.ptCol2RowIndexes[ptID] + if !have { + return nil, errors.Trace(errors.Errorf("Cannot get column maps")) + } + ret := make([]types.Datum, 0, numDatums) + for _, idx := range proj { + if idx >= numDatums { + return nil, errors.Trace(errors.Errorf("Column maps index is overflow!")) + } + ret = append(ret, rowDatums[idx]) + } + return ret, nil +} + // Next implements the Executor Next interface. func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) @@ -927,8 +985,11 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { for id, cols := range e.tblID2Handle { physicalID := id if pt, ok := e.tblID2Table[id]; ok { - // On a partitioned table, we have to use physical ID to encode the lock key! - p, err := pt.GetPartitionByRow(e.ctx, row.GetDatumRow(e.base().retFieldTypes)) + ptRowData, err := e.projectRowToPartitionedTableRow(row, id) + if err != nil { + return err + } + p, err := pt.GetPartitionByRow(e.ctx, ptRowData) if err != nil { return err } diff --git a/executor/executor_test.go b/executor/executor_test.go index 4b7317b638b41..f2fecc17f62cb 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -7367,6 +7367,54 @@ func (s *testSuite) Test13004(c *C) { tk.MustQuery("SELECT TIMESTAMP '9999-01-01 00:00:00'").Check(testkit.Rows("9999-01-01 00:00:00")) } +func (s *testSuite) Test21509(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t0") + tk.MustExec("create table t0 (c_int int, c_timestamp timestamp, primary key (c_int), key(c_timestamp)) partition by hash (c_int) partitions 4") + tk.MustExec("insert into t0 values (1, '2020-12-05 01:02:03')") + tk.MustExec("begin") + // the select for update should not got error + tk.MustQuery("select * from t0 where c_timestamp in (select c_timestamp from t0 where c_int = 1) for update") + tk.MustExec("commit") +} + +func (s *testSuite) Test21618(c *C) { + tk1 := testkit.NewTestKit(c, s.store) + tk2 := testkit.NewTestKit(c, s.store) + // Prepare + tk1.MustExec("use test") + tk2.MustExec("use test") + tk1.MustExec("drop table if exists t") + tk1.MustExec("create table t (c_int int, d_int int, primary key (c_int), key(d_int)) partition by hash (c_int) partitions 4") + tk1.MustExec("insert into t values (1, 2)") + // Transaction 1 execute + tk1.MustExec("begin pessimistic") + tk1.MustExec("select * from t where d_int in (select d_int from t where c_int = 1) for update") + fc := make(chan int) + go func() { + // Transaction 2 execute + tk2.MustExec("begin pessimistic") + tk2.MustExec("select * from t where d_int = 2 for update") + tk2.MustExec("commit") + fc <- 1 + }() + timer := time.NewTimer(1 * time.Second) + select { + case <-fc: + c.Assert(false, IsTrue, Commentf("Should not finish transaction 2")) + case <-timer.C: + } + tk1.MustExec("commit") + + timer = time.NewTimer(1 * time.Second) + select { + case <-fc: + case <-timer.C: + c.Assert(false, IsTrue, Commentf("Transaction 2 should be finished")) + } +} + func (s *testSuite) Test12178(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test")