Skip to content

Commit

Permalink
expression, executor: refine avg/sum precision
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu committed Oct 11, 2018
1 parent 1a9741d commit b424203
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
25 changes: 24 additions & 1 deletion executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package aggfuncs

import (
"github.com/cznic/mathutil"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -56,7 +58,19 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
finalResult := new(types.MyDecimal)
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr)
if err != nil {
return errors.Trace(err)
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if len(e.args) == 2 {
frac = e.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, finalResult)
return nil
Expand Down Expand Up @@ -195,6 +209,15 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return errors.Trace(err)
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, finalResult)
return nil
}
Expand Down
18 changes: 18 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,24 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
// After wrapping cast on the argument, flen etc. may not the same
// as the type of the aggregation function. The following part set
// the type of the argument exactly as the type of the aggregation
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := f.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := f.Args[i].GetType().Tp
*(f.Args[i].GetType()) = *(f.RetTp)
f.Args[i].GetType().Tp = originTp
}
}
}
Expand Down
11 changes: 4 additions & 7 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,21 @@ func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) {
// Because child returns integer or decimal type.
func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal:
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0
case mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal
if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale {
a.RetTp.Decimal = mysql.MaxDecimalScale
}
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
//TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand All @@ -322,15 +322,12 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale)
}
a.RetTp.Flen = mysql.MaxDecimalWidth
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand Down

0 comments on commit b424203

Please sign in to comment.