Skip to content

Commit

Permalink
cherry pick pingcap#20017 to release-4.0
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
dyzsr authored and ti-srebot committed Nov 11, 2020
1 parent 737fe82 commit acdb935
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 15 deletions.
4 changes: 4 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ type baseAggFunc struct {
// ordinal stores the ordinal of the columns in the output chunk, which is
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) error {
Expand Down
25 changes: 25 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"fmt"
"strconv"

"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -242,6 +243,11 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
ordinal: ordinal,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -270,6 +276,15 @@ func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
args: aggFuncDesc.Args,
ordinal: ordinal,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -311,6 +326,11 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
args: aggFuncDesc.Args,
ordinal: ordinal,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
Expand Down Expand Up @@ -360,6 +380,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
},
isMax: isMax,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
Expand Down
22 changes: 7 additions & 15 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
package aggfuncs

import (
<<<<<<< HEAD
"github.com/cznic/mathutil"
"github.com/pingcap/parser/mysql"
=======
"unsafe"

>>>>>>> a3facd0f7... expression, planner: fix decimal results for aggregate functions (#20017)
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -60,15 +65,7 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if len(e.args) == 2 {
frac = e.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -216,12 +213,7 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,10 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}
Expand Down
4 changes: 4 additions & 0 deletions executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}
Expand Down
4 changes: 4 additions & 0 deletions executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}
Expand Down
41 changes: 41 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,3 +1117,44 @@ func (s *testSuiteAgg) TestIssue15958(c *C) {
tk.MustQuery(`select sum(y) from t`).Check(testkit.Rows("6070"))
tk.MustQuery(`select avg(y) from t`).Check(testkit.Rows("2023.3333"))
}
<<<<<<< HEAD
=======

func (s *testSuiteAgg) TestIssue17216(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec(`CREATE TABLE t1 (
pk int(11) NOT NULL,
col1 decimal(40,20) DEFAULT NULL
)`)
tk.MustExec(`INSERT INTO t1 VALUES (2084,0.02040000000000000000),(35324,0.02190000000000000000),(43760,0.00510000000000000000),(46084,0.01400000000000000000),(46312,0.00560000000000000000),(61632,0.02730000000000000000),(94676,0.00660000000000000000),(102244,0.01810000000000000000),(113144,0.02140000000000000000),(157024,0.02750000000000000000),(157144,0.01750000000000000000),(182076,0.02370000000000000000),(188696,0.02330000000000000000),(833,0.00390000000000000000),(6701,0.00230000000000000000),(8533,0.01690000000000000000),(13801,0.01360000000000000000),(20797,0.00680000000000000000),(36677,0.00550000000000000000),(46305,0.01290000000000000000),(76113,0.00430000000000000000),(76753,0.02400000000000000000),(92393,0.01720000000000000000),(111733,0.02690000000000000000),(152757,0.00250000000000000000),(162393,0.02760000000000000000),(167169,0.00440000000000000000),(168097,0.01360000000000000000),(180309,0.01720000000000000000),(19918,0.02620000000000000000),(58674,0.01820000000000000000),(67454,0.01510000000000000000),(70870,0.02880000000000000000),(89614,0.02530000000000000000),(106742,0.00180000000000000000),(107886,0.01580000000000000000),(147506,0.02230000000000000000),(148366,0.01340000000000000000),(167258,0.01860000000000000000),(194438,0.00500000000000000000),(10307,0.02850000000000000000),(14539,0.02210000000000000000),(27703,0.00050000000000000000),(32495,0.00680000000000000000),(39235,0.01450000000000000000),(52379,0.01640000000000000000),(54551,0.01910000000000000000),(85659,0.02330000000000000000),(104483,0.02670000000000000000),(109911,0.02040000000000000000),(114523,0.02110000000000000000),(119495,0.02120000000000000000),(137603,0.01910000000000000000),(154031,0.02580000000000000000);`)
tk.MustQuery("SELECT count(distinct col1) FROM t1").Check(testkit.Rows("48"))
}

func (s *testSuiteAgg) TestIssue19426(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key, b int)")
tk.MustExec("insert into t values (1, 11), (4, 44), (2, 22), (3, 33)")
tk.MustQuery("select sum(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("110.0"))
tk.MustQuery("select avg(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("27.50000"))
tk.MustQuery("select distinct (case when a <= 0 or a > 1000 then 0.0 else b end) v from t order by v").
Check(testkit.Rows("11.0", "22.0", "33.0", "44.0"))
tk.MustQuery("select group_concat(case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44.0,33.0,22.0,11.0"))
tk.MustQuery("select group_concat(a, b, case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44444.0,33333.0,22222.0,11111.0"))
tk.MustQuery("select group_concat(distinct case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44.0,33.0,22.0,11.0"))
tk.MustQuery("select max(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("44.0"))
tk.MustQuery("select min(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("11.0"))
tk.MustQuery("select a, b, sum(case when a < 1000 then b else 0.0 end) over (order by a) from t").
Check(testkit.Rows("1 11 11.0", "2 22 33.0", "3 33 66.0", "4 44 110.0"))
}
>>>>>>> a3facd0f7... expression, planner: fix decimal results for aggregate functions (#20017)
12 changes: 12 additions & 0 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,18 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}

// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
Expand Down
1 change: 1 addition & 0 deletions planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan {
// since the types of the args are already the expected.
func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) {
for i := range aggFuncs {
aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx)
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
aggFuncs[i].WrapCastForAggArgs(sctx)
}
Expand Down

0 comments on commit acdb935

Please sign in to comment.