From 890e4597beaeb7980caa378db55dc34f0762de5c Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Thu, 25 Apr 2019 20:14:38 +0800 Subject: [PATCH 1/5] ddl: check expr restriction for hash partitioned table The return type of the expr should be int, and only part of functions are allowed --- ddl/db_partition_test.go | 23 +++++++++++++---------- ddl/ddl_api.go | 28 +++++++++++++++++----------- ddl/partition.go | 10 ++++++++-- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index 3ebab9da7b640..52d04e67c1883 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -106,12 +106,12 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { sql4 := `create table t4 ( a int not null, - b int not null + b int not null ) partition by range( id ) ( partition p1 values less than maxvalue, - partition p2 values less than (1991), - partition p3 values less than (1995) + partition p2 values less than (1991), + partition p3 values less than (1995) );` assertErrorCode(c, tk, sql4, tmysql.ErrPartitionMaxvalue) @@ -121,10 +121,10 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { c INT NOT NULL ) partition by range columns(a,b,c) ( - partition p0 values less than (10,5,1), - partition p2 values less than (50,maxvalue,10), - partition p3 values less than (65,30,13), - partition p4 values less than (maxvalue,30,40) + partition p0 values less than (10,5,1), + partition p2 values less than (50,maxvalue,10), + partition p3 values less than (65,30,13), + partition p4 values less than (maxvalue,30,40) );`) c.Assert(err, IsNil) @@ -139,13 +139,13 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { sql7 := `create table t7 ( a int not null, - b int not null + b int not null ) partition by range( id ) ( partition p1 values less than (1991), partition p2 values less than maxvalue, - partition p3 values less than maxvalue, - partition p4 values less than (1995), + partition p3 values less than maxvalue, + partition p4 values less than (1995), partition p5 values less than maxvalue );` assertErrorCode(c, tk, sql7, tmysql.ErrPartitionMaxvalue) @@ -230,6 +230,9 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { assertErrorCode(c, tk, `create table t31 (a int not null) partition by range( a );`, tmysql.ErrPartitionsMustBeDefined) assertErrorCode(c, tk, `create table t32 (a int not null) partition by range columns( a );`, tmysql.ErrPartitionsMustBeDefined) assertErrorCode(c, tk, `create table t33 (a int, b int) partition by hash(a) partitions 0;`, tmysql.ErrNoParts) + assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(a) partitions 30;`, tmysql.ErrFieldTypeNotAllowedAsPartitionField) + // TODO: fix this one + // assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(unix_timestamp(a)) partitions 30;`, tmysql.ErrPartitionFuncNotAllowed) } func (s *testIntegrationSuite7) TestCreateTableWithHashPartition(c *C) { diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index d49b228c46677..296365d85624b 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1141,7 +1141,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS err = checkPartitionByRangeColumn(ctx, tbInfo, pi, s) } case model.PartitionTypeHash: - err = checkPartitionByHash(pi) + err = checkPartitionByHash(ctx, pi, s, tbInfo) } if err != nil { return nil, errors.Trace(err) @@ -1353,39 +1353,45 @@ func buildViewInfoWithTableColumns(ctx sessionctx.Context, s *ast.CreateViewStmt return viewInfo, tableColumns } -func checkPartitionByHash(pi *model.PartitionInfo) error { +func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, tbInfo *model.TableInfo) error { if err := checkAddPartitionTooManyPartitions(pi.Num); err != nil { - return errors.Trace(err) + return err } - if err := checkNoHashPartitions(pi.Num); err != nil { - return errors.Trace(err) + if err := checkNoHashPartitions(ctx, pi.Num); err != nil { + return err + } + if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { + return err + } + if err := checkPartitionFuncType(ctx, s, nil, tbInfo); err != nil { + return err } return nil } func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error { if err := checkPartitionNameUnique(tbInfo, pi); err != nil { - return errors.Trace(err) + return err } if err := checkCreatePartitionValue(ctx, tbInfo, pi, cols); err != nil { - return errors.Trace(err) + return err } if err := checkAddPartitionTooManyPartitions(uint64(len(pi.Definitions))); err != nil { - return errors.Trace(err) + return err } if err := checkNoRangePartitions(len(pi.Definitions)); err != nil { - return errors.Trace(err) + return err } if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { - return errors.Trace(err) + return err } if err := checkPartitionFuncType(ctx, s, cols, tbInfo); err != nil { - return errors.Trace(err) + return err } return nil } diff --git a/ddl/partition.go b/ddl/partition.go index 62481a2bdaf53..d0ac01d6ed781 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -215,13 +215,19 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols } } - e, err := expression.ParseSimpleExprWithTableInfo(ctx, buf.String(), tblInfo) + e, err := expression.ParseSimpleExprWithTableInfo(ctx, exprStr, tblInfo) if err != nil { return errors.Trace(err) } if e.GetType().EvalType() == types.ETInt { return nil } + if s.Partition.Tp == model.PartitionTypeHash { + if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok { + return ErrNotAllowedTypeInPartition.GenWithStackByArgs(exprStr) + } + } + return ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION") } @@ -428,7 +434,7 @@ func checkAddPartitionTooManyPartitions(piDefs uint64) error { return nil } -func checkNoHashPartitions(partitionNum uint64) error { +func checkNoHashPartitions(ctx sessionctx.Context, partitionNum uint64) error { if partitionNum == 0 { return ErrNoParts.GenWithStackByArgs("partitions") } From 8d872ecc4bd0162586253b0484131009e43a6d31 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Thu, 25 Apr 2019 20:24:14 +0800 Subject: [PATCH 2/5] make golint happy --- ddl/ddl_api.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 296365d85624b..9fcb362933b8f 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1363,10 +1363,7 @@ func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *as if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { return err } - if err := checkPartitionFuncType(ctx, s, nil, tbInfo); err != nil { - return err - } - return nil + return checkPartitionFuncType(ctx, s, nil, tbInfo) } func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error { @@ -1390,10 +1387,7 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi * return err } - if err := checkPartitionFuncType(ctx, s, cols, tbInfo); err != nil { - return err - } - return nil + return checkPartitionFuncType(ctx, s, cols, tbInfo) } func checkPartitionByRangeColumn(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt) error { From 2a63e29227b38daa1c9af5815579b7d613b1da5b Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 7 May 2019 11:50:58 +0800 Subject: [PATCH 3/5] address comment --- ddl/partition.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ddl/partition.go b/ddl/partition.go index d0ac01d6ed781..535ca98f1674b 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -202,7 +202,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols buf := new(bytes.Buffer) s.Partition.Expr.Format(buf) exprStr := buf.String() - if s.Partition.Tp == model.PartitionTypeRange { + if s.Partition.Tp == model.PartitionTypeRange || s.Partition.Tp == model.PartitionTypeHash { // if partition by columnExpr, check the column type if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok { for _, col := range cols { From cc6fbee199ed43596b78a2faa6354792672f1dec Mon Sep 17 00:00:00 2001 From: crazycs Date: Wed, 8 May 2019 13:32:29 +0800 Subject: [PATCH 4/5] Update ddl/ddl_api.go Co-Authored-By: tiancaiamao --- ddl/ddl_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 53287061c7bee..67445b3b14060 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1375,7 +1375,7 @@ func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *as if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { return err } - return checkPartitionFuncType(ctx, s, nil, tbInfo) + return checkPartitionFuncType(ctx, s, tbInfo.Columns, tbInfo) } func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error { From 23abe5b50a81db2875ee08611b81a38978751fde Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Wed, 8 May 2019 13:59:51 +0800 Subject: [PATCH 5/5] fix CI --- ddl/ddl_api.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 67445b3b14060..73df8ba066690 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1141,7 +1141,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS err = checkPartitionByRangeColumn(ctx, tbInfo, pi, s) } case model.PartitionTypeHash: - err = checkPartitionByHash(ctx, pi, s, tbInfo) + err = checkPartitionByHash(ctx, pi, s, cols, tbInfo) } if err != nil { return nil, errors.Trace(err) @@ -1365,7 +1365,7 @@ func buildViewInfoWithTableColumns(ctx sessionctx.Context, s *ast.CreateViewStmt return viewInfo, tableColumns } -func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, tbInfo *model.TableInfo) error { +func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, tbInfo *model.TableInfo) error { if err := checkAddPartitionTooManyPartitions(pi.Num); err != nil { return err } @@ -1375,7 +1375,7 @@ func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *as if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { return err } - return checkPartitionFuncType(ctx, s, tbInfo.Columns, tbInfo) + return checkPartitionFuncType(ctx, s, cols, tbInfo) } func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error {