From 9522f811f3b9b0488df084823e8af5dbfa53b5c2 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Tue, 10 Dec 2019 16:03:25 +0800 Subject: [PATCH] types: fix insert error when convert string to float (#13716) --- executor/insert_test.go | 21 +++++++++++++++++++++ executor/write_test.go | 2 +- expression/builtin_cast.go | 8 ++++++++ planner/core/logical_plan_builder.go | 1 - types/convert.go | 2 +- types/convert_test.go | 7 ++----- 6 files changed, 33 insertions(+), 8 deletions(-) diff --git a/executor/insert_test.go b/executor/insert_test.go index 891728c2d050f..162a020fce638 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -665,3 +665,24 @@ func (s *testSuite) TestJiraIssue5366(c *C) { tk.MustExec(` insert into bug select ifnull(JSON_UNQUOTE(JSON_EXTRACT('[{"amount":2000,"feeAmount":0,"merchantNo":"20190430140319679394","shareBizCode":"20160311162_SECOND"}]', '$[0].merchantNo')),'') merchant_no union SELECT '20180531557' merchant_no;`) tk.MustQuery(`select * from bug`).Sort().Check(testkit.Rows("20180531557", "20190430140319679394")) } + +func (s *testSuite) TestDMLCast(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a int, b double)`) + tk.MustExec(`insert into t values (ifnull('',0)+0, 0)`) + tk.MustExec(`insert into t values (0, ifnull('',0)+0)`) + tk.MustQuery(`select * from t`).Check(testkit.Rows("0 0", "0 0")) + _, err := tk.Exec(`insert into t values ('', 0)`) + c.Assert(err, NotNil) + _, err = tk.Exec(`insert into t values (0, '')`) + c.Assert(err, NotNil) + _, err = tk.Exec(`update t set a = ''`) + c.Assert(err, NotNil) + _, err = tk.Exec(`update t set b = ''`) + c.Assert(err, NotNil) + tk.MustExec("update t set a = ifnull('',0)+0") + tk.MustExec("update t set b = ifnull('',0)+0") + tk.MustExec("delete from t where a = ''") + tk.MustQuery(`select * from t`).Check(testkit.Rows()) +} diff --git a/executor/write_test.go b/executor/write_test.go index fe727a3ef1459..863ef7ec6be08 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -1180,7 +1180,7 @@ func (s *testSuite) TestUpdate(c *C) { tk.MustExec("update t set a = ''") tk.MustQuery("select * from t").Check(testkit.Rows("0000-00-00 00:00:00 1999-12-13 00:00:00")) tk.MustExec("update t set b = ''") - tk.MustQuery("select * from t").Check(testkit.Rows("0000-00-00 00:00:00 ")) + tk.MustQuery("select * from t").Check(testkit.Rows("0000-00-00 00:00:00 0000-00-00 00:00:00")) tk.MustExec("set @@sql_mode=@orig_sql_mode;") } diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index f1abfa174e939..739b4bacd245e 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1096,6 +1096,10 @@ func (b *builtinCastStringAsIntSig) evalInt(row chunk.Row) (res int64, isNull bo if len(val) > 1 && val[0] == '-' { // negative number isNegative = true } + sctx := b.ctx.GetSessionVars().StmtCtx + if val == "" && (sctx.InInsertStmt || sctx.InUpdateStmt) { + return 0, false, nil + } var ures uint64 sc := b.ctx.GetSessionVars().StmtCtx @@ -1138,6 +1142,10 @@ func (b *builtinCastStringAsRealSig) evalReal(row chunk.Row) (res float64, isNul if isNull || err != nil { return res, isNull, errors.Trace(err) } + sctx := b.ctx.GetSessionVars().StmtCtx + if val == "" && (sctx.InInsertStmt || sctx.InUpdateStmt) { + return 0, false, nil + } sc := b.ctx.GetSessionVars().StmtCtx res, err = types.StrToFloat(sc, val) if err != nil { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index e4d69f56c0244..785c4f90dbde2 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2379,7 +2379,6 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A if err != nil { return nil, nil, errors.Trace(err) } - newExpr = expression.BuildCastFunction(b.ctx, newExpr, col.GetType()) p = np newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr}) } diff --git a/types/convert.go b/types/convert.go index 72e5520c4b243..2ce128617a05a 100644 --- a/types/convert.go +++ b/types/convert.go @@ -588,7 +588,7 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyD // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { - if (sc.InDeleteStmt || sc.InSelectStmt || sc.InUpdateStmt) && s == "" { + if (sc.InDeleteStmt || sc.InSelectStmt) && s == "" { return "0", nil } diff --git a/types/convert_test.go b/types/convert_test.go index f968ca8783ebc..5b44a1322223c 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -467,17 +467,14 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) { func testSelectUpdateDeleteEmptyStringError(c *C) { testCases := []struct { inSelect bool - inUpdate bool inDelete bool }{ - {true, false, false}, - {false, true, false}, - {false, false, true}, + {true, false}, + {false, true}, } sc := new(stmtctx.StatementContext) for _, tc := range testCases { sc.InSelectStmt = tc.inSelect - sc.InUpdateStmt = tc.inUpdate sc.InDeleteStmt = tc.inDelete str := ""