From b4242030b5d71630c7dfb28125ebe3116356d2e1 Mon Sep 17 00:00:00 2001 From: xuhuaiyu <391585975@qq.com> Date: Thu, 11 Oct 2018 12:36:16 +0800 Subject: [PATCH 1/2] expression, executor: refine avg/sum precision --- executor/aggfuncs/func_avg.go | 25 ++++++++++++++++++++++++- executor/builder.go | 18 ++++++++++++++++++ expression/aggregation/descriptor.go | 11 ++++------- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index 139d60845273f..f917c60e1d044 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -14,6 +14,8 @@ package aggfuncs import ( + "github.com/cznic/mathutil" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -56,7 +58,19 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par finalResult := new(types.MyDecimal) err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr) if err != nil { - return errors.Trace(err) + return err + } + // Make the decimal be the result of type inferring. + frac := e.args[0].GetType().Decimal + if len(e.args) == 2 { + frac = e.args[1].GetType().Decimal + } + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + if err != nil { + return err } chk.AppendMyDecimal(e.ordinal, finalResult) return nil @@ -195,6 +209,15 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co if err != nil { return errors.Trace(err) } + // Make the decimal be the result of type inferring. + frac := e.args[0].GetType().Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + if err != nil { + return err + } chk.AppendMyDecimal(e.ordinal, finalResult) return nil } diff --git a/executor/builder.go b/executor/builder.go index 1f9ef0e0cb96a..0392565593edd 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -887,6 +887,24 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { } for i := range f.Args { f.Args[i] = castFunc(b.ctx, f.Args[i]) + // After wrapping cast on the argument, flen etc. may not the same + // as the type of the aggregation function. The following part set + // the type of the argument exactly as the type of the aggregation + // function. + // Note: If the `Tp` of argument is the same as the `Tp` of the + // aggregation function, it will not wrap cast function on it + // internally. The reason of the special handling for `Column` is + // that the `RetType` of `Column` refers to the `infoschema`, so we + // need to set a new variable for it to avoid modifying the + // definition in `infoschema`. + if col, ok := f.Args[i].(*expression.Column); ok { + col.RetType = types.NewFieldType(col.RetType.Tp) + } + // originTp is used when the the `Tp` of column is TypeFloat32 while + // the type of the aggregation function is TypeFloat64. + originTp := f.Args[i].GetType().Tp + *(f.Args[i].GetType()) = *(f.RetTp) + f.Args[i].GetType().Tp = originTp } } } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 2d40426374518..7c47ec85c0d64 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -291,21 +291,21 @@ func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) { // Because child returns integer or decimal type. func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { switch a.Args[0].GetType().Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0 + case mysql.TypeNewDecimal: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { a.RetTp.Decimal = mysql.MaxDecimalScale } - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) default: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength - // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } types.SetBinChsClnFlag(a.RetTp) } @@ -322,15 +322,12 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } a.RetTp.Flen = mysql.MaxDecimalWidth - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) default: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength - // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } types.SetBinChsClnFlag(a.RetTp) } From f31e15ac57eeb2f2a6d97ee0a3c957963eba6947 Mon Sep 17 00:00:00 2001 From: xuhuaiyu <391585975@qq.com> Date: Thu, 11 Oct 2018 13:16:01 +0800 Subject: [PATCH 2/2] tiny change --- executor/builder.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/executor/builder.go b/executor/builder.go index 0392565593edd..b16d39eee04fa 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -887,6 +887,9 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { } for i := range f.Args { f.Args[i] = castFunc(b.ctx, f.Args[i]) + if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum { + continue + } // After wrapping cast on the argument, flen etc. may not the same // as the type of the aggregation function. The following part set // the type of the argument exactly as the type of the aggregation