diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index ca8f63cb353ba..4de94a8be3798 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -127,6 +127,26 @@ func (s *testEvaluatorSuite) TestCast(c *C) { lastWarn = warnings[len(warnings)-1] c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + // cast('125e342.83' as unsigned) + f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("125e342.83"), RetType: types.NewFieldType(mysql.TypeString)}, tp1) + res, err = f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(res.GetUint64() == 125, IsTrue) + + warnings = sc.GetWarnings() + lastWarn = warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + + // cast('1e9223372036854775807' as unsigned) + f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("1e9223372036854775807"), RetType: types.NewFieldType(mysql.TypeString)}, tp1) + res, err = f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(res.GetUint64() == 1, IsTrue) + + warnings = sc.GetWarnings() + lastWarn = warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + // cast('18446744073709551616' as signed); mask := ^mysql.UnsignedFlag tp1.Flag &= uint(mask) @@ -149,6 +169,26 @@ func (s *testEvaluatorSuite) TestCast(c *C) { lastWarn = warnings[len(warnings)-1] c.Assert(terror.ErrorEqual(types.ErrCastAsSignedOverflow, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + // cast('125e342.83' as signed) + f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("125e342.83"), RetType: types.NewFieldType(mysql.TypeString)}, tp1) + res, err = f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(res.GetInt64() == 125, IsTrue) + + warnings = sc.GetWarnings() + lastWarn = warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + + // cast('1e9223372036854775807' as signed) + f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("1e9223372036854775807"), RetType: types.NewFieldType(mysql.TypeString)}, tp1) + res, err = f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(res.GetInt64() == 1, IsTrue) + + warnings = sc.GetWarnings() + lastWarn = warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + // create table t1(s1 time); // insert into t1 values('11:11:11'); // select cast(s1 as decimal(7, 2)) from t1; diff --git a/types/convert.go b/types/convert.go index 82aac0547b04a..ab727258b76fd 100644 --- a/types/convert.go +++ b/types/convert.go @@ -236,13 +236,13 @@ func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) if err != nil { return floatPrefix, errors.Trace(err) } - return floatStrToIntStr(floatPrefix) + return floatStrToIntStr(sc, floatPrefix, str) } // floatStrToIntStr converts a valid float string into valid integer string which can be parsed by // strconv.ParseInt, we can't parse float first then convert it to string because precision will // be lost. -func floatStrToIntStr(validFloat string) (string, error) { +func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (string, error) { var dotIdx = -1 var eIdx = -1 for i := 0; i < len(validFloat); i++ { @@ -275,7 +275,8 @@ func floatStrToIntStr(validFloat string) (string, error) { } if exp > 0 && int64(intCnt) > (math.MaxInt64-int64(exp)) { // (exp + incCnt) overflows MaxInt64. - return validFloat, ErrOverflow.GenByArgs("BIGINT", validFloat) + sc.AppendWarning(ErrOverflow.GenByArgs("BIGINT", oriStr)) + return validFloat[:eIdx], nil } intCnt += exp if intCnt <= 0 { @@ -291,8 +292,9 @@ func floatStrToIntStr(validFloat string) (string, error) { // convert scientific notation decimal number extraZeroCount := intCnt - len(digits) if extraZeroCount > 20 { - // Return overflow to avoid allocating too much memory. - return validFloat, ErrOverflow.GenByArgs("BIGINT", validFloat) + // Append overflow warning and return to avoid allocating too much memory. + sc.AppendWarning(ErrOverflow.GenByArgs("BIGINT", oriStr)) + return validFloat[:eIdx], nil } validInt = string(digits) + strings.Repeat("0", extraZeroCount) } diff --git a/types/convert_test.go b/types/convert_test.go index 2a70a0a2d9d9a..08657d627dae8 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -684,10 +684,15 @@ func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { _, err := strconv.ParseFloat(prefix, 64) c.Assert(err, IsNil) } - _, err := floatStrToIntStr("1e9223372036854775807") - c.Assert(terror.ErrorEqual(err, ErrOverflow), IsTrue, Commentf("err %v", err)) - _, err = floatStrToIntStr("1e21") - c.Assert(terror.ErrorEqual(err, ErrOverflow), IsTrue, Commentf("err %v", err)) + floatStr, err := floatStrToIntStr(sc, "1e9223372036854775807", "1e9223372036854775807") + c.Assert(err, IsNil) + c.Assert(floatStr, Equals, "1") + floatStr, err = floatStrToIntStr(sc, "125e342", "125e342.83") + c.Assert(err, IsNil) + c.Assert(floatStr, Equals, "125") + floatStr, err = floatStrToIntStr(sc, "1e21", "1e21") + c.Assert(err, IsNil) + c.Assert(floatStr, Equals, "1") } // TestConvertTime tests time related conversion.