From b5f9b35a453a0b5f06fd4a7fa77b94f251e7716e Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Wed, 25 Jul 2018 22:41:42 +0800 Subject: [PATCH] plan: refactor the code of building `Insert`. (#7068) --- ddl/db_change_test.go | 35 ++-- executor/errors.go | 1 - executor/executor_test.go | 9 +- executor/insert_common.go | 36 ---- executor/write_test.go | 21 +++ plan/errors.go | 3 + plan/logical_plan_test.go | 2 +- plan/physical_plan_test.go | 2 +- plan/planbuilder.go | 330 +++++++++++++++++++++++-------------- session/session_test.go | 5 +- 10 files changed, 266 insertions(+), 178 deletions(-) diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 907ded50ca487..ad66b651f721d 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -132,13 +132,13 @@ func (s *testStateChangeSuite) TestTwoStates(c *C) { testInfo.sqlInfos[0].sql = "insert into t (c1, c2, c3, c4) value(2, 'b', 'N', '2017-07-02')" testInfo.sqlInfos[1].sql = "insert into t (c1, c2, c3, d3, c4) value(3, 'b', 'N', 'a', '2017-07-03')" unknownColErr := errors.New("unknown column d3") - testInfo.sqlInfos[1].cases[0].expectedErr = unknownColErr - testInfo.sqlInfos[1].cases[1].expectedErr = unknownColErr - testInfo.sqlInfos[1].cases[2].expectedErr = unknownColErr - testInfo.sqlInfos[1].cases[3].expectedErr = unknownColErr + testInfo.sqlInfos[1].cases[0].expectedCompileErr = unknownColErr + testInfo.sqlInfos[1].cases[1].expectedCompileErr = unknownColErr + testInfo.sqlInfos[1].cases[2].expectedCompileErr = unknownColErr + testInfo.sqlInfos[1].cases[3].expectedCompileErr = unknownColErr testInfo.sqlInfos[2].sql = "update t set c2 = 'c2_update'" testInfo.sqlInfos[3].sql = "replace into t values(5, 'e', 'N', '2017-07-05')'" - testInfo.sqlInfos[3].cases[4].expectedErr = errors.New("Column count doesn't match value count at row 1") + testInfo.sqlInfos[3].cases[4].expectedCompileErr = errors.New("Column count doesn't match value count at row 1") alterTableSQL := "alter table t add column d3 enum('a', 'b') not null default 'a' after c3" s.test(c, "", alterTableSQL, testInfo) // TODO: Add more DDL statements. @@ -227,10 +227,11 @@ func (s *testStateChangeSuite) test(c *C, tableName, alterTableSQL string, testI } type stateCase struct { - session session.Session - rawStmt ast.StmtNode - stmt ast.Statement - expectedErr error + session session.Session + rawStmt ast.StmtNode + stmt ast.Statement + expectedExecErr error + expectedCompileErr error } type sqlInfo struct { @@ -299,6 +300,13 @@ func (t *testExecInfo) compileSQL(idx int) (err error) { return errors.Trace(err) } c.stmt, err = compiler.Compile(ctx, c.rawStmt) + if c.expectedCompileErr != nil { + if err == nil { + err = errors.Errorf("expected error %s but got nil", c.expectedCompileErr) + } else if strings.Contains(err.Error(), c.expectedCompileErr.Error()) { + err = nil + } + } if err != nil { return errors.Trace(err) } @@ -309,11 +317,14 @@ func (t *testExecInfo) compileSQL(idx int) (err error) { func (t *testExecInfo) execSQL(idx int) error { for _, sqlInfo := range t.sqlInfos { c := sqlInfo.cases[idx] + if c.expectedCompileErr != nil { + continue + } _, err := c.stmt.Exec(context.TODO()) - if c.expectedErr != nil { + if c.expectedExecErr != nil { if err == nil { - err = errors.Errorf("expected error %s but got nil", c.expectedErr) - } else if strings.Contains(err.Error(), c.expectedErr.Error()) { + err = errors.Errorf("expected error %s but got nil", c.expectedExecErr) + } else if strings.Contains(err.Error(), c.expectedExecErr.Error()) { err = nil } } diff --git a/executor/errors.go b/executor/errors.go index 1deb212bf1b64..2ea951f3f5a1f 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -39,7 +39,6 @@ var ( ErrPasswordNoMatch = terror.ClassExecutor.New(mysql.ErrPasswordNoMatch, mysql.MySQLErrName[mysql.ErrPasswordNoMatch]) ErrCannotUser = terror.ClassExecutor.New(mysql.ErrCannotUser, mysql.MySQLErrName[mysql.ErrCannotUser]) - ErrWrongValueCountOnRow = terror.ClassExecutor.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) ErrPasswordFormat = terror.ClassExecutor.New(mysql.ErrPasswordFormat, mysql.MySQLErrName[mysql.ErrPasswordFormat]) ErrCantChangeTxCharacteristics = terror.ClassExecutor.New(mysql.ErrCantChangeTxCharacteristics, mysql.MySQLErrName[mysql.ErrCantChangeTxCharacteristics]) ) diff --git a/executor/executor_test.go b/executor/executor_test.go index 187539e10b988..ba978248e19c9 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1312,7 +1312,7 @@ func (s *testSuite) TestMultiUpdate(c *C) { func (s *testSuite) TestGeneratedColumnWrite(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.MustExec(`CREATE TABLE test_gc_write (a int primary key, b int, c int as (a+8) virtual)`) + tk.MustExec(`CREATE TABLE test_gc_write (a int primary key auto_increment, b int, c int as (a+8) virtual)`) tk.MustExec(`CREATE TABLE test_gc_write_1 (a int primary key, b int, c int)`) tests := []struct { @@ -1336,6 +1336,7 @@ func (s *testSuite) TestGeneratedColumnWrite(c *C) { // Can insert without generated columns. {`insert into test_gc_write (a, b) values (1, 1)`, 0}, {`insert into test_gc_write set a = 2, b = 2`, 0}, + {`insert into test_gc_write (b) select c from test_gc_write`, 0}, // Can update without generated columns. {`update test_gc_write set b = 2 where a = 2`, 0}, {`update test_gc_write t1, test_gc_write_1 t2 set t1.b = 3, t2.b = 4`, 0}, @@ -1343,13 +1344,15 @@ func (s *testSuite) TestGeneratedColumnWrite(c *C) { // But now we can't do this, just as same with MySQL 5.7: {`insert into test_gc_write values (1, 1)`, mysql.ErrWrongValueCountOnRow}, {`insert into test_gc_write select 1, 1`, mysql.ErrWrongValueCountOnRow}, + {`insert into test_gc_write (c) select a, b from test_gc_write`, mysql.ErrWrongValueCountOnRow}, + {`insert into test_gc_write (b, c) select a, b from test_gc_write`, mysql.ErrBadGeneratedColumn}, } for _, tt := range tests { _, err := tk.Exec(tt.stmt) if tt.err != 0 { - c.Assert(err, NotNil) + c.Assert(err, NotNil, Commentf("sql is `%v`", tt.stmt)) terr := errors.Trace(err).(*errors.Err).Cause().(*terror.Error) - c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err)) + c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err), Commentf("sql is %v", tt.stmt)) } else { c.Assert(err, IsNil) } diff --git a/executor/insert_common.go b/executor/insert_common.go index 41deed7c34e2e..10ffaecc843c1 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -148,35 +148,6 @@ func (e *InsertValues) fillValueList() error { return nil } -func (e *InsertValues) checkValueCount(insertValueCount, valueCount, genColsCount, num int, cols []*table.Column) error { - // TODO: This check should be done in plan builder. - if insertValueCount != valueCount { - // "insert into t values (), ()" is valid. - // "insert into t values (), (1)" is not valid. - // "insert into t values (1), ()" is not valid. - // "insert into t values (1,2), (1)" is not valid. - // So the value count must be same for all insert list. - return ErrWrongValueCountOnRow.GenByArgs(num + 1) - } - if valueCount == 0 && len(e.Columns) > 0 { - // "insert into t (c1) values ()" is not valid. - return ErrWrongValueCountOnRow.GenByArgs(num + 1) - } else if valueCount > 0 { - explicitSetLen := 0 - if len(e.Columns) != 0 { - explicitSetLen = len(e.Columns) - } else { - explicitSetLen = len(e.SetList) - } - if explicitSetLen > 0 && valueCount+genColsCount != len(cols) { - return ErrWrongValueCountOnRow.GenByArgs(num + 1) - } else if explicitSetLen == 0 && valueCount != len(cols) { - return ErrWrongValueCountOnRow.GenByArgs(num + 1) - } - } - return nil -} - func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows []types.DatumRow) error) (err error) { // process `insert|replace ... set x=y...` if err = e.fillValueList(); err != nil { @@ -184,11 +155,7 @@ func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows []types.D } rows := make([]types.DatumRow, len(e.Lists)) - length := len(e.Lists[0]) for i, list := range e.Lists { - if err = e.checkValueCount(length, len(list), len(e.GenColumns), i, cols); err != nil { - return errors.Trace(err) - } e.rowCount = uint64(i) rows[i], err = e.getRow(cols, list, i) if err != nil { @@ -277,9 +244,6 @@ func (e *InsertValues) fillDefaultValues(row types.DatumRow, hasValue []bool) er func (e *InsertValues) insertRowsFromSelect(ctx context.Context, cols []*table.Column, exec func(rows []types.DatumRow) error) error { // process `insert|replace into ... select ... from ...` selectExec := e.children[0] - if selectExec.Schema().Len() != len(cols) { - return ErrWrongValueCountOnRow.GenByArgs(1) - } fields := selectExec.retTypes() chk := selectExec.newChunk() iter := chunk.NewIterator4Chunk(chk) diff --git a/executor/write_test.go b/executor/write_test.go index 845b3fef0b791..cdcf9967c8779 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -100,6 +100,12 @@ func (s *testSuite) TestInsert(c *C) { c.Assert(err, NotNil) tk.MustExec("rollback") + errInsertSelectSQL = `insert insert_test_1 values(default, default, default, default, default)` + tk.MustExec("begin") + _, err = tk.Exec(errInsertSelectSQL) + c.Assert(err, NotNil) + tk.MustExec("rollback") + // Updating column is PK handle. // Make sure the record is "1, 1, nil, 1". r := tk.MustQuery("select * from insert_test where id = 1;") @@ -240,6 +246,21 @@ func (s *testSuite) TestInsert(c *C) { Check(testkit.Rows("Warning 1690 constant -1.1 overflows float", "Warning 1690 constant -1.1 overflows double", "Warning 1690 constant -2.1 overflows float", "Warning 1690 constant -2.1 overflows double")) tk.MustQuery("select * from t").Check(testkit.Rows("0 0", "0 0", "0 0", "1.1 1.1")) + + // issue 7061 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int default 1, b int default 2)") + tk.MustExec("insert into t values(default, default)") + tk.MustQuery("select * from t").Check(testkit.Rows("1 2")) + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(default(b), default(a))") + tk.MustQuery("select * from t").Check(testkit.Rows("2 1")) + tk.MustExec("truncate table t") + tk.MustExec("insert into t (b) values(default)") + tk.MustQuery("select * from t").Check(testkit.Rows("1 2")) + tk.MustExec("truncate table t") + tk.MustExec("insert into t (b) values(default(a))") + tk.MustQuery("select * from t").Check(testkit.Rows("1 1")) } func (s *testSuite) TestInsertAutoInc(c *C) { diff --git a/plan/errors.go b/plan/errors.go index 086e11e7f1434..292c7888fd509 100644 --- a/plan/errors.go +++ b/plan/errors.go @@ -48,6 +48,7 @@ const ( codeMixOfGroupFuncAndFields = mysql.ErrMixOfGroupFuncAndFields codeNonUniqTable = mysql.ErrNonuniqTable codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect + codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow ) // error definitions. @@ -81,6 +82,7 @@ var ( ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal]) ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by") ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) + ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) ) func init() { @@ -107,6 +109,7 @@ func init() { codeMixOfGroupFuncAndFields: mysql.ErrMixOfGroupFuncAndFields, codeNonUniqTable: mysql.ErrNonuniqTable, codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect, + codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap } diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index 6a0f0133f79bc..0e62a231e2c39 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -1332,7 +1332,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { ans []visitInfo }{ { - sql: "insert into t values (1)", + sql: "insert into t (a) values (1)", ans: []visitInfo{ {mysql.InsertPriv, "test", "t", ""}, }, diff --git a/plan/physical_plan_test.go b/plan/physical_plan_test.go index be3d249e1f983..4d83a5618e80c 100644 --- a/plan/physical_plan_test.go +++ b/plan/physical_plan_test.go @@ -620,7 +620,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderBasePhysicalPlan(c *C) { }, // Test simple insert. { - sql: "insert into t values(0,0,0,0,0,0,0)", + sql: "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)", best: "Insert", }, // Test dual. diff --git a/plan/planbuilder.go b/plan/planbuilder.go index cd2235b040227..5fe9ee5f846de 100644 --- a/plan/planbuilder.go +++ b/plan/planbuilder.go @@ -17,7 +17,6 @@ import ( "fmt" "strings" - "github.com/cznic/mathutil" "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" @@ -922,21 +921,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { table: tableInfo.Name.L, }) - columnByName := make(map[string]*table.Column, len(insertPlan.Table.Cols())) - for _, col := range insertPlan.Table.Cols() { - columnByName[col.Name.L] = col - } - - // Check insert.Columns contains generated columns or not. - // It's for INSERT INTO t (...) VALUES (...) - for _, col := range insert.Columns { - column, ok := columnByName[col.Name.L] - if ok && column.IsGenerated() { - b.err = ErrBadGeneratedColumn.GenByArgs(col.Name.O, tableInfo.Name.O) - return nil - } - } - mockTablePlan := LogicalTableDual{}.init(b.ctx) mockTablePlan.SetSchema(insertPlan.tableSchema) @@ -951,154 +935,256 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { return n } - cols := insertPlan.Table.Cols() - maxValuesItemLength := 0 // the max length of items in VALUES list. - for _, valuesItem := range insert.Lists { - exprList := make([]expression.Expression, 0, len(valuesItem)) - for i, valueItem := range valuesItem { - var expr expression.Expression - var err error - if dft, ok := valueItem.(*ast.DefaultExpr); ok { - if dft.Name != nil { - expr, err = b.findDefaultValue(cols, dft.Name) - } else { - expr, err = b.getDefaultValue(cols[i]) - } - } else if val, ok := valueItem.(*ast.ValueExpr); ok { - expr = &expression.Constant{ - Value: val.Datum, - RetType: &val.Type, - } - } else { - expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, true, checkRefColumn) - } - if err != nil { - b.err = errors.Trace(err) - } - exprList = append(exprList, expr) + if len(insert.Setlist) > 0 { + // Branch for `INSERT ... SET ...`. + b.buildSetValuesOfInsert(insert, insertPlan, mockTablePlan, checkRefColumn) + if b.err != nil { + return nil + } + } else if len(insert.Lists) > 0 { + // Branch for `INSERT ... VALUES ...`. + b.buildValuesListOfInsert(insert, insertPlan, mockTablePlan, checkRefColumn) + if b.err != nil { + return nil } - if len(valuesItem) > maxValuesItemLength { - maxValuesItemLength = len(valuesItem) + } else { + // Branch for `INSERT ... SELECT ...`. + b.buildSelectPlanOfInsert(insert, insertPlan) + if b.err != nil { + return nil } - insertPlan.Lists = append(insertPlan.Lists, exprList) } - // It's for INSERT INTO t VALUES (...) - if len(insert.Columns) == 0 { - // The length of VALUES list maybe exceed table width, - // we ignore this here but do checking in executor. - var effectiveValuesLen int - if maxValuesItemLength <= len(tableInfo.Columns) { - effectiveValuesLen = maxValuesItemLength - } else { - effectiveValuesLen = len(tableInfo.Columns) + mockTablePlan.SetSchema(insertPlan.Schema4OnDuplicate) + columnByName := make(map[string]*table.Column, len(insertPlan.Table.Cols())) + for _, col := range insertPlan.Table.Cols() { + columnByName[col.Name.L] = col + } + onDupColSet, dupCols, err := insertPlan.validateOnDup(insert.OnDuplicate, columnByName, tableInfo) + if err != nil { + b.err = errors.Trace(err) + return nil + } + for i, assign := range insert.OnDuplicate { + // Construct the function which calculates the assign value of the column. + expr, err := b.rewriteInsertOnDuplicateUpdate(assign.Expr, mockTablePlan, insertPlan) + if err != nil { + b.err = errors.Trace(err) + return nil } - for i := 0; i < effectiveValuesLen; i++ { - col := tableInfo.Columns[i] - if col.IsGenerated() { - b.err = ErrBadGeneratedColumn.GenByArgs(col.Name.O, tableInfo.Name.O) - return nil - } + + insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{ + Col: dupCols[i], + Expr: expr, + }) + } + + // Calculate generated columns. + mockTablePlan.schema = insertPlan.tableSchema + insertPlan.GenCols = b.resolveGeneratedColumns(insertPlan.Table.Cols(), onDupColSet, mockTablePlan) + if b.err != nil { + b.err = errors.Trace(b.err) + return nil + } + + insertPlan.ResolveIndices() + return insertPlan +} + +func (p *Insert) validateOnDup(onDup []*ast.Assignment, colMap map[string]*table.Column, tblInfo *model.TableInfo) (map[string]struct{}, []*expression.Column, error) { + onDupColSet := make(map[string]struct{}, len(onDup)) + dupCols := make([]*expression.Column, 0, len(onDup)) + for _, assign := range onDup { + // Check whether the column to be updated exists in the source table. + col, err := p.tableSchema.FindColumn(assign.Column) + if err != nil { + return nil, nil, errors.Trace(err) + } else if col == nil { + return nil, nil, ErrUnknownColumn.GenByArgs(assign.Column.OrigColName(), "field list") + } + + // Check whether the column to be updated is the generated column. + column := colMap[assign.Column.Name.L] + if column.IsGenerated() { + return nil, nil, ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tblInfo.Name.O) + } + onDupColSet[column.Name.L] = struct{}{} + dupCols = append(dupCols, col) + } + return onDupColSet, dupCols, nil +} + +func (b *planBuilder) getAffectCols(insertStmt *ast.InsertStmt, insertPlan *Insert) (affectedValuesCols []*table.Column) { + if len(insertStmt.Columns) > 0 { + // This branch is for the following scenarios: + // 1. `INSERT INTO tbl_name (col_name [, col_name] ...) {VALUES | VALUE} (value_list) [, (value_list)] ...`, + // 2. `INSERT INTO tbl_name (col_name [, col_name] ...) SELECT ...`. + colName := make([]string, 0, len(insertStmt.Columns)) + for _, col := range insertStmt.Columns { + colName = append(colName, col.Name.O) } + affectedValuesCols, b.err = table.FindCols(insertPlan.Table.Cols(), colName, insertPlan.Table.Meta().PKIsHandle) + if b.err != nil { + b.err = errors.Trace(b.err) + return + } + + } else if len(insertStmt.Setlist) == 0 { + // This branch is for the following scenarios: + // 1. `INSERT INTO tbl_name {VALUES | VALUE} (value_list) [, (value_list)] ...`, + // 2. `INSERT INTO tbl_name SELECT ...`. + affectedValuesCols = insertPlan.Table.Cols() } + return +} +func (b *planBuilder) buildSetValuesOfInsert(insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, + checkRefColumn func(n ast.Node) ast.Node) { + tableInfo := insertPlan.Table.Meta() + colNames := make([]string, 0, len(insert.Setlist)) + exprCols := make([]*expression.Column, 0, len(insert.Setlist)) for _, assign := range insert.Setlist { - col, err := insertPlan.tableSchema.FindColumn(assign.Column) + exprCol, err := insertPlan.tableSchema.FindColumn(assign.Column) if err != nil { b.err = errors.Trace(err) - return nil + return } - if col == nil { + if exprCol == nil { b.err = errors.Errorf("Can't find column %s", assign.Column) - return nil + return } + colNames = append(colNames, assign.Column.Name.L) + exprCols = append(exprCols, exprCol) + } - // Check whether the column to be updated is the generated column. - if columnByName[assign.Column.Name.L].IsGenerated() { - b.err = ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tableInfo.Name.O) - return nil + // Check whether the column to be updated is the generated column. + tCols, err := table.FindCols(insertPlan.Table.Cols(), colNames, tableInfo.PKIsHandle) + if err != nil { + b.err = errors.Trace(err) + return + } + for _, tCol := range tCols { + if tCol.IsGenerated() { + b.err = ErrBadGeneratedColumn.GenByArgs(tCol.Name.O, tableInfo.Name.O) + return } + } + for i, assign := range insert.Setlist { expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, true, checkRefColumn) if err != nil { b.err = errors.Trace(err) - return nil } insertPlan.SetList = append(insertPlan.SetList, &expression.Assignment{ - Col: col, + Col: exprCols[i], Expr: expr, }) } - insertPlan.Schema4OnDuplicate = insertPlan.tableSchema - if insert.Select != nil { - selectPlan := b.build(insert.Select) - if b.err != nil { - return nil - } +} - numInsertCols := mathutil.Min(selectPlan.Schema().Len(), len(tableInfo.Columns)) - // If the column to be inserted in the insert table is a generated - // column, raises a "ErrBadGeneratedColumn" error here. - for _, col := range tableInfo.Columns[:numInsertCols] { +func (b *planBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, + checkRefColumn func(n ast.Node) ast.Node) { + affectedValuesCols := b.getAffectCols(insert, insertPlan) + if b.err != nil { + b.err = errors.Trace(b.err) + return + } + + // If the value_list and col_list is empty and we have generated column, we can still write to this table. + // i.e. insert into t values(); can be executed successfully if t have generated column. + if len(insert.Columns) > 0 || len(insert.Lists[0]) > 0 { + // If value_list is not empty or the col_list is not empty, length of value_list should be the same with col_list's. + if len(insert.Lists[0]) != len(affectedValuesCols) { + b.err = ErrWrongValueCountOnRow.GenByArgs(1) + return + } + // No generated column is allowed. + for _, col := range affectedValuesCols { if col.IsGenerated() { - b.err = ErrBadGeneratedColumn.GenByArgs(col.Name.O, tableInfo.Name.O) - return nil + b.err = ErrBadGeneratedColumn.GenByArgs(col.Name.O, insertPlan.Table.Meta().Name.O) + return } } + } - insertPlan.SelectPlan, b.err = doOptimize(b.optFlag, selectPlan.(LogicalPlan)) - if b.err != nil { - return nil + totalTableCols := insertPlan.Table.Cols() + for i, valuesItem := range insert.Lists { + // The length of the all the value_list should be the same. + // "insert into t values (), ()" is valid. + // "insert into t values (), (1)" is not valid. + // "insert into t values (1), ()" is not valid. + // "insert into t values (1,2), (1)" is not valid. + if i > 0 && len(insert.Lists[i-1]) != len(insert.Lists[i]) { + b.err = ErrWrongValueCountOnRow.GenByArgs(i + 1) + return } - - insertPlan.Schema4OnDuplicate = expression.MergeSchema(insertPlan.tableSchema, insertPlan.SelectPlan.Schema()) + exprList := make([]expression.Expression, 0, len(valuesItem)) + for j, valueItem := range valuesItem { + var expr expression.Expression + var err error + switch x := valueItem.(type) { + case *ast.DefaultExpr: + if x.Name != nil { + expr, err = b.findDefaultValue(totalTableCols, x.Name) + } else { + expr, err = b.getDefaultValue(affectedValuesCols[j]) + } + case *ast.ValueExpr: + expr = &expression.Constant{ + Value: x.Datum, + RetType: &x.Type, + } + default: + expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, true, checkRefColumn) + } + if err != nil { + b.err = errors.Trace(err) + } + exprList = append(exprList, expr) + } + insertPlan.Lists = append(insertPlan.Lists, exprList) } + insertPlan.Schema4OnDuplicate = insertPlan.tableSchema +} - mockTablePlan.SetSchema(insertPlan.Schema4OnDuplicate) - onDupCols := make(map[string]struct{}, len(insert.OnDuplicate)) - for _, assign := range insert.OnDuplicate { - // Check whether the column to be updated exists in the source table. - col, err := insertPlan.tableSchema.FindColumn(assign.Column) - if err != nil { - b.err = errors.Trace(err) - return nil - } else if col == nil { - b.err = ErrUnknownColumn.GenByArgs(assign.Column.OrigColName(), "field list") - return nil - } +func (b *planBuilder) buildSelectPlanOfInsert(insert *ast.InsertStmt, insertPlan *Insert) { + affectedValuesCols := b.getAffectCols(insert, insertPlan) + if b.err != nil { + b.err = errors.Trace(b.err) + return + } + selectPlan := b.build(insert.Select) + if b.err != nil { + return + } - // Check whether the column to be updated is the generated column. - column := columnByName[assign.Column.Name.L] - if column.IsGenerated() { - b.err = ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tableInfo.Name.O) - return nil - } + // Check that the length of select' row is equal to the col list. + if selectPlan.Schema().Len() != len(affectedValuesCols) { + b.err = ErrWrongValueCountOnRow.GenByArgs(1) + return + } - // Construct the function which calculates the assign value of the column. - expr, err := b.rewriteInsertOnDuplicateUpdate(assign.Expr, mockTablePlan, insertPlan) - if err != nil { - b.err = errors.Trace(err) - return nil + // Check to guarantee that there's no generated column. + // This check should be done after the above one to make its behavior compatible with MySQL. + // For example, table t has two columns, namely a and b, and b is a generated column. + // "insert into t (b) select * from t" will raise an error that the column count is not matched. + // "insert into t select * from t" will raise an error that there's a generated column in the column list. + // If we do this check before the above one, "insert into t (b) select * from t" will raise an error + // that there's a generated column in the column list. + for _, col := range affectedValuesCols { + if col.IsGenerated() { + b.err = ErrBadGeneratedColumn.GenByArgs(col.Name.O, insertPlan.Table.Meta().Name.O) + return } - - insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{ - Col: col, - Expr: expr, - }) - onDupCols[column.Name.L] = struct{}{} } - // Calculate generated columns. - mockTablePlan.schema = insertPlan.tableSchema - insertPlan.GenCols = b.resolveGeneratedColumns(insertPlan.Table.Cols(), onDupCols, mockTablePlan) + insertPlan.SelectPlan, b.err = doOptimize(b.optFlag, selectPlan.(LogicalPlan)) if b.err != nil { - b.err = errors.Trace(b.err) - return nil + return } - insertPlan.ResolveIndices() - return insertPlan + insertPlan.Schema4OnDuplicate = expression.MergeSchema(insertPlan.tableSchema, insertPlan.SelectPlan.Schema()) } func (b *planBuilder) buildLoadData(ld *ast.LoadDataStmt) Plan { diff --git a/session/session_test.go b/session/session_test.go index ed4b0aa5ea622..0dda8ba3cf989 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" @@ -1578,7 +1579,7 @@ func (s *testSchemaSuite) TestPrepareStmtCommitWhenSchemaChanged(c *C) { tk.MustExec("alter table t drop column b") tk1.MustExec("execute stmt using @a, @a") _, err := tk1.Exec("commit") - c.Assert(terror.ErrorEqual(err, executor.ErrWrongValueCountOnRow), IsTrue, Commentf("err %v", err)) + c.Assert(terror.ErrorEqual(err, plan.ErrWrongValueCountOnRow), IsTrue, Commentf("err %v", err)) } func (s *testSchemaSuite) TestCommitWhenSchemaChanged(c *C) { @@ -1594,7 +1595,7 @@ func (s *testSchemaSuite) TestCommitWhenSchemaChanged(c *C) { // When tk1 commit, it will find schema already changed. tk1.MustExec("insert into t values (4, 4)") _, err := tk1.Exec("commit") - c.Assert(terror.ErrorEqual(err, executor.ErrWrongValueCountOnRow), IsTrue) + c.Assert(terror.ErrorEqual(err, plan.ErrWrongValueCountOnRow), IsTrue) } func (s *testSchemaSuite) TestRetrySchemaChange(c *C) {