diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 819912db67a5d..ff556cfc09a3d 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -461,9 +461,13 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b if isNull || err != nil { return res, isNull, err } - if !mysql.HasUnsignedFlag(b.tp.Flag) && !mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { + if unsignedArgs0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag); !mysql.HasUnsignedFlag(b.tp.Flag) && !unsignedArgs0 { res = float64(val) - } else if b.inUnion && val < 0 { + } else if b.inUnion && !unsignedArgs0 && val < 0 { + // Round up to 0 if the value is negative but the expression eval type is unsigned in `UNION` statement + // NOTE: the following expressions are equal (so choose the more efficient one): + // `b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && !unsignedArgs0 && val < 0` + // `b.inUnion && !unsignedArgs0 && val < 0` res = 0 } else { // recall that, int to float is different from uint to float @@ -487,9 +491,14 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe if isNull || err != nil { return res, isNull, err } - if !mysql.HasUnsignedFlag(b.tp.Flag) && !mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { + + if unsignedArgs0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag); !mysql.HasUnsignedFlag(b.tp.Flag) && !unsignedArgs0 { res = types.NewDecFromInt(val) - } else if b.inUnion && val < 0 { + // Round up to 0 if the value is negative but the expression eval type is unsigned in `UNION` statement + // NOTE: the following expressions are equal (so choose the more efficient one): + // `b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && !unsignedArgs0 && val < 0` + // `b.inUnion && !unsignedArgs0 && val < 0` + } else if b.inUnion && !unsignedArgs0 && val < 0 { res = &types.MyDecimal{} } else { res = types.NewDecFromUint(uint64(val)) diff --git a/expression/builtin_cast_vec.go b/expression/builtin_cast_vec.go index 81d733e627328..16b452415abc4 100644 --- a/expression/builtin_cast_vec.go +++ b/expression/builtin_cast_vec.go @@ -111,7 +111,11 @@ func (b *builtinCastIntAsRealSig) vecEvalReal(input *chunk.Chunk, result *chunk. } if !hasUnsignedFlag0 && !hasUnsignedFlag1 { rs[i] = float64(i64s[i]) - } else if b.inUnion && i64s[i] < 0 { + } else if b.inUnion && !hasUnsignedFlag1 && i64s[i] < 0 { + // Round up to 0 if the value is negative but the expression eval type is unsigned in `UNION` statement + // NOTE: the following expressions are equal (so choose the more efficient one): + // `b.inUnion && hasUnsignedFlag0 && !hasUnsignedFlag1 && i64s[i] < 0` + // `b.inUnion && !hasUnsignedFlag1 && i64s[i] < 0` rs[i] = 0 } else { // recall that, int to float is different from uint to float diff --git a/expression/integration_test.go b/expression/integration_test.go index 642dbc914e88d..78c2dc33f5223 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2308,6 +2308,22 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { result.Check(testkit.Rows("9223372036854775808 9223372036854775808", "9223372036854775808 9223372036854775808")) tk.MustExec(`drop table tb5;`) + // test builtinCastIntAsDecimalSig + tk.MustExec(`drop table if exists tb5`) + tk.MustExec(`create table tb5 (a decimal(65), b bigint(64) unsigned);`) + tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`) + result = tk.MustQuery(`select cast(b as decimal(64)) from tb5 union all select b from tb5;`) + result.Check(testkit.Rows("9223372036854775808", "9223372036854775808")) + tk.MustExec(`drop table tb5`) + + // test builtinCastIntAsRealSig + tk.MustExec(`drop table if exists tb5`) + tk.MustExec(`create table tb5 (a bigint(64) unsigned, b double(64, 10));`) + tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`) + result = tk.MustQuery(`select a from tb5 where a = b union all select b from tb5;`) + result.Check(testkit.Rows("9223372036854776000", "9223372036854776000")) + tk.MustExec(`drop table tb5`) + // Test corner cases of cast string as datetime result = tk.MustQuery(`select cast("170102034" as datetime);`) result.Check(testkit.Rows("2017-01-02 03:04:00"))