diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 800c034b1d886..84088fb13adc2 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -470,6 +470,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } switch tp { case types.ETInt: + // adjust unsigned flag + greastInitUnsignedFlag := false + if isEqualsInitUnsignedFlag(greastInitUnsignedFlag, args) { + bf.tp.Flag &= ^mysql.UnsignedFlag + } else { + bf.tp.Flag |= mysql.UnsignedFlag + } + sig = &builtinGreatestIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestInt) case types.ETReal: @@ -689,6 +697,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi } switch tp { case types.ETInt: + // adjust unsigned flag + leastInitUnsignedFlag := true + if isEqualsInitUnsignedFlag(leastInitUnsignedFlag, args) { + bf.tp.Flag |= mysql.UnsignedFlag + } else { + bf.tp.Flag &= ^mysql.UnsignedFlag + } + sig = &builtinLeastIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LeastInt) case types.ETReal: @@ -2756,3 +2772,15 @@ func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs } return int64(json.CompareBinary(arg0, arg1)), false, nil } + +// isEqualsInitUnsignedFlag can adjust unsigned flag for greatest/least function. +// For greatest, returns unsigned result if there is at least one argument is unsigned. +// For least, returns signed result if there is at least one argument is signed. +func isEqualsInitUnsignedFlag(initUnsigned bool, args []Expression) bool { + for _, arg := range args { + if initUnsigned != mysql.HasUnsignedFlag(arg.GetType().Flag) { + return false + } + } + return true +} diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index b8d3db9ae9937..f1aa15ee1a783 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -263,6 +263,8 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { sc := s.ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate sc.IgnoreTruncate = true + decG := &types.MyDecimal{} + decL := &types.MyDecimal{} defer func() { sc.IgnoreTruncate = originIgnoreTruncate }() @@ -274,6 +276,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { isNil bool getErr bool }{ + { + []interface{}{int64(-9223372036854775808), uint64(9223372036854775809)}, + decG.FromUint(9223372036854775809), decL.FromInt(-9223372036854775808), false, false, + }, + { + []interface{}{uint64(9223372036854775808), uint64(9223372036854775809)}, + uint64(9223372036854775809), uint64(9223372036854775808), false, false, + }, { []interface{}{1, 2, 3, 4}, int64(4), int64(1), false, false, diff --git a/expression/integration_test.go b/expression/integration_test.go index 0032cfa8d20b6..e26e8d850ec17 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -10318,3 +10318,12 @@ func (s *testIntegrationSuite) TestIssue29244(c *C) { tk.MustExec("set tidb_enable_vectorized_expression = off;") tk.MustQuery("select microsecond(a) from t;").Check(testkit.Rows("123500", "123500")) } + +func (s *testIntegrationSuite) TestIssue30101(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 bigint unsigned, c2 bigint unsigned);") + tk.MustExec("insert into t1 values(9223372036854775808, 9223372036854775809);") + tk.MustQuery("select greatest(c1, c2) from t1;").Sort().Check(testkit.Rows("9223372036854775809")) +} diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 5365c78f61336..4293241a5ffbd 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -1034,6 +1034,13 @@ func (s *testInferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase { {"interval(c_int_d, c_int_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"interval(c_int_d, c_float_d, c_double_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + + {"greatest(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0}, + {"least(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"least(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"least(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, } }