Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression/types: fix decimal minus/round/multiple result #7001

Merged
merged 9 commits into from
Jul 11, 2018
3 changes: 2 additions & 1 deletion expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package aggregation

import (
"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
Expand Down Expand Up @@ -84,7 +85,7 @@ func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) {
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = to.Round(to, frac, types.ModeHalfEven)
err = to.Round(to, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
}
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tipb/go-tipb"
)
Expand Down Expand Up @@ -523,7 +524,7 @@ func (s *builtinArithmeticMultiplyDecimalSig) evalDecimal(row types.Row) (*types
}
c := &types.MyDecimal{}
err = types.DecimalMul(a, b, c)
if err != nil {
if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  conventions:
    decimal_smth() <= 1     -- result is usable, but precision loss is possible

In MySQL multiple will ignore mul DataTruncated error without any error or warning.....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have some test to cover this behavior?

Copy link
Contributor Author

@lysu lysu Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes~ 0151f10 add a new case, today.

select 2.00000000000000000000000000000001 * 1.000000000000000000000000000000000000000000002 will failed in master, but pass in this PR and mysql with 2.000000000000000000000000000000

return nil, true, errors.Trace(err)
}
return c, false, nil
Expand Down
50 changes: 26 additions & 24 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"
"time"

"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -263,8 +264,10 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if mysql.HasUnsignedFlag(argFieldTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}

bf.tp.Flen = argFieldTp.Flen
bf.tp.Decimal = 0
bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp)

var sig builtinFunc
if len(args) > 1 {
switch argTp {
Expand Down Expand Up @@ -292,6 +295,25 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
return sig, nil
}

// calculateDecimal4RoundAndTruncate calculates tp.decimals of round/truncate func.
func calculateDecimal4RoundAndTruncate(ctx sessionctx.Context, args []Expression, retType types.EvalType) int {
if retType == types.ETInt || len(args) <= 1 {
return 0
}
secondConst, secondIsConst := args[1].(*Constant)
if !secondIsConst {
return args[0].GetType().Decimal
}
argDec, isNull, err := secondConst.EvalInt(ctx, nil)
if err != nil || isNull || argDec < 0 {
return 0
}
if argDec > mysql.MaxDecimalScale {
return mysql.MaxDecimalScale
}
return int(argDec)
}

type builtinRoundRealSig struct {
baseBuiltinFunc
}
Expand Down Expand Up @@ -422,7 +444,7 @@ func (b *builtinRoundWithFracDecSig) evalDecimal(row types.Row) (*types.MyDecima
return nil, isNull, errors.Trace(err)
}
to := new(types.MyDecimal)
if err = val.Round(to, int(frac), types.ModeHalfEven); err != nil {
if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfEven); err != nil {
return nil, true, errors.Trace(err)
}
return to, false, nil
Expand Down Expand Up @@ -1695,22 +1717,6 @@ type truncateFunctionClass struct {
baseFunctionClass
}

// getDecimal returns the `Decimal` value of return type for function `TRUNCATE`.
func (c *truncateFunctionClass) getDecimal(ctx sessionctx.Context, arg Expression) int {
if constant, ok := arg.(*Constant); ok {
decimal, isNull, err := constant.EvalInt(ctx, nil)
if isNull || err != nil {
return 0
} else if decimal > 30 {
return 30
} else if decimal < 0 {
return 0
}
return int(decimal)
}
return 3
}

func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
Expand All @@ -1723,11 +1729,7 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre

bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, argTp, types.ETInt)

if argTp == types.ETInt {
bf.tp.Decimal = 0
} else {
bf.tp.Decimal = c.getDecimal(bf.ctx, args[1])
}
bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp)
bf.tp.Flen = args[0].GetType().Flen - args[0].GetType().Decimal + bf.tp.Decimal
bf.tp.Flag |= args[0].GetType().Flag

Expand Down Expand Up @@ -1768,7 +1770,7 @@ func (b *builtinTruncateDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal
}

result := new(types.MyDecimal)
if err := x.Round(result, int(d), types.ModeTruncate); err != nil {
if err := x.Round(result, mathutil.Min(int(d), b.getRetTp().Decimal), types.ModeTruncate); err != nil {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncate's param is given by user, so need check dec < tp.Decimal(which is less than MAX_DECIMAL_SCALE)

e.g. Truncate(1.124, 100)(we had test case).

return nil, true, errors.Trace(err)
}
return result, false, nil
Expand Down
6 changes: 1 addition & 5 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,15 +749,11 @@ func (b *builtinUnaryMinusDecimalSig) Clone() builtinFunc {
}

func (b *builtinUnaryMinusDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal, bool, error) {
var dec *types.MyDecimal
dec, isNull, err := b.args[0].EvalDecimal(b.ctx, row)
if err != nil || isNull {
return dec, isNull, errors.Trace(err)
}

to := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), dec, to)
return to, err != nil, errors.Trace(err)
return types.DecimalNeg(dec), false, nil
}

type builtinUnaryMinusRealSig struct {
Expand Down
4 changes: 3 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3372,7 +3372,7 @@ func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) {
return store, dom, errors.Trace(err)
}

