From a0e28c31e4cbf8f4ecb7f17b384801b881449d32 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 1 Sep 2020 21:24:51 +0800 Subject: [PATCH] expression: Support stddev_pop function (#19195) (#19541) Signed-off-by: ti-srebot --- executor/aggfuncs/builder.go | 25 ++++++++++ executor/aggfuncs/func_stddevpop.go | 55 ++++++++++++++++++++++ executor/aggfuncs/func_stddevpop_test.go | 25 ++++++++++ executor/aggfuncs/func_varpop.go | 8 ++-- executor/aggregate_test.go | 48 ++++++++++++++++--- expression/aggregation/agg_to_pb.go | 2 + expression/aggregation/base_func.go | 8 ++++ go.mod | 2 +- go.sum | 4 +- planner/core/rule_aggregation_push_down.go | 2 +- 10 files changed, 165 insertions(+), 14 deletions(-) create mode 100644 executor/aggfuncs/func_stddevpop.go create mode 100644 executor/aggfuncs/func_stddevpop_test.go diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 799dc9d9e4d34..06922c0ceb6aa 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -57,6 +57,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal return buildJSONObjectAgg(aggFuncDesc, ordinal) case ast.AggFuncApproxCountDistinct: return buildApproxCountDistinct(aggFuncDesc, ordinal) + case ast.AggFuncStddevPop: + return buildStdDevPop(aggFuncDesc, ordinal) } return nil } @@ -442,6 +444,29 @@ func buildVarPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } } +// buildStdDevPop builds the AggFunc implementation for function "STD()/STDDEV()/STDDEV_POP()" +func buildStdDevPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseStdDevPopAggFunc{ + varPop4Float64{ + baseVarPopAggFunc{ + baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }, + }, + }, + } + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + return nil + default: + if aggFuncDesc.HasDistinct { + return &stdDevPop4DistinctFloat64{base} + } + return &stdDevPop4Float64{base} + } +} + // buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg". func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ diff --git a/executor/aggfuncs/func_stddevpop.go b/executor/aggfuncs/func_stddevpop.go new file mode 100644 index 0000000000000..2db8c06941240 --- /dev/null +++ b/executor/aggfuncs/func_stddevpop.go @@ -0,0 +1,55 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "math" + + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type baseStdDevPopAggFunc struct { + varPop4Float64 +} + +type stdDevPop4Float64 struct { + baseStdDevPopAggFunc +} + +func (e *stdDevPop4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopFloat64)(pr) + if p.count == 0 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count) + chk.AppendFloat64(e.ordinal, math.Sqrt(variance)) + return nil +} + +type stdDevPop4DistinctFloat64 struct { + baseStdDevPopAggFunc +} + +func (e *stdDevPop4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopDistinctFloat64)(pr) + if p.count == 0 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count) + chk.AppendFloat64(e.ordinal, math.Sqrt(variance)) + return nil +} diff --git a/executor/aggfuncs/func_stddevpop_test.go b/executor/aggfuncs/func_stddevpop_test.go new file mode 100644 index 0000000000000..0c4b7e24f601b --- /dev/null +++ b/executor/aggfuncs/func_stddevpop_test.go @@ -0,0 +1,25 @@ +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" +) + +func (s *testSuite) TestMergePartialResult4Stddevpop(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncStddevPop, mysql.TypeDouble, 5, 1.4142135623730951, 0.816496580927726, 1.3169567191065923), + } + for _, test := range tests { + s.testMergePartialResult(c, test) + } +} + +func (s *testSuite) TestStddevpop(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncStddevPop, mysql.TypeDouble, 5, nil, 1.4142135623730951), + } + for _, test := range tests { + s.testAggFunc(c, test) + } +} diff --git a/executor/aggfuncs/func_varpop.go b/executor/aggfuncs/func_varpop.go index e8bd9f1147f59..c7d0b3fded82a 100644 --- a/executor/aggfuncs/func_varpop.go +++ b/executor/aggfuncs/func_varpop.go @@ -51,8 +51,8 @@ func (e *varPop4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par chk.AppendNull(e.ordinal) return nil } - varicance := p.variance / float64(p.count) - chk.AppendFloat64(e.ordinal, varicance) + variance := p.variance / float64(p.count) + chk.AppendFloat64(e.ordinal, variance) return nil } @@ -143,8 +143,8 @@ func (e *varPop4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context chk.AppendNull(e.ordinal) return nil } - varicance := p.variance / float64(p.count) - chk.AppendFloat64(e.ordinal, varicance) + variance := p.variance / float64(p.count) + chk.AppendFloat64(e.ordinal, variance) return nil } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index dba989565948e..2f6a5fedaaddb 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -406,12 +406,6 @@ func (s *testSuiteAgg) TestAggregation(c *C) { result = tk.MustQuery("select a, variance(b) over w from t window w as (partition by a)").Sort() result.Check(testkit.Rows("1 2364075.6875", "1 2364075.6875", "1 2364075.6875", "1 2364075.6875", "2 0")) - _, err = tk.Exec("select std(a) from t") - c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: std") - _, err = tk.Exec("select stddev(a) from t") - c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev") - _, err = tk.Exec("select stddev_pop(a) from t") - c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev_pop") _, err = tk.Exec("select std_samp(a) from t") // TODO: Fix this error message. c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION test.std_samp does not exist") @@ -428,6 +422,48 @@ func (s *testSuiteAgg) TestAggregation(c *C) { tk.MustQuery("select sum(b) from t1 group by c order by c;").Check(testkit.Rows("", "-3", "-2", "-2")) tk.MustQuery("select sum(c) from t1 group by a order by a;").Check(testkit.Rows("", "-2", "-2", "-3")) tk.MustQuery("select sum(c) from t1 group by b order by b;").Check(testkit.Rows("", "-3", "-2", "-2")) + + // For stddev_pop()/std()/stddev() function + tk.MustExec("drop table if exists t1;") + tk.MustExec(`create table t1 (grp int, a bigint unsigned, c char(10) not null);`) + tk.MustExec(`insert into t1 values (1,1,"a");`) + tk.MustExec(`insert into t1 values (2,2,"b");`) + tk.MustExec(`insert into t1 values (2,3,"c");`) + tk.MustExec(`insert into t1 values (3,4,"E");`) + tk.MustExec(`insert into t1 values (3,5,"C");`) + tk.MustExec(`insert into t1 values (3,6,"D");`) + tk.MustQuery(`select stddev_pop(all a) from t1;`).Check(testkit.Rows("1.707825127659933")) + tk.MustQuery(`select stddev_pop(a) from t1 group by grp order by grp;`).Check(testkit.Rows("0", "0.5", "0.816496580927726")) + tk.MustQuery(`select sum(a)+count(a)+avg(a)+stddev_pop(a) as sum from t1 group by grp order by grp;`).Check(testkit.Rows("3", "10", "23.816496580927726")) + tk.MustQuery(`select std(all a) from t1;`).Check(testkit.Rows("1.707825127659933")) + tk.MustQuery(`select std(a) from t1 group by grp order by grp;`).Check(testkit.Rows("0", "0.5", "0.816496580927726")) + tk.MustQuery(`select sum(a)+count(a)+avg(a)+std(a) as sum from t1 group by grp order by grp;`).Check(testkit.Rows("3", "10", "23.816496580927726")) + tk.MustQuery(`select stddev(all a) from t1;`).Check(testkit.Rows("1.707825127659933")) + tk.MustQuery(`select stddev(a) from t1 group by grp order by grp;`).Check(testkit.Rows("0", "0.5", "0.816496580927726")) + tk.MustQuery(`select sum(a)+count(a)+avg(a)+stddev(a) as sum from t1 group by grp order by grp;`).Check(testkit.Rows("3", "10", "23.816496580927726")) + // test null + tk.MustExec("drop table if exists t1;") + tk.MustExec("CREATE TABLE t1 (a int, b int);") + tk.MustQuery("select stddev_pop(b) from t1;").Check(testkit.Rows("")) + tk.MustQuery("select std(b) from t1;").Check(testkit.Rows("")) + tk.MustQuery("select stddev(b) from t1;").Check(testkit.Rows("")) + tk.MustExec("insert into t1 values (1,null);") + tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("")) + tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("")) + tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("")) + tk.MustExec("insert into t1 values (1,null);") + tk.MustExec("insert into t1 values (2,null);") + tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("", "")) + tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("", "")) + tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("", "")) + tk.MustExec("insert into t1 values (2,1);") + tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("", "0")) + tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("", "0")) + tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("", "0")) + tk.MustExec("insert into t1 values (3,1);") + tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) + tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) + tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) } func (s *testSuiteAgg) TestAggPrune(c *C) { diff --git a/expression/aggregation/agg_to_pb.go b/expression/aggregation/agg_to_pb.go index 98bb1feded596..f256667f3abaa 100644 --- a/expression/aggregation/agg_to_pb.go +++ b/expression/aggregation/agg_to_pb.go @@ -61,6 +61,8 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag tp = tipb.ExprType_VarPop case ast.AggFuncJsonObjectAgg: tp = tipb.ExprType_JsonObjectAgg + case ast.AggFuncStddevPop: + tp = tipb.ExprType_StddevPop } if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) { return nil diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 6604f4ca144a2..ec0bd523f9d6f 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -111,6 +111,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { a.typeInfer4LeadLag(ctx) case ast.AggFuncVarPop: a.typeInfer4VarPop(ctx) + case ast.AggFuncStddevPop: + a.typeInfer4Std(ctx) case ast.AggFuncJsonObjectAgg: a.typeInfer4JsonFuncs(ctx) default: @@ -256,6 +258,12 @@ func (a *baseFuncDesc) typeInfer4VarPop(ctx sessionctx.Context) { a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength } +func (a *baseFuncDesc) typeInfer4Std(ctx sessionctx.Context) { + //std's return value type is double + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength +} + // GetDefaultValue gets the default value when the function's input is null. // According to MySQL, default values of the function are listed as follows: // e.g. diff --git a/go.mod b/go.mod index 2cf750382a1fa..de84a24bbd989 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 github.com/pingcap/kvproto v0.0.0-20200818080353-7aaed8998596 github.com/pingcap/log v0.0.0-20200828042413-fce0951f1463 - github.com/pingcap/parser v0.0.0-20200831060432-37fb52783318 + github.com/pingcap/parser v0.0.0-20200901062802-475ea5e2e0a7 github.com/pingcap/pd/v4 v4.0.5-0.20200817114353-e465cafe8a91 github.com/pingcap/sysutil v0.0.0-20200715082929-4c47bcac246a github.com/pingcap/tidb-tools v4.0.1-0.20200530144555-cdec43635625+incompatible diff --git a/go.sum b/go.sum index 412e248ed3e64..4d78bb1b31449 100644 --- a/go.sum +++ b/go.sum @@ -453,8 +453,8 @@ github.com/pingcap/parser v0.0.0-20200424075042-8222d8b724a4/go.mod h1:9v0Edh8Ib github.com/pingcap/parser v0.0.0-20200507022230-f3bf29096657/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4= github.com/pingcap/parser v0.0.0-20200603032439-c4ecb4508d2f/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4= github.com/pingcap/parser v0.0.0-20200623164729-3a18f1e5dceb/go.mod h1:vQdbJqobJAgFyiRNNtXahpMoGWwPEuWciVEK5A20NS0= -github.com/pingcap/parser v0.0.0-20200831060432-37fb52783318 h1:QrLimON13AgrHi4ly7bdZhnRkw+I7O2yVqhZ53tEX4I= -github.com/pingcap/parser v0.0.0-20200831060432-37fb52783318/go.mod h1:vQdbJqobJAgFyiRNNtXahpMoGWwPEuWciVEK5A20NS0= +github.com/pingcap/parser v0.0.0-20200901062802-475ea5e2e0a7 h1:B+x0Vu4YNkgudIqZZ3DHFYISiKu+UvFCu84zKf8FeLc= +github.com/pingcap/parser v0.0.0-20200901062802-475ea5e2e0a7/go.mod h1:vQdbJqobJAgFyiRNNtXahpMoGWwPEuWciVEK5A20NS0= github.com/pingcap/pd/v4 v4.0.0-rc.1.0.20200422143320-428acd53eba2/go.mod h1:s+utZtXDznOiL24VK0qGmtoHjjXNsscJx3m1n8cC56s= github.com/pingcap/pd/v4 v4.0.0-rc.2.0.20200520083007-2c251bd8f181 h1:FM+PzdoR3fmWAJx3ug+p5aOgs5aZYwFkoDL7Potdsz0= github.com/pingcap/pd/v4 v4.0.0-rc.2.0.20200520083007-2c251bd8f181/go.mod h1:q4HTx/bA8aKBa4S7L+SQKHvjRPXCRV0tA0yRw0qkZSA= diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index e7d371f321c42..c713a0badf92f 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -38,7 +38,7 @@ func (a *aggregationPushDownSolver) isDecomposableWithJoin(fun *aggregation.AggF return false } switch fun.Name { - case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg: + case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop: // TODO: Support avg push down. return false case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: