Skip to content

Commit

Permalink
executor: refine the precision for avg (#7860) (#7874)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored and zz-jason committed Oct 11, 2018
1 parent 8823f12 commit e6025cb
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
25 changes: 24 additions & 1 deletion executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package aggfuncs

import (
"github.com/cznic/mathutil"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -56,7 +58,19 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
finalResult := new(types.MyDecimal)
err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr)
if err != nil {
return errors.Trace(err)
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if len(e.args) == 2 {
frac = e.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, finalResult)
return nil
Expand Down Expand Up @@ -195,6 +209,15 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return errors.Trace(err)
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, finalResult)
return nil
}
Expand Down
21 changes: 21 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,27 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum {
continue
}
// After wrapping cast on the argument, flen etc. may not the same
// as the type of the aggregation function. The following part set
// the type of the argument exactly as the type of the aggregation
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := f.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := f.Args[i].GetType().Tp
*(f.Args[i].GetType()) = *(f.RetTp)
f.Args[i].GetType().Tp = originTp
}
}
}
Expand Down
19 changes: 12 additions & 7 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,21 @@ func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) {
// Because child returns integer or decimal type.
func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal:
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0
case mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal
if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale {
a.RetTp.Decimal = mysql.MaxDecimalScale
}
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
default:
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
//TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand All @@ -318,11 +322,12 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale)
}
a.RetTp.Flen = mysql.MaxDecimalWidth
// TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0])
default:
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
// TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0])
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}
Expand Down
4 changes: 2 additions & 2 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,14 +822,14 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase {
{"sum(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 3},
{"sum(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 1},
{"sum(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 4},
{"avg(c_float_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 7},
{"avg(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 5},
{"avg(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0},
}
}
Expand Down

0 comments on commit e6025cb

Please sign in to comment.