From cba3e1a5c745397e857ebca6961f85a461a2bc61 Mon Sep 17 00:00:00 2001 From: Maxwell Date: Wed, 11 Dec 2019 10:53:28 +0800 Subject: [PATCH] planner: fixup some bugs with `DEFAULT` expression (#13168)(#13211)(#12550)(#11901) (#13682) --- ddl/db_integration_test.go | 62 ++++++++++++ ddl/db_test.go | 64 ++++++++++++ executor/insert_common.go | 10 +- executor/write_test.go | 146 +++++++++++++++++++++++++++ expression/expression.go | 16 +-- planner/core/logical_plan_builder.go | 48 ++++++--- planner/core/planbuilder.go | 97 ++++++++++++------ 7 files changed, 386 insertions(+), 57 deletions(-) diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 5fe8f5363571e..4a021883d33d2 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -1894,6 +1894,68 @@ func (s *testIntegrationSuite4) TestDropAutoIncrementIndex(c *C) { assertErrorCode(c, tk, dropIndexSQL, mysql.ErrWrongAutoKey) } +func (s *testIntegrationSuite3) TestInsertIntoGeneratedColumnWithDefaultExpr(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists test") + tk.MustExec("use test") + + // insert into virtual / stored columns + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int, b int as (-a) virtual, c int as (-a) stored)") + tk.MustExec("insert into t1 values (1, default, default)") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 -1 -1")) + tk.MustExec("delete from t1") + + // insert multiple rows + tk.MustExec("insert into t1(a,b) values (1, default), (2, default)") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 -1 -1", "2 -2 -2")) + tk.MustExec("delete from t1") + + // insert into generated columns only + tk.MustExec("insert into t1(b) values (default)") + tk.MustQuery("select * from t1").Check(testkit.Rows(" ")) + tk.MustExec("delete from t1") + tk.MustExec("insert into t1(c) values (default)") + tk.MustQuery("select * from t1").Check(testkit.Rows(" ")) + tk.MustExec("delete from t1") + + // generated columns with index + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t2 like t1") + tk.MustExec("alter table t2 add index idx1(a)") + tk.MustExec("alter table t2 add index idx2(b)") + tk.MustExec("insert into t2 values (1, default, default)") + tk.MustQuery("select * from t2").Check(testkit.Rows("1 -1 -1")) + tk.MustExec("delete from t2") + tk.MustExec("alter table t2 drop index idx1") + tk.MustExec("alter table t2 drop index idx2") + tk.MustExec("insert into t2 values (1, default, default)") + tk.MustQuery("select * from t2").Check(testkit.Rows("1 -1 -1")) + + // generated columns in different position + tk.MustExec("drop table if exists t3") + tk.MustExec("create table t3 (gc1 int as (r+1), gc2 int as (r+1) stored, gc3 int as (gc2+1), gc4 int as (gc1+1) stored, r int)") + tk.MustExec("insert into t3 values (default, default, default, default, 1)") + tk.MustQuery("select * from t3").Check(testkit.Rows("2 2 3 3 1")) + + // generated columns in replace statement + tk.MustExec("create table t4 (a int key, b int, c int as (a+1), d int as (b+1) stored)") + tk.MustExec("insert into t4 values (1, 10, default, default)") + tk.MustQuery("select * from t4").Check(testkit.Rows("1 10 2 11")) + tk.MustExec("replace into t4 values (1, 20, default, default)") + tk.MustQuery("select * from t4").Check(testkit.Rows("1 20 2 21")) + + // generated columns with default function is not allowed + tk.MustExec("create table t5 (a int default 10, b int as (a+1))") + assertErrorCode(c, tk, "insert into t5 values (20, default(a))", mysql.ErrBadGeneratedColumn) + + tk.MustExec("drop table t1") + tk.MustExec("drop table t2") + tk.MustExec("drop table t3") + tk.MustExec("drop table t4") + tk.MustExec("drop table t5") +} + func (s *testIntegrationSuite3) TestParserIssue284(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/ddl/db_test.go b/ddl/db_test.go index 49462225074b1..9dbf1809de9f4 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -3048,6 +3048,70 @@ func (s *testDBSuite5) TestModifyGeneratedColumn(c *C) { tk.MustQuery("select * from t1").Check(testkit.Rows("1 2")) } +func (s *testDBSuite5) TestDefaultSQLFunction(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists test;") + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1, t2, t3, t4;") + + // For issue #13189 + // Use `DEFAULT()` in `INSERT` / `INSERT ON DUPLICATE KEY UPDATE` statement + tk.MustExec("create table t1(a int primary key, b int default 20, c int default 30, d int default 40);") + tk.MustExec("insert into t1 set a = 1, b = default(c);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 30 30 40")) + tk.MustExec("insert into t1 set a = 2, b = default(c), c = default(d), d = default(b);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 30 30 40", "2 30 40 20")) + tk.MustExec("insert into t1 values (2, 3, 4, 5) on duplicate key update b = default(d), c = default(b);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 30 30 40", "2 40 20 20")) + tk.MustExec("delete from t1") + tk.MustExec("insert into t1 set a = default(b) + default(c) - default(d)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 20 30 40")) + // Use `DEFAULT()` in `UPDATE` statement + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 value (1, 2, 3, 4);") + tk.MustExec("update t1 set a = 1, c = default(b);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 2 20 4")) + tk.MustExec("insert into t1 value (2, 2, 3, 4);") + tk.MustExec("update t1 set c = default(b), b = default(c) where a = 2;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 2 20 4", "2 30 20 4")) + tk.MustExec("delete from t1") + tk.MustExec("insert into t1 set a = 10") + tk.MustExec("update t1 set a = 10, b = default(c) + default(d)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 70 30 40")) + // Use `DEFAULT()` in `REPLACE` statement + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 value (1, 2, 3, 4);") + tk.MustExec("replace into t1 set a = 1, c = default(b);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 20 40")) + tk.MustExec("insert into t1 value (2, 2, 3, 4);") + tk.MustExec("replace into t1 set a = 2, d = default(b), c = default(d);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 20 40", "2 20 40 20")) + tk.MustExec("delete from t1") + tk.MustExec("insert into t1 set a = 10, c = 3") + tk.MustExec("replace into t1 set a = 10, b = default(c) + default(d)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 70 30 40")) + tk.MustExec("replace into t1 set a = 20, d = default(c) + default(b)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 70 30 40", "20 20 30 50")) + + // Use `DEFAULT()` in expression of generate columns, issue #12471 + tk.MustExec("create table t2(a int default 9, b int as (1 + default(a)));") + tk.MustExec("insert into t2 values(1, default);") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 10")) + + // Use `DEFAULT()` with subquery, issue #13390 + tk.MustExec("create table t3(f1 int default 11);") + tk.MustExec("insert into t3 value ();") + tk.MustQuery("select default(f1) from (select * from t3) t1;").Check(testkit.Rows("11")) + tk.MustQuery("select default(f1) from (select * from (select * from t3) t1 ) t1;").Check(testkit.Rows("11")) + + tk.MustExec("create table t4(a int default 4);") + tk.MustExec("insert into t4 value (2);") + tk.MustQuery("select default(c) from (select b as c from (select a as b from t4) t3) t2;").Check(testkit.Rows("4")) + tk.MustGetErrCode("select default(a) from (select a from (select 1 as a) t4) t4;", mysql.ErrNoDefaultForField) + + tk.MustExec("drop table t1, t2, t3, t4;") +} + func (s *testDBSuite4) TestIssue9100(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test_db") diff --git a/executor/insert_common.go b/executor/insert_common.go index 44a92ca8db267..a043e8316b9f7 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -91,9 +91,6 @@ func (e *InsertValues) initInsertColumns() error { for _, v := range e.SetList { columns = append(columns, v.Col.ColName.O) } - for _, v := range e.GenColumns { - columns = append(columns, v.Name.O) - } cols, err = table.FindCols(tableCols, columns, e.Table.Meta().PKIsHandle) if err != nil { return errors.Errorf("INSERT INTO %s: %s", e.Table.Meta().Name.O, err) @@ -107,9 +104,6 @@ func (e *InsertValues) initInsertColumns() error { for _, v := range e.Columns { columns = append(columns, v.Name.O) } - for _, v := range e.GenColumns { - columns = append(columns, v.Name.O) - } cols, err = table.FindCols(tableCols, columns, e.Table.Meta().PKIsHandle) if err != nil { return errors.Errorf("INSERT INTO %s: %s", e.Table.Meta().Name.O, err) @@ -119,6 +113,9 @@ func (e *InsertValues) initInsertColumns() error { cols = tableCols } for _, col := range cols { + if !col.IsGenerated() { + e.insertColumns = append(e.insertColumns, col) + } if col.Name.L == model.ExtraHandleName.L { if !e.ctx.GetSessionVars().AllowWriteRowID { return errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported.") @@ -133,7 +130,6 @@ func (e *InsertValues) initInsertColumns() error { if err != nil { return err } - e.insertColumns = cols return nil } diff --git a/executor/write_test.go b/executor/write_test.go index 9f9bd47bce244..c2ef457fe1074 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -747,6 +747,88 @@ func (s *testSuite4) TestInsertIgnoreOnDup(c *C) { r.Check(testkit.Rows("1 1", "2 2")) } +func (s *testSuite4) TestInsertSetWithDefault(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + // Assign `DEFAULT` in `INSERT ... SET ...` statement + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (a int default 10, b int default 20);") + tk.MustExec("insert into t1 set a=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 20")) + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 set b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 20")) + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 set b=default, a=1;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20")) + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 set a=default(a);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("10 20")) + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 set a=default(b), b=default(a)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("20 10")) + tk.MustExec("delete from t1;") + tk.MustExec("insert into t1 set a=default(b)+default(a);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("30 20")) + // With generated columns + tk.MustExec("create table t2 (a int default 10, b int generated always as (-a) virtual, c int generated always as (-a) stored);") + tk.MustExec("insert into t2 set a=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("10 -10 -10")) + tk.MustExec("delete from t2;") + tk.MustExec("insert into t2 set a=2, b=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("2 -2 -2")) + tk.MustExec("delete from t2;") + tk.MustExec("insert into t2 set c=default, a=3;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("3 -3 -3")) + tk.MustExec("delete from t2;") + tk.MustExec("insert into t2 set a=default, b=default, c=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("10 -10 -10")) + tk.MustExec("delete from t2;") + tk.MustExec("insert into t2 set a=default(a), b=default, c=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("10 -10 -10")) + tk.MustExec("delete from t2;") + tk.MustGetErrCode("insert into t2 set b=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 set a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 set a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 set a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustExec("drop table t1, t2") +} + +func (s *testSuite4) TestInsertOnDupUpdateDefault(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + // Assign `DEFAULT` in `INSERT ... ON DUPLICATE KEY UPDATE ...` statement + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (a int unique, b int default 20, c int default 30);") + tk.MustExec("insert into t1 values (1,default,default);") + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30")) + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30")) + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, a=2") + tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 30")) + tk.MustExec("insert into t1 values (2,default,default) on duplicate key update c=default(b)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 20")) + tk.MustExec("insert into t1 values (2,default,default) on duplicate key update a=default(b)+default(c)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("50 20 20")) + // With generated columns + tk.MustExec("create table t2 (a int unique, b int generated always as (-a) virtual, c int generated always as (-a) stored);") + tk.MustExec("insert into t2 values (1,default,default);") + tk.MustExec("insert into t2 values (1,default,default) on duplicate key update a=2, b=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("2 -2 -2")) + tk.MustExec("insert into t2 values (2,default,default) on duplicate key update a=3, c=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("3 -3 -3")) + tk.MustExec("insert into t2 values (3,default,default) on duplicate key update c=default, b=default, a=4;") + tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4")) + tk.MustExec("insert into t2 values (10,default,default) on duplicate key update b=default, a=20, c=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4", "10 -10 -10")) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update b=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustExec("drop table t1, t2") +} + func (s *testSuite4) TestReplace(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -896,6 +978,36 @@ func (s *testSuite4) TestReplace(c *C) { tk.MustExec(`replace into t1 select * from (select 1, 2) as tmp;`) c.Assert(int64(tk.Se.AffectedRows()), Equals, int64(2)) tk.CheckLastMessage("Records: 1 Duplicates: 1 Warnings: 0") + + // Assign `DEFAULT` in `REPLACE` statement + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (a int primary key, b int default 20, c int default 30);") + tk.MustExec("insert into t1 value (1, 2, 3);") + tk.MustExec("replace t1 set a=1, b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30")) + tk.MustExec("replace t1 set a=2, b=default, c=default") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30", "2 20 30")) + tk.MustExec("replace t1 set a=2, b=default(c), c=default(b);") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30", "2 30 20")) + tk.MustExec("replace t1 set a=default(b)+default(c)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30", "2 30 20", "50 20 30")) + // With generated columns + tk.MustExec("create table t2 (pk int primary key, a int default 1, b int generated always as (-a) virtual, c int generated always as (-a) stored);") + tk.MustExec("replace t2 set pk=1, b=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 1 -1 -1")) + tk.MustExec("replace t2 set pk=2, a=10, b=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 1 -1 -1", "2 10 -10 -10")) + tk.MustExec("replace t2 set pk=2, c=default, a=20;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 1 -1 -1", "2 20 -20 -20")) + tk.MustExec("replace t2 set pk=2, a=default, b=default, c=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 1 -1 -1", "2 1 -1 -1")) + tk.MustExec("replace t2 set pk=3, a=default(a), b=default, c=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 1 -1 -1", "2 1 -1 -1", "3 1 -1 -1")) + tk.MustGetErrCode("replace t2 set b=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("replace t2 set a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("replace t2 set a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("replace t2 set a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustExec("drop table t1, t2") } func (s *testSuite2) TestGeneratedColumnForInsert(c *C) { @@ -1393,6 +1505,40 @@ func (s *testSuite) TestUpdate(c *C) { _, err = tk.Exec("update v set a = '2000-11-11'") c.Assert(err.Error(), Equals, core.ErrViewInvalid.GenWithStackByArgs("test", "v").Error()) tk.MustExec("drop view v") + + // Assign `DEFAULT` in `UPDATE` statement + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (a int default 1, b int default 2);") + tk.MustExec("insert into t1 values (10, 10), (20, 20);") + tk.MustExec("update t1 set a=default where b=10;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 10", "20 20")) + tk.MustExec("update t1 set a=30, b=default where a=20;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 10", "30 2")) + tk.MustExec("update t1 set a=default, b=default where a=30;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 10", "1 2")) + tk.MustExec("insert into t1 values (40, 40)") + tk.MustExec("update t1 set a=default, b=default") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 2", "1 2", "1 2")) + tk.MustExec("update t1 set a=default(b), b=default(a)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("2 1", "2 1", "2 1")) + // With generated columns + tk.MustExec("create table t2 (a int default 1, b int generated always as (-a) virtual, c int generated always as (-a) stored);") + tk.MustExec("insert into t2 values (10, default, default), (20, default, default)") + tk.MustExec("update t2 set b=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("10 -10 -10", "20 -20 -20")) + tk.MustExec("update t2 set a=30, b=default where a=10;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("30 -30 -30", "20 -20 -20")) + tk.MustExec("update t2 set c=default, a=40 where c=-20;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("30 -30 -30", "40 -40 -40")) + tk.MustExec("update t2 set a=default, b=default, c=default where b=-30;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 -1 -1", "40 -40 -40")) + tk.MustExec("update t2 set a=default(a), b=default, c=default;") + tk.MustQuery("select * from t2;").Check(testkit.Rows("1 -1 -1", "1 -1 -1")) + tk.MustGetErrCode("update t2 set b=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("update t2 set a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("update t2 set a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("update t2 set a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustExec("drop table t1, t2") } func (s *testSuite4) TestPartitionedTableUpdate(c *C) { diff --git a/expression/expression.go b/expression/expression.go index 6b47230c132d8..adb5f6cac186b 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -347,13 +347,15 @@ func ColumnInfos2ColumnsWithDBName(ctx sessionctx.Context, dbName, tblName model continue } newCol := &Column{ - ColName: col.Name, - TblName: tblName, - DBName: dbName, - RetType: &col.FieldType, - ID: col.ID, - UniqueID: ctx.GetSessionVars().AllocPlanColumnID(), - Index: col.Offset, + OrigColName: col.Name, + OrigTblName: tblName, + ColName: col.Name, + TblName: tblName, + DBName: dbName, + RetType: &col.FieldType, + ID: col.ID, + UniqueID: ctx.GetSessionVars().AllocPlanColumnID(), + Index: col.Offset, } columns = append(columns, newCol) } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 6e23f4f6553d0..5d85092d959b1 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -164,7 +164,6 @@ func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSet v.TableAsName = &x.AsName } for _, col := range p.Schema().Columns { - col.OrigTblName = col.TblName if x.AsName.L != "" { col.TblName = x.AsName } @@ -608,19 +607,18 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a // buildProjectionFieldNameFromColumns builds the field name, table name and database name when field expression is a column reference. func (b *PlanBuilder) buildProjectionFieldNameFromColumns(origField *ast.SelectField, colNameField *ast.ColumnNameExpr, c *expression.Column) (colName, origColName, tblName, origTblName, dbName model.CIStr) { - origColName, tblName, dbName = colNameField.Name.Name, colNameField.Name.Table, colNameField.Name.Schema - if origField.AsName.L != "" { - colName = origField.AsName + origTblName, origColName, dbName = c.OrigTblName, c.OrigColName, c.DBName + if origField.AsName.L == "" { + colName = colNameField.Name.Name } else { - colName = origColName + colName = origField.AsName } if tblName.L == "" { tblName = c.TblName + } else { + tblName = colNameField.Name.Table } - if dbName.L == "" { - dbName = c.DBName - } - return colName, origColName, tblName, c.OrigTblName, c.DBName + return } // buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression. @@ -2295,6 +2293,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName) (L DBName: dbName, TblName: tableInfo.Name, ColName: col.Name, + OrigTblName: tableInfo.Name, OrigColName: col.Name, ID: col.ID, RetType: &col.FieldType, @@ -2674,14 +2673,22 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) ([]*expression.Assignment, LogicalPlan, error) { b.curClause = fieldList - modifyColumns := make(map[string]struct{}, p.Schema().Len()) // Which columns are in set list. + // modifyColumns indicates which columns are in set list, + // and if it is set to `DEFAULT` + modifyColumns := make(map[string]bool, p.Schema().Len()) for _, assign := range list { col, _, err := p.findColumn(assign.Column) if err != nil { return nil, nil, err } columnFullName := fmt.Sprintf("%s.%s.%s", col.DBName.L, col.TblName.L, col.ColName.L) - modifyColumns[columnFullName] = struct{}{} + // We save a flag for the column in map `modifyColumns` + // This flag indicated if assign keyword `DEFAULT` to the column + if extractDefaultExpr(assign.Expr) != nil { + modifyColumns[columnFullName] = true + } else { + modifyColumns[columnFullName] = false + } } // If columns in set list contains generated columns, raise error. @@ -2699,9 +2706,12 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab continue } columnFullName := fmt.Sprintf("%s.%s.%s", tn.Schema.L, tn.Name.L, colInfo.Name.L) - if _, ok := modifyColumns[columnFullName]; ok { + // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. + // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html + if isDefault, ok := modifyColumns[columnFullName]; ok && !isDefault { return nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(colInfo.Name.O, tableInfo.Name.O) } + virtualAssignments = append(virtualAssignments, &ast.Assignment{ Column: &ast.ColumnName{Schema: tn.Schema, Table: tn.Name, Name: colInfo.Name}, Expr: tableVal.Cols()[i].GeneratedExpr, @@ -2719,6 +2729,10 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab var newExpr expression.Expression var np LogicalPlan if i < len(list) { + // If assign `DEFAULT` to column, fill the `defaultExpr.Name` before rewrite expression + if expr := extractDefaultExpr(assign.Expr); expr != nil { + expr.Name = assign.Column + } newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, false) } else { // rewrite with generation expression @@ -2764,6 +2778,16 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab return newList, p, nil } +// extractDefaultExpr extract a `DefaultExpr` from `ExprNode`, +// If it is a `DEFAULT` function like `DEFAULT(a)`, return nil. +// Only if it is `DEFAULT` keyword, it will return the `DefaultExpr`. +func extractDefaultExpr(node ast.ExprNode) *ast.DefaultExpr { + if expr, ok := node.(*ast.DefaultExpr); ok && expr.Name == nil { + return expr + } + return nil +} + func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) (Plan, error) { if b.pushTableHints(delete.TableHints) { // table hints are only visible in the current DELETE statement. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index b84649d5e4e9e..45b29f02683d1 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -1555,26 +1555,13 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( } mockTablePlan.SetSchema(insertPlan.Schema4OnDuplicate) - columnByName := make(map[string]*table.Column, len(insertPlan.Table.Cols())) - for _, col := range insertPlan.Table.Cols() { - columnByName[col.Name.L] = col - } - onDupColSet, dupCols, err := insertPlan.validateOnDup(insert.OnDuplicate, columnByName, tableInfo) + + onDupColSet, err := insertPlan.resolveOnDuplicate(insert.OnDuplicate, tableInfo, func(node ast.ExprNode) (expression.Expression, error) { + return b.rewriteInsertOnDuplicateUpdate(ctx, node, mockTablePlan, insertPlan) + }) if err != nil { return nil, err } - for i, assign := range insert.OnDuplicate { - // Construct the function which calculates the assign value of the column. - expr, err1 := b.rewriteInsertOnDuplicateUpdate(ctx, assign.Expr, mockTablePlan, insertPlan) - if err1 != nil { - return nil, err1 - } - - insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{ - Col: dupCols[i], - Expr: expr, - }) - } // Calculate generated columns. mockTablePlan.schema = insertPlan.tableSchema @@ -1587,27 +1574,49 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( return insertPlan, err } -func (p *Insert) validateOnDup(onDup []*ast.Assignment, colMap map[string]*table.Column, tblInfo *model.TableInfo) (map[string]struct{}, []*expression.Column, error) { +func (p *Insert) resolveOnDuplicate(onDup []*ast.Assignment, tblInfo *model.TableInfo, yield func(ast.ExprNode) (expression.Expression, error)) (map[string]struct{}, error) { onDupColSet := make(map[string]struct{}, len(onDup)) - dupCols := make([]*expression.Column, 0, len(onDup)) + colMap := make(map[string]*table.Column, len(p.Table.Cols())) + for _, col := range p.Table.Cols() { + colMap[col.Name.L] = col + } for _, assign := range onDup { // Check whether the column to be updated exists in the source table. col, err := p.tableSchema.FindColumn(assign.Column) if err != nil { - return nil, nil, err + return nil, err } else if col == nil { - return nil, nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list") + return nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list") } // Check whether the column to be updated is the generated column. column := colMap[assign.Column.Name.L] + defaultExpr := extractDefaultExpr(assign.Expr) + if defaultExpr != nil { + defaultExpr.Name = assign.Column + } + // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. + // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html if column.IsGenerated() { - return nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O) + if defaultExpr != nil { + continue + } + return nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O) } + onDupColSet[column.Name.L] = struct{}{} - dupCols = append(dupCols, col) + + expr, err := yield(assign.Expr) + if err != nil { + return nil, err + } + + p.OnDuplicate = append(p.OnDuplicate, &expression.Assignment{ + Col: col, + Expr: expr, + }) } - return onDupColSet, dupCols, nil + return onDupColSet, nil } func (b *PlanBuilder) getAffectCols(insertStmt *ast.InsertStmt, insertPlan *Insert) (affectedValuesCols []*table.Column, err error) { @@ -1654,13 +1663,26 @@ func (b *PlanBuilder) buildSetValuesOfInsert(ctx context.Context, insert *ast.In if err != nil { return err } + generatedColumns := make(map[string]struct{}, len(tCols)) for _, tCol := range tCols { if tCol.IsGenerated() { - return ErrBadGeneratedColumn.GenWithStackByArgs(tCol.Name.O, tableInfo.Name.O) + generatedColumns[tCol.Name.L] = struct{}{} } } for i, assign := range insert.Setlist { + defaultExpr := extractDefaultExpr(assign.Expr) + if defaultExpr != nil { + defaultExpr.Name = assign.Column + } + // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. + // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html + if _, ok := generatedColumns[assign.Column.Name.L]; ok { + if defaultExpr != nil { + continue + } + return ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tableInfo.Name.O) + } expr, _, err := b.rewriteWithPreprocess(ctx, assign.Expr, mockTablePlan, nil, nil, true, checkRefColumn) if err != nil { return err @@ -1687,12 +1709,6 @@ func (b *PlanBuilder) buildValuesListOfInsert(ctx context.Context, insert *ast.I if len(insert.Lists[0]) != len(affectedValuesCols) { return ErrWrongValueCountOnRow.GenWithStackByArgs(1) } - // No generated column is allowed. - for _, col := range affectedValuesCols { - if col.IsGenerated() { - return ErrBadGeneratedColumn.GenWithStackByArgs(col.Name.O, insertPlan.Table.Meta().Name.O) - } - } } totalTableCols := insertPlan.Table.Cols() @@ -1709,8 +1725,17 @@ func (b *PlanBuilder) buildValuesListOfInsert(ctx context.Context, insert *ast.I for j, valueItem := range valuesItem { var expr expression.Expression var err error + var generatedColumnWithDefaultExpr bool + col := affectedValuesCols[j] switch x := valueItem.(type) { case *ast.DefaultExpr: + if col.IsGenerated() { + if x.Name != nil { + return ErrBadGeneratedColumn.GenWithStackByArgs(col.Name.O, insertPlan.Table.Meta().Name.O) + } + generatedColumnWithDefaultExpr = true + break + } if x.Name != nil { expr, err = b.findDefaultValue(totalTableCols, x.Name) } else { @@ -1727,6 +1752,16 @@ func (b *PlanBuilder) buildValuesListOfInsert(ctx context.Context, insert *ast.I if err != nil { return err } + // insert value into a generated column is not allowed + if col.IsGenerated() { + // but there is only one exception: + // it is allowed to insert the `default` value into a generated column + if generatedColumnWithDefaultExpr { + continue + } + return ErrBadGeneratedColumn.GenWithStackByArgs(col.Name.O, insertPlan.Table.Meta().Name.O) + } + exprList = append(exprList, expr) } insertPlan.Lists = append(insertPlan.Lists, exprList)