func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) {
func (s *testIntegrationSuite) TestTwoDecimalTruncate(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
tk.MustExec("use test")
Expand All @@ -3383,4 +3383,6 @@ func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) {
tk.MustExec("update t1 set b = a")
res := tk.MustQuery("select a, b from t1")
res.Check(testkit.Rows("123.12345 123.1"))
res = tk.MustQuery("select 2.00000000000000000000000000000001 * 1.000000000000000000000000000000000000000000002")
res.Check(testkit.Rows("2.000000000000000000000000000000"))
}
6 changes: 3 additions & 3 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,9 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase {

{"round(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"round(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
{"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // Should be 17.
{"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // Should be 17.
{"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // Should be 5.
{"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // flen Should be 17.
{"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // flen Should be 17.
{"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // flen Should be 5.
{"round(c_datetime )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"round(c_time_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"round(c_timestamp_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
Expand Down
23 changes: 15 additions & 8 deletions types/mydecimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,6 @@ func (d *MyDecimal) doMiniRightShift(shift, beg, end int) {
// RETURN VALUE
// eDecOK/eDecTruncated
func (d *MyDecimal) Round(to *MyDecimal, frac int, roundMode RoundMode) (err error) {
if frac > mysql.MaxDecimalScale {
frac = mysql.MaxDecimalScale
}
Copy link
Contributor Author

@lysu lysu Jul 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this check out of mydecimal, and not > mysql.MaxDecimalScale but > e.tp.Decimals

// wordsFracTo is the number of fraction words in buffer.
wordsFracTo := (frac + 1) / digitsPerWord
if frac > 0 {
Expand Down Expand Up @@ -1383,6 +1380,16 @@ func (d *MyDecimal) Compare(to *MyDecimal) int {
return 1
}

// DecimalNeg reverses decimal's sign.
func DecimalNeg(from *MyDecimal) *MyDecimal {
to := *from
if from.IsZero() {
return &to
}
to.negative = !from.negative
return &to
}

// DecimalAdd adds two decimals, sets the result to 'to'.
func DecimalAdd(from1, from2, to *MyDecimal) error {
to.resultFrac = myMaxInt8(from1.resultFrac, from2.resultFrac)
Expand Down Expand Up @@ -1753,7 +1760,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
to.digitsFrac = int8(wordsFracTo * digitsPerWord)
}
if to.digitsInt > int8(wordsIntTo*digitsPerWord) {
to.digitsInt = int8(wordsFracTo * digitsPerWord)
to.digitsInt = int8(wordsIntTo * digitsPerWord)
}
if tmp1 > wordsIntTo {
tmp1 -= wordsIntTo
Expand All @@ -1762,7 +1769,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
wordsFrac1 = 0
wordsFrac2 = 0
} else {
tmp2 -= wordsIntTo
tmp2 -= wordsFracTo
tmp1 = tmp2 >> 1
if wordsFrac1 <= wordsFrac2 {
wordsFrac1 -= tmp1
Expand All @@ -1774,9 +1781,9 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
}
}
startTo := wordsIntTo + wordsFracTo - 1
start2 := wordsInt2 + wordsFrac2 - 1
stop1 := 0
stop2 := 0
start2 := idx2 + wordsFrac2 - 1
stop1 := idx1 - wordsInt1
stop2 := idx2 - wordsInt2
to.wordBuf = zeroMyDecimal.wordBuf

for idx1 += wordsFrac1 - 1; idx1 >= stop1; idx1-- {
Expand Down
24 changes: 22 additions & 2 deletions types/mydecimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,25 @@ func (s *testMyDecimalSuite) TestMaxDecimal(c *C) {
}
}

func (s *testMyDecimalSuite) TestNeg(c *C) {
type testCase struct {
a string
result string
err error
}
tests := []testCase{
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "0.0000000000000000000000000000000000000000000000000017382578996420603", nil},
{"-13890436710184412000000000000000000000000000000000000000000000000000000000000", "13890436710184412000000000000000000000000000000000000000000000000000000000000", nil},
{"0", "0", nil},
}
for _, tt := range tests {
a := NewDecFromStringForTest(tt.a)
negResult := DecimalNeg(a)
result := negResult.ToString()
c.Assert(string(result), Equals, tt.result)
}
}

func (s *testMyDecimalSuite) TestAdd(c *C) {
type testCase struct {
a string
Expand Down Expand Up @@ -627,6 +646,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) {
{"123456", "9876543210", "1219318518533760", nil},
{"123", "0.01", "1.23", nil},
{"123", "0", "0", nil},
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "-13890436710184412000000000000000000000000000000000000000000000000000000000000", "0.000000000000000000000000000000", ErrTruncated},
{"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow},
}
for _, tt := range tests {
Expand All @@ -635,8 +655,8 @@ func (s *testMyDecimalSuite) TestMul(c *C) {
b.FromString([]byte(tt.b))
err := DecimalMul(&a, &b, &product)
c.Check(err, Equals, tt.err)
result := product.ToString()
c.Assert(string(result), Equals, tt.result)
result := product.String()
c.Assert(result, Equals, tt.result)
}
}

Expand Down