-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expression/types: fix decimal minus/round/multiple result #7001
Changes from 7 commits
6668431
2ef808c
7bba6ba
3de2dd1
039feff
126d875
0151f10
f36ed84
7d1dbcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ import ( | |
"strings" | ||
"time" | ||
|
||
"github.com/cznic/mathutil" | ||
"github.com/juju/errors" | ||
"github.com/pingcap/tidb/mysql" | ||
"github.com/pingcap/tidb/sessionctx" | ||
|
@@ -263,8 +264,10 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi | |
if mysql.HasUnsignedFlag(argFieldTp.Flag) { | ||
bf.tp.Flag |= mysql.UnsignedFlag | ||
} | ||
|
||
bf.tp.Flen = argFieldTp.Flen | ||
bf.tp.Decimal = 0 | ||
bf.tp.Decimal = fixDecimal4RoundAndTruncate(ctx, args, argTp) | ||
|
||
var sig builtinFunc | ||
if len(args) > 1 { | ||
switch argTp { | ||
|
@@ -292,6 +295,25 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi | |
return sig, nil | ||
} | ||
|
||
// fixDecimal4RoundAndTruncate fixes tp.decimals of round/truncate func. | ||
func fixDecimal4RoundAndTruncate(ctx sessionctx.Context, args []Expression, retType types.EvalType) int { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/fix.../calc.../ |
||
if retType == types.ETInt || len(args) <= 1 { | ||
return 0 | ||
} | ||
secondConst, secondIsConst := args[1].(*Constant) | ||
if !secondIsConst { | ||
return args[0].GetType().Decimal | ||
} | ||
argDec, isNull, err := secondConst.EvalInt(ctx, nil) | ||
if isNull || err != nil || argDec < 0 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should check |
||
return 0 | ||
} | ||
if argDec > mysql.MaxDecimalScale { | ||
return mysql.MaxDecimalScale | ||
} | ||
return int(argDec) | ||
} | ||
|
||
type builtinRoundRealSig struct { | ||
baseBuiltinFunc | ||
} | ||
|
@@ -422,7 +444,7 @@ func (b *builtinRoundWithFracDecSig) evalDecimal(row types.Row) (*types.MyDecima | |
return nil, isNull, errors.Trace(err) | ||
} | ||
to := new(types.MyDecimal) | ||
if err = val.Round(to, int(frac), types.ModeHalfEven); err != nil { | ||
if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfEven); err != nil { | ||
return nil, true, errors.Trace(err) | ||
} | ||
return to, false, nil | ||
|
@@ -1695,22 +1717,6 @@ type truncateFunctionClass struct { | |
baseFunctionClass | ||
} | ||
|
||
// getDecimal returns the `Decimal` value of return type for function `TRUNCATE`. | ||
func (c *truncateFunctionClass) getDecimal(ctx sessionctx.Context, arg Expression) int { | ||
if constant, ok := arg.(*Constant); ok { | ||
decimal, isNull, err := constant.EvalInt(ctx, nil) | ||
if isNull || err != nil { | ||
return 0 | ||
} else if decimal > 30 { | ||
return 30 | ||
} else if decimal < 0 { | ||
return 0 | ||
} | ||
return int(decimal) | ||
} | ||
return 3 | ||
} | ||
|
||
func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { | ||
if err := c.verifyArgs(args); err != nil { | ||
return nil, errors.Trace(err) | ||
|
@@ -1723,11 +1729,7 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre | |
|
||
bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, argTp, types.ETInt) | ||
|
||
if argTp == types.ETInt { | ||
bf.tp.Decimal = 0 | ||
} else { | ||
bf.tp.Decimal = c.getDecimal(bf.ctx, args[1]) | ||
} | ||
bf.tp.Decimal = fixDecimal4RoundAndTruncate(ctx, args, argTp) | ||
bf.tp.Flen = args[0].GetType().Flen - args[0].GetType().Decimal + bf.tp.Decimal | ||
bf.tp.Flag |= args[0].GetType().Flag | ||
|
||
|
@@ -1768,7 +1770,7 @@ func (b *builtinTruncateDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal | |
} | ||
|
||
result := new(types.MyDecimal) | ||
if err := x.Round(result, int(d), types.ModeTruncate); err != nil { | ||
if err := x.Round(result, mathutil.Min(int(d), b.getRetTp().Decimal), types.ModeTruncate); err != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truncate's param is given by user, so need check dec < tp.Decimal(which is less than MAX_DECIMAL_SCALE) e.g. |
||
return nil, true, errors.Trace(err) | ||
} | ||
return result, false, nil | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -713,9 +713,6 @@ func (d *MyDecimal) doMiniRightShift(shift, beg, end int) { | |
// RETURN VALUE | ||
// eDecOK/eDecTruncated | ||
func (d *MyDecimal) Round(to *MyDecimal, frac int, roundMode RoundMode) (err error) { | ||
if frac > mysql.MaxDecimalScale { | ||
frac = mysql.MaxDecimalScale | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this check out of mydecimal, and not |
||
// wordsFracTo is the number of fraction words in buffer. | ||
wordsFracTo := (frac + 1) / digitsPerWord | ||
if frac > 0 { | ||
|
@@ -1383,6 +1380,16 @@ func (d *MyDecimal) Compare(to *MyDecimal) int { | |
return 1 | ||
} | ||
|
||
// DecimalNeg reverses decimal's sign. | ||
func DecimalNeg(from *MyDecimal) *MyDecimal { | ||
to := *from | ||
if from.IsZero() { | ||
return &to | ||
} | ||
to.negative = !from.negative | ||
return &to | ||
} | ||
|
||
// DecimalAdd adds two decimals, sets the result to 'to'. | ||
func DecimalAdd(from1, from2, to *MyDecimal) error { | ||
to.resultFrac = myMaxInt8(from1.resultFrac, from2.resultFrac) | ||
|
@@ -1753,7 +1760,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { | |
to.digitsFrac = int8(wordsFracTo * digitsPerWord) | ||
} | ||
if to.digitsInt > int8(wordsIntTo*digitsPerWord) { | ||
to.digitsInt = int8(wordsFracTo * digitsPerWord) | ||
to.digitsInt = int8(wordsIntTo * digitsPerWord) | ||
} | ||
if tmp1 > wordsIntTo { | ||
tmp1 -= wordsIntTo | ||
|
@@ -1762,7 +1769,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { | |
wordsFrac1 = 0 | ||
wordsFrac2 = 0 | ||
} else { | ||
tmp2 -= wordsIntTo | ||
tmp2 -= wordsFracTo | ||
tmp1 = tmp2 >> 1 | ||
if wordsFrac1 <= wordsFrac2 { | ||
wordsFrac1 -= tmp1 | ||
|
@@ -1774,9 +1781,9 @@ func DecimalMul(from1, from2, to *MyDecimal) error { | |
} | ||
} | ||
startTo := wordsIntTo + wordsFracTo - 1 | ||
start2 := wordsInt2 + wordsFrac2 - 1 | ||
stop1 := 0 | ||
stop2 := 0 | ||
start2 := idx2 + wordsFrac2 - 1 | ||
stop1 := idx1 - wordsInt1 | ||
stop2 := idx2 - wordsInt2 | ||
to.wordBuf = zeroMyDecimal.wordBuf | ||
|
||
for idx1 += wordsFrac1 - 1; idx1 >= stop1; idx1-- { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -543,6 +543,25 @@ func (s *testMyDecimalSuite) TestMaxDecimal(c *C) { | |
} | ||
} | ||
|
||
func (s *testMyDecimalSuite) TestNeg(c *C) { | ||
type testCase struct { | ||
a string | ||
result string | ||
err error | ||
} | ||
tests := []testCase{ | ||
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "0.0000000000000000000000000000000000000000000000000017382578996420603", nil}, | ||
{"-13890436710184412000000000000000000000000000000000000000000000000000000000000", "13890436710184412000000000000000000000000000000000000000000000000000000000000", nil}, | ||
{"0", "0", nil}, | ||
} | ||
for _, tt := range tests { | ||
a := NewDecFromStringForTest(tt.a) | ||
negResult := DecimalNeg(a) | ||
result := negResult.ToString() | ||
c.Assert(string(result), Equals, tt.result) | ||
} | ||
} | ||
|
||
func (s *testMyDecimalSuite) TestAdd(c *C) { | ||
type testCase struct { | ||
a string | ||
|
@@ -627,6 +646,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) { | |
{"123456", "9876543210", "1219318518533760", nil}, | ||
{"123", "0.01", "1.23", nil}, | ||
{"123", "0", "0", nil}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove the old test case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's my mistake.- - |
||
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "-13890436710184412000000000000000000000000000000000000000000000000000000000000", "0.000000000000000000000000000000", ErrTruncated}, | ||
{"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow}, | ||
} | ||
for _, tt := range tests { | ||
|
@@ -635,8 +655,8 @@ func (s *testMyDecimalSuite) TestMul(c *C) { | |
b.FromString([]byte(tt.b)) | ||
err := DecimalMul(&a, &b, &product) | ||
c.Check(err, Equals, tt.err) | ||
result := product.ToString() | ||
c.Assert(string(result), Equals, tt.result) | ||
result := product.String() | ||
c.Assert(result, Equals, tt.result) | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In MySQL multiple will ignore mul DataTruncated error without any error or warning.....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have some test to cover this behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes~ 0151f10 add a new case, today.
select 2.00000000000000000000000000000001 * 1.000000000000000000000000000000000000000000002
will failed in master, but pass in this PR and mysql with2.000000000000000000000000000000