From 55565f1c938f47ab09bae8b04dc4edc29ce4d8e3 Mon Sep 17 00:00:00 2001 From: xiaojian cai Date: Tue, 28 Aug 2018 11:24:48 +0800 Subject: [PATCH] expression: fix out of range error for intdiv (#7492) --- expression/builtin_arithmetic.go | 36 ++++++++++++++++++++++---------- expression/errors.go | 2 ++ expression/integration_test.go | 7 +++++++ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index 93cd217986ab8..4e7117b637e10 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -729,27 +729,41 @@ func (s *builtinArithmeticIntDivideIntSig) evalInt(row chunk.Row) (int64, bool, return ret, err != nil, errors.Trace(err) } -func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (int64, bool, error) { - a, isNull, err := s.args[0].EvalDecimal(s.ctx, row) - if isNull || err != nil { - return 0, isNull, errors.Trace(err) - } - - b, isNull, err := s.args[1].EvalDecimal(s.ctx, row) - if isNull || err != nil { - return 0, isNull, errors.Trace(err) +func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64, isNull bool, err error) { + sc := s.ctx.GetSessionVars().StmtCtx + var num [2]*types.MyDecimal + for i, arg := range s.args { + num[i], isNull, err = arg.EvalDecimal(s.ctx, row) + // Its behavior is consistent with MySQL. + if terror.ErrorEqual(err, types.ErrTruncated) { + err = nil + } + if terror.ErrorEqual(err, types.ErrOverflow) { + newErr := errTruncatedWrongValue.GenByArgs("DECIMAL", arg) + err = sc.HandleOverflow(newErr, newErr) + } + if isNull || err != nil { + return 0, isNull, errors.Trace(err) + } } c := &types.MyDecimal{} - err = types.DecimalDiv(a, b, c, types.DivFracIncr) + err = types.DecimalDiv(num[0], num[1], c, types.DivFracIncr) if err == types.ErrDivByZero { return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx)) } + if err == types.ErrTruncated { + err = sc.HandleTruncate(errTruncatedWrongValue.GenByArgs("DECIMAL", c)) + } + if err == types.ErrOverflow { + newErr := errTruncatedWrongValue.GenByArgs("DECIMAL", c) + err = sc.HandleOverflow(newErr, newErr) + } if err != nil { return 0, true, errors.Trace(err) } - ret, err := c.ToInt() + ret, err = c.ToInt() // err returned by ToInt may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. if err == types.ErrOverflow { return 0, true, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) diff --git a/expression/errors.go b/expression/errors.go index f342bf42916ae..92be6e1487054 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -38,6 +38,7 @@ var ( errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) + errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) func init() { @@ -53,6 +54,7 @@ func init() { mysql.ErrOperandColumns: mysql.ErrOperandColumns, mysql.ErrRegexp: mysql.ErrRegexp, mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, + mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes } diff --git a/expression/integration_test.go b/expression/integration_test.go index 20369d6038ab3..c37b478e5453e 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2515,6 +2515,13 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { result.Check(testkit.Rows("1 1300 -6 ")) result = tk.MustQuery("SELECT 2.4 div 1.1, 2.4 div 1.2, 2.4 div 1.3;") result.Check(testkit.Rows("2 2 1")) + result = tk.MustQuery("SELECT 1.175494351E-37 div 1.7976931348623157E+308, 1.7976931348623157E+308 div -1.7976931348623157E+307, 1 div 1e-82;") + result.Check(testkit.Rows("0 -1 ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect DECIMAL value: 'cast(1.7976931348623157e+308)'", + "Warning|1292|Truncated incorrect DECIMAL value: 'cast(1.7976931348623157e+308)'", + "Warning|1292|Truncated incorrect DECIMAL value: 'cast(-1.7976931348623158e+307)'", + "Warning|1365|Division by 0")) rs, err = tk.Exec("select 1e300 DIV 1.5") c.Assert(err, IsNil) _, err = session.GetRows4Test(ctx, tk.Se, rs)