diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index af40efa7a5f91..8a8a9c4cab7fe 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -521,6 +521,9 @@ func (e *Execute) rebuildRange(p Plan) error { } x.Handle = kv.IntHandle(iv) if x.PartitionInfo != nil { + if x.TblInfo.Partition.Type != model.PartitionTypeHash { + return errors.New("range partition table can not use plan cache") + } num := x.TblInfo.Partition.Num pos := math.Abs(iv) % int64(num) x.PartitionInfo = &x.TblInfo.Partition.Definitions[pos] @@ -533,6 +536,9 @@ func (e *Execute) rebuildRange(p Plan) error { } } if x.PartitionInfo != nil { + if x.TblInfo.Partition.Type != model.PartitionTypeHash { + return errors.New("range partition table can not use plan cache") + } val := x.IndexValues[x.partitionColumnPos].GetInt64() partitionID := val % int64(x.TblInfo.Partition.Num) x.PartitionInfo = &x.TblInfo.Partition.Definitions[partitionID] diff --git a/table/tables/partition.go b/table/tables/partition.go index 7d80e4d91c0a5..4cfe8092d1315 100644 --- a/table/tables/partition.go +++ b/table/tables/partition.go @@ -21,11 +21,13 @@ import ( "sort" "strconv" "strings" + "sync" "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" @@ -65,8 +67,10 @@ func (p *partition) GetPhysicalID() int64 { // partitionedTable is a table, it contains many Partitions. type partitionedTable struct { TableCommon - partitionExpr *PartitionExpr - partitions map[int64]*partition + partitionExpr *PartitionExpr + partitions map[int64]*partition + evalBufferTypes []*types.FieldType + evalBufferPool sync.Pool } func newPartitionedTable(tbl *TableCommon, tblInfo *model.TableInfo) (table.Table, error) { @@ -76,7 +80,12 @@ func newPartitionedTable(tbl *TableCommon, tblInfo *model.TableInfo) (table.Tabl return nil, errors.Trace(err) } ret.partitionExpr = partitionExpr - + initEvalBufferType(ret) + ret.evalBufferPool = sync.Pool{ + New: func() interface{} { + return initEvalBuffer(ret) + }, + } if err := initTableIndices(&ret.TableCommon); err != nil { return nil, errors.Trace(err) } @@ -125,6 +134,28 @@ type PartitionExpr struct { *ForRangeColumnsPruning } +func initEvalBufferType(t *partitionedTable) { + hasExtraHandle := false + numCols := len(t.Cols()) + if !t.Meta().PKIsHandle { + hasExtraHandle = true + numCols++ + } + t.evalBufferTypes = make([]*types.FieldType, numCols) + for i, col := range t.Cols() { + t.evalBufferTypes[i] = &col.FieldType + } + + if hasExtraHandle { + t.evalBufferTypes[len(t.evalBufferTypes)-1] = types.NewFieldType(mysql.TypeLonglong) + } +} + +func initEvalBuffer(t *partitionedTable) *chunk.MutRow { + evalBuffer := chunk.MutRowFromTypes(t.evalBufferTypes) + return &evalBuffer +} + // ForRangeColumnsPruning is used for range partition pruning. type ForRangeColumnsPruning struct { LessThan []expression.Expression @@ -237,9 +268,9 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, // The caller should assure partition info is not nil. locateExprs := make([]expression.Expression, 0, len(pi.Definitions)) var buf bytes.Buffer + p := parser.New() schema := expression.NewSchema(columns...) partStr := rangePartitionString(pi) - p := parser.New() for i := 0; i < len(pi.Definitions); i++ { if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") { // Expr less than maxvalue is always true. @@ -263,10 +294,15 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, switch len(pi.Columns) { case 0: + exprs, err := parseSimpleExprWithNames(p, ctx, pi.Expr, schema, names) + if err != nil { + return nil, err + } tmp, err := dataForRangePruning(ctx, pi) if err != nil { return nil, errors.Trace(err) } + ret.Expr = exprs ret.ForRangePruning = tmp case 1: tmp, err := dataForRangeColumnsPruning(ctx, pi, schema, names, p) @@ -318,7 +354,11 @@ func (t *partitionedTable) locatePartition(ctx sessionctx.Context, pi *model.Par var idx int switch t.meta.Partition.Type { case model.PartitionTypeRange: - idx, err = t.locateRangePartition(ctx, pi, r) + if len(pi.Columns) == 0 { + idx, err = t.locateRangePartition(ctx, pi, r) + } else { + idx, err = t.locateRangeColumnPartition(ctx, pi, r) + } case model.PartitionTypeHash: idx, err = t.locateHashPartition(ctx, pi, r) } @@ -328,13 +368,15 @@ func (t *partitionedTable) locatePartition(ctx sessionctx.Context, pi *model.Par return pi.Definitions[idx].ID, nil } -func (t *partitionedTable) locateRangePartition(ctx sessionctx.Context, pi *model.PartitionInfo, r []types.Datum) (int, error) { +func (t *partitionedTable) locateRangeColumnPartition(ctx sessionctx.Context, pi *model.PartitionInfo, r []types.Datum) (int, error) { var err error var isNull bool partitionExprs := t.partitionExpr.UpperBounds + evalBuffer := t.evalBufferPool.Get().(*chunk.MutRow) + defer t.evalBufferPool.Put(evalBuffer) idx := sort.Search(len(partitionExprs), func(i int) bool { - var ret int64 - ret, isNull, err = partitionExprs[i].EvalInt(ctx, chunk.MutRowFromDatums(r).ToRow()) + evalBuffer.SetDatums(r...) + ret, isNull, err := partitionExprs[i].EvalInt(ctx, evalBuffer.ToRow()) if err != nil { return true // Break the search. } @@ -371,9 +413,74 @@ func (t *partitionedTable) locateRangePartition(ctx sessionctx.Context, pi *mode return idx, nil } +func (t *partitionedTable) locateRangePartition(ctx sessionctx.Context, pi *model.PartitionInfo, r []types.Datum) (int, error) { + var ( + ret int64 + val int64 + isNull bool + err error + ) + if col, ok := t.partitionExpr.Expr.(*expression.Column); ok { + if r[col.Index].IsNull() { + isNull = true + } + ret = r[col.Index].GetInt64() + } else { + evalBuffer := t.evalBufferPool.Get().(*chunk.MutRow) + defer t.evalBufferPool.Put(evalBuffer) + evalBuffer.SetDatums(r...) + val, isNull, err = t.partitionExpr.Expr.EvalInt(ctx, evalBuffer.ToRow()) + if err != nil { + return 0, err + } + ret = val + } + unsigned := mysql.HasUnsignedFlag(t.partitionExpr.Expr.GetType().Flag) + ranges := t.partitionExpr.ForRangePruning + length := len(ranges.LessThan) + pos := sort.Search(length, func(i int) bool { + if isNull { + return true + } + return ranges.compare(i, ret, unsigned) > 0 + }) + if isNull { + pos = 0 + } + if pos < 0 || pos >= length { + // The data does not belong to any of the partition returns `table has no partition for value %s`. + var valueMsg string + if pi.Expr != "" { + e, err := expression.ParseSimpleExprWithTableInfo(ctx, pi.Expr, t.meta) + if err == nil { + val, _, err := e.EvalInt(ctx, chunk.MutRowFromDatums(r).ToRow()) + if err == nil { + valueMsg = fmt.Sprintf("%d", val) + } + } + } else { + // When the table is partitioned by range columns. + valueMsg = "from column_list" + } + return 0, table.ErrNoPartitionForGivenValue.GenWithStackByArgs(valueMsg) + } + return pos, nil +} + // TODO: supports linear hashing func (t *partitionedTable) locateHashPartition(ctx sessionctx.Context, pi *model.PartitionInfo, r []types.Datum) (int, error) { - ret, isNull, err := t.partitionExpr.Expr.EvalInt(ctx, chunk.MutRowFromDatums(r).ToRow()) + if col, ok := t.partitionExpr.Expr.(*expression.Column); ok { + ret := r[col.Index].GetInt64() + ret = ret % int64(t.meta.Partition.Num) + if ret < 0 { + ret = -ret + } + return int(ret), nil + } + evalBuffer := t.evalBufferPool.Get().(*chunk.MutRow) + defer t.evalBufferPool.Put(evalBuffer) + evalBuffer.SetDatums(r...) + ret, isNull, err := t.partitionExpr.Expr.EvalInt(ctx, evalBuffer.ToRow()) if err != nil { return 0, err } @@ -542,3 +649,31 @@ func rewritePartitionExpr(ctx sessionctx.Context, field ast.ExprNode, schema *ex expr, err := expression.RewriteSimpleExprWithNames(ctx, field, schema, names) return expr, err } + +func compareUnsigned(v1, v2 int64) int { + switch { + case uint64(v1) > uint64(v2): + return 1 + case uint64(v1) == uint64(v2): + return 0 + } + return -1 +} + +func (lt *ForRangePruning) compare(ith int, v int64, unsigned bool) int { + if ith == len(lt.LessThan)-1 { + if lt.MaxValue { + return 1 + } + } + if unsigned { + return compareUnsigned(lt.LessThan[ith], v) + } + switch { + case lt.LessThan[ith] > v: + return 1 + case lt.LessThan[ith] == v: + return 0 + } + return -1 +} diff --git a/table/tables/partition_test.go b/table/tables/partition_test.go index d67f54954e85c..51287e0f222a4 100644 --- a/table/tables/partition_test.go +++ b/table/tables/partition_test.go @@ -278,7 +278,7 @@ func (ts *testSuite) TestGeneratePartitionExpr(c *C) { } } -func (ts *testSuite) TestLocateRangePartitionErr(c *C) { +func (ts *testSuite) TestLocateRangeColumnPartitionErr(c *C) { tk := testkit.NewTestKitWithInit(c, ts.store) tk.MustExec("use test") tk.MustExec(`CREATE TABLE t_month_data_monitor ( @@ -294,6 +294,88 @@ func (ts *testSuite) TestLocateRangePartitionErr(c *C) { c.Assert(table.ErrNoPartitionForGivenValue.Equal(err), IsTrue) } +func (ts *testSuite) TestLocateRangePartitionErr(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE t_range_locate ( + id int(20) NOT NULL AUTO_INCREMENT, + data_date date NOT NULL, + PRIMARY KEY (id, data_date) + ) PARTITION BY RANGE(id) ( + PARTITION p0 VALUES LESS THAN (1024), + PARTITION p1 VALUES LESS THAN (4096) + )`) + + _, err := tk.Exec("INSERT INTO t_range_locate VALUES (5000, '2019-04-04')") + c.Assert(table.ErrNoPartitionForGivenValue.Equal(err), IsTrue) +} + +func (ts *testSuite) TestLocatePartitionWithExtraHandle(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE t_extra ( + id int(20) NOT NULL AUTO_INCREMENT, + x int(10) not null, + PRIMARY KEY (id, x) + ) PARTITION BY RANGE(id) ( + PARTITION p0 VALUES LESS THAN (1024), + PARTITION p1 VALUES LESS THAN (4096) + )`) + tk.MustExec("INSERT INTO t_extra VALUES (1000, 1000), (2000, 2000)") + tk.MustExec("set autocommit=0;") + tk.MustQuery("select * from t_extra where id = 1000 for update").Check(testkit.Rows("1000 1000")) + tk.MustExec("commit") +} + +func (ts *testSuite) TestMultiTableUpdate(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE t_a ( + id int(20), + data_date date + ) partition by hash(id) partitions 10`) + tk.MustExec(`CREATE TABLE t_b ( + id int(20), + data_date date + ) PARTITION BY RANGE(id) ( + PARTITION p0 VALUES LESS THAN (2), + PARTITION p1 VALUES LESS THAN (4), + PARTITION p2 VALUES LESS THAN (6) + )`) + tk.MustExec("INSERT INTO t_a VALUES (1, '2020-08-25'), (2, '2020-08-25'), (3, '2020-08-25'), (4, '2020-08-25'), (5, '2020-08-25')") + tk.MustExec("INSERT INTO t_b VALUES (1, '2020-08-25'), (2, '2020-08-25'), (3, '2020-08-25'), (4, '2020-08-25'), (5, '2020-08-25')") + tk.MustExec("update t_a, t_b set t_a.data_date = '2020-08-24', t_a.data_date = '2020-08-23', t_a.id = t_a.id + t_b.id where t_a.id = t_b.id") + tk.MustQuery("select id from t_a order by id").Check(testkit.Rows("2", "4", "6", "8", "10")) +} + +func (ts *testSuite) TestLocatePartitionSingleColumn(c *C) { + tk := testkit.NewTestKitWithInit(c, ts.store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE t_hash_locate ( + id int(20), + data_date date + ) partition by hash(id) partitions 10`) + + tk.MustExec(`CREATE TABLE t_range ( + id int(10) NOT NULL, + data_date date, + PRIMARY KEY (id) + ) PARTITION BY RANGE(id) ( + PARTITION p0 VALUES LESS THAN (1), + PARTITION p1 VALUES LESS THAN (2), + PARTITION p2 VALUES LESS THAN (4) + )`) + + tk.MustExec("INSERT INTO t_hash_locate VALUES (), (), (), ()") + tk.MustQuery("SELECT count(*) FROM t_hash_locate PARTITION (p0)").Check(testkit.Rows("4")) + tk.MustExec("INSERT INTO t_range VALUES (-1, NULL), (1, NULL), (2, NULL), (3, NULL)") + tk.MustQuery("SELECT count(*) FROM t_range PARTITION (p0)").Check(testkit.Rows("1")) + tk.MustQuery("SELECT count(*) FROM t_range PARTITION (p1)").Check(testkit.Rows("1")) + tk.MustQuery("SELECT count(*) FROM t_range PARTITION (p2)").Check(testkit.Rows("2")) + _, err := tk.Exec("INSERT INTO t_range VALUES (4, NULL)") + c.Assert(table.ErrNoPartitionForGivenValue.Equal(err), IsTrue) +} + func (ts *testSuite) TestTimeZoneChange(c *C) { tk := testkit.NewTestKitWithInit(c, ts.store) tk.MustExec("use test")