Skip to content

Commit

Permalink
plan: refactor the code of building Insert. (#7068)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored Jul 25, 2018
1 parent 85cd246 commit b5f9b35
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 178 deletions.
35 changes: 23 additions & 12 deletions ddl/db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ func (s *testStateChangeSuite) TestTwoStates(c *C) {
testInfo.sqlInfos[0].sql = "insert into t (c1, c2, c3, c4) value(2, 'b', 'N', '2017-07-02')"
testInfo.sqlInfos[1].sql = "insert into t (c1, c2, c3, d3, c4) value(3, 'b', 'N', 'a', '2017-07-03')"
unknownColErr := errors.New("unknown column d3")
testInfo.sqlInfos[1].cases[0].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[1].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[2].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[3].expectedErr = unknownColErr
testInfo.sqlInfos[1].cases[0].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[1].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[2].expectedCompileErr = unknownColErr
testInfo.sqlInfos[1].cases[3].expectedCompileErr = unknownColErr
testInfo.sqlInfos[2].sql = "update t set c2 = 'c2_update'"
testInfo.sqlInfos[3].sql = "replace into t values(5, 'e', 'N', '2017-07-05')'"
testInfo.sqlInfos[3].cases[4].expectedErr = errors.New("Column count doesn't match value count at row 1")
testInfo.sqlInfos[3].cases[4].expectedCompileErr = errors.New("Column count doesn't match value count at row 1")
alterTableSQL := "alter table t add column d3 enum('a', 'b') not null default 'a' after c3"
s.test(c, "", alterTableSQL, testInfo)
// TODO: Add more DDL statements.
Expand Down Expand Up @@ -227,10 +227,11 @@ func (s *testStateChangeSuite) test(c *C, tableName, alterTableSQL string, testI
}

type stateCase struct {
session session.Session
rawStmt ast.StmtNode
stmt ast.Statement
expectedErr error
session session.Session
rawStmt ast.StmtNode
stmt ast.Statement
expectedExecErr error
expectedCompileErr error
}

type sqlInfo struct {
Expand Down Expand Up @@ -299,6 +300,13 @@ func (t *testExecInfo) compileSQL(idx int) (err error) {
return errors.Trace(err)
}
c.stmt, err = compiler.Compile(ctx, c.rawStmt)
if c.expectedCompileErr != nil {
if err == nil {
err = errors.Errorf("expected error %s but got nil", c.expectedCompileErr)
} else if strings.Contains(err.Error(), c.expectedCompileErr.Error()) {
err = nil
}
}
if err != nil {
return errors.Trace(err)
}
Expand All @@ -309,11 +317,14 @@ func (t *testExecInfo) compileSQL(idx int) (err error) {
func (t *testExecInfo) execSQL(idx int) error {
for _, sqlInfo := range t.sqlInfos {
c := sqlInfo.cases[idx]
if c.expectedCompileErr != nil {
continue
}
_, err := c.stmt.Exec(context.TODO())
if c.expectedErr != nil {
if c.expectedExecErr != nil {
if err == nil {
err = errors.Errorf("expected error %s but got nil", c.expectedErr)
} else if strings.Contains(err.Error(), c.expectedErr.Error()) {
err = errors.Errorf("expected error %s but got nil", c.expectedExecErr)
} else if strings.Contains(err.Error(), c.expectedExecErr.Error()) {
err = nil
}
}
Expand Down
1 change: 0 additions & 1 deletion executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ var (

ErrPasswordNoMatch = terror.ClassExecutor.New(mysql.ErrPasswordNoMatch, mysql.MySQLErrName[mysql.ErrPasswordNoMatch])
ErrCannotUser = terror.ClassExecutor.New(mysql.ErrCannotUser, mysql.MySQLErrName[mysql.ErrCannotUser])
ErrWrongValueCountOnRow = terror.ClassExecutor.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow])
ErrPasswordFormat = terror.ClassExecutor.New(mysql.ErrPasswordFormat, mysql.MySQLErrName[mysql.ErrPasswordFormat])
ErrCantChangeTxCharacteristics = terror.ClassExecutor.New(mysql.ErrCantChangeTxCharacteristics, mysql.MySQLErrName[mysql.ErrCantChangeTxCharacteristics])
)
Expand Down
9 changes: 6 additions & 3 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ func (s *testSuite) TestMultiUpdate(c *C) {
func (s *testSuite) TestGeneratedColumnWrite(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`CREATE TABLE test_gc_write (a int primary key, b int, c int as (a+8) virtual)`)
tk.MustExec(`CREATE TABLE test_gc_write (a int primary key auto_increment, b int, c int as (a+8) virtual)`)
tk.MustExec(`CREATE TABLE test_gc_write_1 (a int primary key, b int, c int)`)

tests := []struct {
Expand All @@ -1336,20 +1336,23 @@ func (s *testSuite) TestGeneratedColumnWrite(c *C) {
// Can insert without generated columns.
{`insert into test_gc_write (a, b) values (1, 1)`, 0},
{`insert into test_gc_write set a = 2, b = 2`, 0},
{`insert into test_gc_write (b) select c from test_gc_write`, 0},
// Can update without generated columns.
{`update test_gc_write set b = 2 where a = 2`, 0},
{`update test_gc_write t1, test_gc_write_1 t2 set t1.b = 3, t2.b = 4`, 0},

// But now we can't do this, just as same with MySQL 5.7:
{`insert into test_gc_write values (1, 1)`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write select 1, 1`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write (c) select a, b from test_gc_write`, mysql.ErrWrongValueCountOnRow},
{`insert into test_gc_write (b, c) select a, b from test_gc_write`, mysql.ErrBadGeneratedColumn},
}
for _, tt := range tests {
_, err := tk.Exec(tt.stmt)
if tt.err != 0 {
c.Assert(err, NotNil)
c.Assert(err, NotNil, Commentf("sql is `%v`", tt.stmt))
terr := errors.Trace(err).(*errors.Err).Cause().(*terror.Error)
c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err))
c.Assert(terr.Code(), Equals, terror.ErrCode(tt.err), Commentf("sql is %v", tt.stmt))
} else {
c.Assert(err, IsNil)
}
Expand Down
36 changes: 0 additions & 36 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,47 +148,14 @@ func (e *InsertValues) fillValueList() error {
return nil
}

func (e *InsertValues) checkValueCount(insertValueCount, valueCount, genColsCount, num int, cols []*table.Column) error {
// TODO: This check should be done in plan builder.
if insertValueCount != valueCount {
// "insert into t values (), ()" is valid.
// "insert into t values (), (1)" is not valid.
// "insert into t values (1), ()" is not valid.
// "insert into t values (1,2), (1)" is not valid.
// So the value count must be same for all insert list.
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
}
if valueCount == 0 && len(e.Columns) > 0 {
// "insert into t (c1) values ()" is not valid.
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
} else if valueCount > 0 {
explicitSetLen := 0
if len(e.Columns) != 0 {
explicitSetLen = len(e.Columns)
} else {
explicitSetLen = len(e.SetList)
}
if explicitSetLen > 0 && valueCount+genColsCount != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
} else if explicitSetLen == 0 && valueCount != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(num + 1)
}
}
return nil
}

func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows []types.DatumRow) error) (err error) {
// process `insert|replace ... set x=y...`
if err = e.fillValueList(); err != nil {
return errors.Trace(err)
}

rows := make([]types.DatumRow, len(e.Lists))
length := len(e.Lists[0])
for i, list := range e.Lists {
if err = e.checkValueCount(length, len(list), len(e.GenColumns), i, cols); err != nil {
return errors.Trace(err)
}
e.rowCount = uint64(i)
rows[i], err = e.getRow(cols, list, i)
if err != nil {
Expand Down Expand Up @@ -277,9 +244,6 @@ func (e *InsertValues) fillDefaultValues(row types.DatumRow, hasValue []bool) er
func (e *InsertValues) insertRowsFromSelect(ctx context.Context, cols []*table.Column, exec func(rows []types.DatumRow) error) error {
// process `insert|replace into ... select ... from ...`
selectExec := e.children[0]
if selectExec.Schema().Len() != len(cols) {
return ErrWrongValueCountOnRow.GenByArgs(1)
}
fields := selectExec.retTypes()
chk := selectExec.newChunk()
iter := chunk.NewIterator4Chunk(chk)
Expand Down
21 changes: 21 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ func (s *testSuite) TestInsert(c *C) {
c.Assert(err, NotNil)
tk.MustExec("rollback")

errInsertSelectSQL = `insert insert_test_1 values(default, default, default, default, default)`
tk.MustExec("begin")
_, err = tk.Exec(errInsertSelectSQL)
c.Assert(err, NotNil)
tk.MustExec("rollback")

// Updating column is PK handle.
// Make sure the record is "1, 1, nil, 1".
r := tk.MustQuery("select * from insert_test where id = 1;")
Expand Down Expand Up @@ -240,6 +246,21 @@ func (s *testSuite) TestInsert(c *C) {
Check(testkit.Rows("Warning 1690 constant -1.1 overflows float", "Warning 1690 constant -1.1 overflows double",
"Warning 1690 constant -2.1 overflows float", "Warning 1690 constant -2.1 overflows double"))
tk.MustQuery("select * from t").Check(testkit.Rows("0 0", "0 0", "0 0", "1.1 1.1"))

// issue 7061
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 1, b int default 2)")
tk.MustExec("insert into t values(default, default)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t values(default(b), default(a))")
tk.MustQuery("select * from t").Check(testkit.Rows("2 1"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t (b) values(default)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2"))
tk.MustExec("truncate table t")
tk.MustExec("insert into t (b) values(default(a))")
tk.MustQuery("select * from t").Check(testkit.Rows("1 1"))
}

func (s *testSuite) TestInsertAutoInc(c *C) {
Expand Down
3 changes: 3 additions & 0 deletions plan/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
codeMixOfGroupFuncAndFields = mysql.ErrMixOfGroupFuncAndFields
codeNonUniqTable = mysql.ErrNonuniqTable
codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect
codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow
)

// error definitions.
Expand Down Expand Up @@ -81,6 +82,7 @@ var (
ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal])
ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by")
ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable])
ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow])
)

func init() {
Expand All @@ -107,6 +109,7 @@ func init() {
codeMixOfGroupFuncAndFields: mysql.ErrMixOfGroupFuncAndFields,
codeNonUniqTable: mysql.ErrNonuniqTable,
codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect,
codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow,
}
terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap
}
2 changes: 1 addition & 1 deletion plan/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) {
ans []visitInfo
}{
{
sql: "insert into t values (1)",
sql: "insert into t (a) values (1)",
ans: []visitInfo{
{mysql.InsertPriv, "test", "t", ""},
},
Expand Down
2 changes: 1 addition & 1 deletion plan/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderBasePhysicalPlan(c *C) {
},
// Test simple insert.
{
sql: "insert into t values(0,0,0,0,0,0,0)",
sql: "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)",
best: "Insert",
},
// Test dual.
Expand Down
Loading

0 comments on commit b5f9b35

Please sign in to comment.