From eb977aff5e6199422217a88c5edffecead62e5b4 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Sat, 24 Jul 2021 00:08:47 +0530 Subject: [PATCH 01/12] gen4: support distinct inside function expression Signed-off-by: Harshit Gangal --- .../planbuilder/abstract/queryprojection.go | 20 +++++-- go/vt/vtgate/planbuilder/horizon_planning.go | 54 ++++++++++++++++--- go/vt/vtgate/planbuilder/ordered_aggregate.go | 27 ++++++++++ .../planbuilder/testdata/aggr_cases.txt | 1 + 4 files changed, 92 insertions(+), 10 deletions(-) diff --git a/go/vt/vtgate/planbuilder/abstract/queryprojection.go b/go/vt/vtgate/planbuilder/abstract/queryprojection.go index 412f4fad69e..769aeb30fd5 100644 --- a/go/vt/vtgate/planbuilder/abstract/queryprojection.go +++ b/go/vt/vtgate/planbuilder/abstract/queryprojection.go @@ -54,6 +54,9 @@ type ( GroupBy struct { Inner sqlparser.Expr WeightStrExpr sqlparser.Expr + + // This is to add the distinct function expression in grouping column for pushing down but not be to used as grouping key at VTGate level. + Distinct bool } ) @@ -63,15 +66,18 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) { Distinct: sel.Distinct, } + distinctAggrFunc := false for _, selExp := range sel.SelectExprs { exp, ok := selExp.(*sqlparser.AliasedExpr) if !ok { return nil, semantics.Gen4NotSupportedF("%T in select list", selExp) } - if err := checkForInvalidAggregations(exp); err != nil { + foundDistinctAggrFunc, err := checkForInvalidAggregations(exp, distinctAggrFunc) + if err != nil { return nil, err } + distinctAggrFunc = distinctAggrFunc || foundDistinctAggrFunc col := SelectExpr{ Col: exp, } @@ -130,19 +136,25 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) { return qp, nil } -func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { - return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { +func checkForInvalidAggregations(exp *sqlparser.AliasedExpr, failOnDistinctAggrFunc bool) (bool, error) { + distinctAggrFunc := false + err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { fExpr, ok := node.(*sqlparser.FuncExpr) if ok && fExpr.IsAggregate() { if len(fExpr.Exprs) != 1 { return false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr)) } if fExpr.Distinct { - return false, semantics.Gen4NotSupportedF("distinct aggregation") + if failOnDistinctAggrFunc { + return false, semantics.Gen4NotSupportedF("multiple distinct aggregation function") + } + distinctAggrFunc = true + return true, nil } } return true, nil }, exp.Expr) + return distinctAggrFunc, err } func (qp *QueryProjection) getNonAggrExprNotMatchingGroupByExprs() sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 1b0e5dddf3a..d682121d8e6 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -145,21 +145,58 @@ func (hp *horizonPlanning) planAggregations() error { } for _, e := range hp.qp.SelectExprs { - offset, _, err := pushProjection(e.Col, hp.plan, hp.semTable, true, false) - if err != nil { - return err + // push all expression if they are non-aggregating or the plan is not ordered aggregated plan. + if !e.Aggr || oa == nil { + _, _, err := pushProjection(e.Col, hp.plan, hp.semTable, true, false) + if err != nil { + return err + } } + if e.Aggr && oa != nil { fExpr, isFunc := e.Col.Expr.(*sqlparser.FuncExpr) if !isFunc { return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") } opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ + handleDistinct, innerAliased, err := oa.needDistinctHandlingGen4(fExpr, opcode) + if err != nil { + return err + } + + pushExpr := e.Col + var alias string + if handleDistinct { + switch opcode { + case engine.AggregateCount: + opcode = engine.AggregateCountDistinct + case engine.AggregateSum: + opcode = engine.AggregateSumDistinct + } + + if e.Col.As.IsEmpty() { + alias = sqlparser.String(e.Col.Expr) + } else { + alias = e.Col.As.String() + } + + pushExpr = innerAliased + oa.eaggr.PreProcess = true + + hp.haveToTruncate(true) + hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, Distinct: true}) + } + offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) + if err != nil { + return err + } + aggrParams := engine.AggregateParams{ Opcode: opcode, Col: offset, + Alias: alias, Expr: fExpr, - }) + } + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, aggrParams) } } @@ -229,7 +266,9 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem if err != nil { return false, err } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: weightStringOffset, Expr: groupExpr.WeightStrExpr}) + if !groupExpr.Distinct { + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: weightStringOffset, Expr: groupExpr.WeightStrExpr}) + } colAddedRecursively, err := planGroupByGen4(groupExpr, node.input, semTable) if err != nil { return false, err @@ -327,6 +366,9 @@ func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan log if err != nil { return 0, 0, false, err } + if weightStrExpr == nil { + return offset, -1, added, nil + } colName, ok := expr.(*sqlparser.ColName) if !ok { return 0, 0, false, semantics.Gen4NotSupportedF("group by/order by non-column expression") diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index cf4bbed2e42..33dd81d21ed 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -303,6 +303,33 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr return true, innerAliased, nil } +// needDistinctHandling returns true if oa needs to handle the distinct clause. +// If true, it will also return the aliased expression that needs to be pushed +// down into the underlying route. +func (oa *orderedAggregate) needDistinctHandlingGen4(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode) (bool, *sqlparser.AliasedExpr, error) { + if !funcExpr.Distinct { + return false, nil, nil + } + if opcode != engine.AggregateCount && opcode != engine.AggregateSum { + return false, nil, nil + } + innerAliased, ok := funcExpr.Exprs[0].(*sqlparser.AliasedExpr) + if !ok { + return false, nil, fmt.Errorf("syntax error: %s", sqlparser.String(funcExpr)) + } + _, ok = oa.input.(*route) + if !ok { + // Unreachable + return true, innerAliased, nil + } + // check for unique vindex + //vindex := pb.st.Vindex(innerAliased.Expr, rb) + //if vindex != nil && vindex.IsUnique() { + // return false, nil, nil + //} + return true, innerAliased, nil +} + // Wireup implements the logicalPlan interface // If text columns are detected in the keys, then the function modifies // the primitive to pull a corresponding weight_string from mysql and diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index 544c0bfaa5e..2d054559070 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -866,6 +866,7 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont "Table": "`user`" } } +Gen4 plan same as above # count with distinct unique vindex "select col, count(distinct id) from user group by col" From f980ec3cc9f891a4338ea68da6c3d7e8890fbd55 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 26 Jul 2021 11:35:35 +0530 Subject: [PATCH 02/12] gen4: added e2e test for count distinct function Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/gen4/gen4_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/go/test/endtoend/vtgate/gen4/gen4_test.go b/go/test/endtoend/vtgate/gen4/gen4_test.go index 34b463fec10..9d01c34b73b 100644 --- a/go/test/endtoend/vtgate/gen4/gen4_test.go +++ b/go/test/endtoend/vtgate/gen4/gen4_test.go @@ -86,6 +86,27 @@ func TestGroupBy(t *testing.T) { `[INT64(2) VARCHAR("B") VARCHAR("C") VARCHAR("abc")]]`) } +func TestAggregationFunc(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + defer exec(t, conn, `delete from t2`) + + // insert some data. + checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) + + // on primary vindex + assertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`, + `[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`) + + // on any column + assertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`, + `[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`) + +} + func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { t.Helper() qr := checkedExec(t, conn, query) From 96913fbd5728094f8236854c5b63e2b0d6a20933 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 09:37:41 +0530 Subject: [PATCH 03/12] make aggregation work with varchar columns Signed-off-by: Harshit Gangal --- go/sqltypes/value.go | 12 ++ go/vt/vtgate/engine/ordered_aggregate.go | 47 +++--- go/vt/vtgate/engine/ordered_aggregate_test.go | 144 ++++++++++++++++-- go/vt/vtgate/planbuilder/horizon_planning.go | 40 ++--- go/vt/vtgate/planbuilder/ordered_aggregate.go | 4 +- go/vt/vtgate/planbuilder/show.go | 2 +- 6 files changed, 198 insertions(+), 51 deletions(-) diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 491ee7d7059..85002c8bc4d 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -372,6 +372,18 @@ func (v Value) IsDateTime() bool { return int(v.typ)&dt == dt } +// IsComparable returns true if the Value is null safe comparable without collation information. +func (v *Value) IsComparable() bool { + if v.typ == Null || IsNumber(v.typ) || IsBinary(v.typ) { + return true + } + switch v.typ { + case Timestamp, Date, Time, Datetime, Enum, Set, TypeJSON, Bit: + return true + } + return false +} + // MarshalJSON should only be used for testing. // It's not a complete implementation. func (v Value) MarshalJSON() ([]byte, error) { diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 747a77db040..920173f9279 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -45,7 +45,7 @@ type OrderedAggregate struct { PreProcess bool `json:",omitempty"` // Aggregates specifies the aggregation parameters for each // aggregation function: function opcode and input column number. - Aggregates []AggregateParams + Aggregates []*AggregateParams // GroupByKeys specifies the input values that must be used for // the aggregation key. @@ -78,27 +78,33 @@ func (gbp GroupByParams) String() string { // AggregateParams specify the parameters for each aggregation. // It contains the opcode and input column number. type AggregateParams struct { - Opcode AggregateOpcode - Col int + Opcode AggregateOpcode + Col int + KeyCol int + WCol int + WAssigned bool // Alias is set only for distinct opcodes. Alias string `json:",omitempty"` Expr sqlparser.Expr } -func (ap AggregateParams) isDistinct() bool { +func (ap *AggregateParams) isDistinct() bool { return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct } -func (ap AggregateParams) preProcess() bool { +func (ap *AggregateParams) preProcess() bool { return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct || ap.Opcode == AggregateGtid } -func (ap AggregateParams) String() string { +func (ap *AggregateParams) String() string { + keyCol := strconv.Itoa(ap.Col) + if ap.Opcode == AggregateCountDistinct && ap.WAssigned { + keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) + } if ap.Alias != "" { - return fmt.Sprintf("%s(%d) AS %s", ap.Opcode.String(), ap.Col, ap.Alias) + return fmt.Sprintf("%s(%s) AS %s", ap.Opcode.String(), keyCol, ap.Alias) } - - return fmt.Sprintf("%s(%d)", ap.Opcode.String(), ap.Col) + return fmt.Sprintf("%s(%s)", ap.Opcode.String(), keyCol) } // AggregateOpcode is the aggregation Opcode. @@ -306,6 +312,9 @@ func (oa *OrderedAggregate) convertFields(fields []*querypb.Field) []*querypb.Fi Name: aggr.Alias, Type: opcodeType[aggr.Opcode], } + if aggr.isDistinct() { + aggr.KeyCol = aggr.Col + } } return fields } @@ -318,17 +327,21 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. for _, aggr := range oa.Aggregates { switch aggr.Opcode { case AggregateCountDistinct: - curDistinct = row[aggr.Col] + curDistinct = row[aggr.KeyCol] + if aggr.WAssigned && !curDistinct.IsComparable() { + aggr.KeyCol = aggr.WCol + curDistinct = row[aggr.KeyCol] + } // Type is int64. Ok to call MakeTrusted. - if row[aggr.Col].IsNull() { + if row[aggr.KeyCol].IsNull() { newRow[aggr.Col] = countZero } else { newRow[aggr.Col] = countOne } case AggregateSumDistinct: - curDistinct = row[aggr.Col] + curDistinct = row[aggr.KeyCol] var err error - newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) + newRow[aggr.Col], err = evalengine.Cast(row[aggr.KeyCol], opcodeType[aggr.Opcode]) if err != nil { newRow[aggr.Col] = sumZero } @@ -392,17 +405,17 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes result := sqltypes.CopyRow(row1) for _, aggr := range oa.Aggregates { if aggr.isDistinct() { - if row2[aggr.Col].IsNull() { + if row2[aggr.KeyCol].IsNull() { continue } - cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.Col]) + cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.KeyCol]) if err != nil { return nil, sqltypes.NULL, err } if cmp == 0 { continue } - curDistinct = row2[aggr.Col] + curDistinct = row2[aggr.KeyCol] } var err error switch aggr.Opcode { @@ -473,7 +486,7 @@ func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) { } func aggregateParamsToString(in interface{}) string { - return in.(AggregateParams).String() + return in.(*AggregateParams).String() } func groupByParamsToString(i interface{}) string { diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index b471b94a580..8560f49d579 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -49,7 +49,7 @@ func TestOrderedAggregateExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -86,7 +86,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -128,7 +128,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -171,7 +171,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -307,7 +307,7 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCountDistinct, Col: 1, Alias: "count(distinct col2)", @@ -383,7 +383,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCountDistinct, Col: 1, Alias: "count(distinct col2)", @@ -471,7 +471,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateSumDistinct, Col: 1, Alias: "sum(distinct col2)", @@ -520,7 +520,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateSumDistinct, Col: 1, Alias: "sum(distinct col2)", @@ -556,7 +556,7 @@ func TestOrderedAggregateKeysFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -589,7 +589,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }}, @@ -629,7 +629,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { func TestMerge(t *testing.T) { assert := assert.New(t) oa := &OrderedAggregate{ - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateCount, Col: 1, }, { @@ -716,7 +716,7 @@ func TestNoInputAndNoGroupingKeys(outer *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: test.opcode, Col: 0, Alias: test.name, @@ -769,7 +769,7 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { oa := &OrderedAggregate{ PreProcess: true, - Aggregates: []AggregateParams{{ + Aggregates: []*AggregateParams{{ Opcode: AggregateGtid, Col: 1, Alias: "vgtid", @@ -790,3 +790,121 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { ) assert.Equal(t, wantResult, result) } + +func TestCountDistinctOnVarchar(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "1|a|0x41", + "1|a|0x41", + "1|b|0x42", + "2|b|0x42", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateCountDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "count(distinct c2)", + }}, + GroupByKeys: []GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|count(distinct c2)", + "int64|int64", + ), + `1|2`, `2|1`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} + +func TestCountDistinctOnVarcharWithNulls(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "null|a|0x41", + "null|b|0x42", + "null|null|null", + "1|null|null", + "1|null|null", + "1|a|0x41", + "1|a|0x41", + "1|b|0x42", + "2|null|null", + "2|b|0x42", + "3|null|null", + "3|null|null", + "3|null|null", + "3|null|null", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateCountDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "count(distinct c2)", + }}, + GroupByKeys: []GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|count(distinct c2)", + "int64|int64", + ), + `null|2`, `1|2`, `2|1`, `3|0`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index d682121d8e6..1f45a83ecf9 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -164,37 +164,41 @@ func (hp *horizonPlanning) planAggregations() error { return err } - pushExpr := e.Col - var alias string + aggrParams := &engine.AggregateParams{ + Opcode: opcode, + Expr: fExpr, + } if handleDistinct { switch opcode { case engine.AggregateCount: - opcode = engine.AggregateCountDistinct + aggrParams.Opcode = engine.AggregateCountDistinct case engine.AggregateSum: - opcode = engine.AggregateSumDistinct + aggrParams.Opcode = engine.AggregateSumDistinct } if e.Col.As.IsEmpty() { - alias = sqlparser.String(e.Col.Expr) + aggrParams.Alias = sqlparser.String(e.Col.Expr) } else { - alias = e.Col.As.String() + aggrParams.Alias = e.Col.As.String() } - - pushExpr = innerAliased oa.eaggr.PreProcess = true hp.haveToTruncate(true) hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, Distinct: true}) - } - offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) - if err != nil { - return err - } - aggrParams := engine.AggregateParams{ - Opcode: opcode, - Col: offset, - Alias: alias, - Expr: fExpr, + + offset, wOffset, _, err := wrapAndPushExpr(innerAliased.Expr, innerAliased.Expr, oa.input, hp.semTable) + if err != nil { + return err + } + aggrParams.Col = offset + aggrParams.WCol = wOffset + aggrParams.WAssigned = true + } else { + offset, _, err := pushProjection(e.Col, oa.input, hp.semTable, true, true) + if err != nil { + return err + } + aggrParams.Col = offset } oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, aggrParams) } diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index 33dd81d21ed..f2e9fd7c471 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -253,7 +253,7 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias case engine.AggregateSum: opcode = engine.AggregateSumDistinct } - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ Opcode: opcode, Col: innerCol, Alias: alias, @@ -264,7 +264,7 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias return nil, 0, err } pb.plan = newBuilder - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, engine.AggregateParams{ + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ Opcode: opcode, Col: innerCol, }) diff --git a/go/vt/vtgate/planbuilder/show.go b/go/vt/vtgate/planbuilder/show.go index bce8a701d7b..8a403bb3df3 100644 --- a/go/vt/vtgate/planbuilder/show.go +++ b/go/vt/vtgate/planbuilder/show.go @@ -505,7 +505,7 @@ func buildShowVGtidPlan(show *sqlparser.ShowBasic, vschema ContextVSchema) (engi } return &engine.OrderedAggregate{ PreProcess: true, - Aggregates: []engine.AggregateParams{ + Aggregates: []*engine.AggregateParams{ { Opcode: engine.AggregateGtid, Col: 1, From 23a12dfd19ed0e4c628c8e1add525c2d6feeab3c Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 10:19:59 +0530 Subject: [PATCH 04/12] add weight_string expr to group by exprs for count distinct Signed-off-by: Harshit Gangal --- go/vt/vtgate/planbuilder/horizon_planning.go | 3 ++- go/vt/wrangler/vdiff.go | 4 ++-- go/vt/wrangler/vdiff_test.go | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 1f45a83ecf9..7a82b3a13e3 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -151,6 +151,7 @@ func (hp *horizonPlanning) planAggregations() error { if err != nil { return err } + continue } if e.Aggr && oa != nil { @@ -184,7 +185,7 @@ func (hp *horizonPlanning) planAggregations() error { oa.eaggr.PreProcess = true hp.haveToTruncate(true) - hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, Distinct: true}) + hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, WeightStrExpr: innerAliased.Expr, Distinct: true}) offset, wOffset, _, err := wrapAndPushExpr(innerAliased.Expr, innerAliased.Expr, oa.input, hp.semTable) if err != nil { diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 9eea5470aa0..b6e91ea0a82 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -434,7 +434,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer sourceSelect := &sqlparser.Select{} targetSelect := &sqlparser.Select{} // aggregates contains the list if Aggregate functions, if any. - var aggregates []engine.AggregateParams + var aggregates []*engine.AggregateParams for _, selExpr := range sel.SelectExprs { switch selExpr := selExpr.(type) { case *sqlparser.StarExpr: @@ -463,7 +463,7 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer if expr, ok := selExpr.Expr.(*sqlparser.FuncExpr); ok { switch fname := expr.Name.Lowered(); fname { case "count", "sum": - aggregates = append(aggregates, engine.AggregateParams{ + aggregates = append(aggregates, &engine.AggregateParams{ Opcode: engine.SupportedAggregates[fname], Col: len(sourceSelect.SelectExprs) - 1, }) diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index a0ed626a660..1dcb7bccad6 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -387,7 +387,7 @@ func TestVDiffPlanSuccess(t *testing.T) { pkCols: []int{0}, selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ - Aggregates: []engine.AggregateParams{{ + Aggregates: []*engine.AggregateParams{{ Opcode: engine.AggregateCount, Col: 2, }, { From 8a6aa73045866900024f20a7df27506afb56b1f8 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 12:00:40 +0530 Subject: [PATCH 05/12] push the aggr columns and later push all the group by columns with weight_string Signed-off-by: Harshit Gangal --- .../planbuilder/abstract/queryprojection.go | 3 +- go/vt/vtgate/planbuilder/horizon_planning.go | 57 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/go/vt/vtgate/planbuilder/abstract/queryprojection.go b/go/vt/vtgate/planbuilder/abstract/queryprojection.go index 769aeb30fd5..478b8f2c9bc 100644 --- a/go/vt/vtgate/planbuilder/abstract/queryprojection.go +++ b/go/vt/vtgate/planbuilder/abstract/queryprojection.go @@ -56,7 +56,8 @@ type ( WeightStrExpr sqlparser.Expr // This is to add the distinct function expression in grouping column for pushing down but not be to used as grouping key at VTGate level. - Distinct bool + // Starts with 1 so that default (0) means unassigned. + DistinctAggrIndex int } ) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 7a82b3a13e3..59548250280 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -165,43 +165,37 @@ func (hp *horizonPlanning) planAggregations() error { return err } - aggrParams := &engine.AggregateParams{ - Opcode: opcode, - Expr: fExpr, - } + pushExpr := e.Col + var alias string if handleDistinct { + pushExpr = innerAliased + switch opcode { case engine.AggregateCount: - aggrParams.Opcode = engine.AggregateCountDistinct + opcode = engine.AggregateCountDistinct case engine.AggregateSum: - aggrParams.Opcode = engine.AggregateSumDistinct + opcode = engine.AggregateSumDistinct } - if e.Col.As.IsEmpty() { - aggrParams.Alias = sqlparser.String(e.Col.Expr) + alias = sqlparser.String(e.Col.Expr) } else { - aggrParams.Alias = e.Col.As.String() + alias = e.Col.As.String() } - oa.eaggr.PreProcess = true + oa.eaggr.PreProcess = true hp.haveToTruncate(true) - hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, WeightStrExpr: innerAliased.Expr, Distinct: true}) - - offset, wOffset, _, err := wrapAndPushExpr(innerAliased.Expr, innerAliased.Expr, oa.input, hp.semTable) - if err != nil { - return err - } - aggrParams.Col = offset - aggrParams.WCol = wOffset - aggrParams.WAssigned = true - } else { - offset, _, err := pushProjection(e.Col, oa.input, hp.semTable, true, true) - if err != nil { - return err - } - aggrParams.Col = offset + hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, WeightStrExpr: innerAliased.Expr, DistinctAggrIndex: len(oa.eaggr.Aggregates) + 1}) } - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, aggrParams) + offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) + if err != nil { + return err + } + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ + Opcode: opcode, + Col: offset, + Alias: alias, + Expr: fExpr, + }) } } @@ -267,12 +261,17 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem _, _, added, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node, semTable) return added, err case *orderedAggregate: - keyCol, weightStringOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) + keyCol, wsOffset, colAdded, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node.input, semTable) if err != nil { return false, err } - if !groupExpr.Distinct { - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: weightStringOffset, Expr: groupExpr.WeightStrExpr}) + if groupExpr.DistinctAggrIndex == 0 { + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: wsOffset, Expr: groupExpr.WeightStrExpr}) + } else { + if wsOffset != -1 { + node.eaggr.Aggregates[groupExpr.DistinctAggrIndex-1].WAssigned = true + node.eaggr.Aggregates[groupExpr.DistinctAggrIndex-1].WCol = wsOffset + } } colAddedRecursively, err := planGroupByGen4(groupExpr, node.input, semTable) if err != nil { From 17419980dace270955fa032ba4570201f67dbe0e Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 12:27:19 +0530 Subject: [PATCH 06/12] count/sum distinct on unique vindex does not require special handling Signed-off-by: Harshit Gangal --- go/vt/vtgate/planbuilder/horizon_planning.go | 2 +- go/vt/vtgate/planbuilder/ordered_aggregate.go | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 59548250280..ec88ba27c9e 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -160,7 +160,7 @@ func (hp *horizonPlanning) planAggregations() error { return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") } opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] - handleDistinct, innerAliased, err := oa.needDistinctHandlingGen4(fExpr, opcode) + handleDistinct, innerAliased, err := oa.needDistinctHandlingGen4(fExpr, opcode, hp.semTable, hp.vschema) if err != nil { return err } diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index f2e9fd7c471..f8dcd407ac9 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -306,7 +306,7 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr // needDistinctHandling returns true if oa needs to handle the distinct clause. // If true, it will also return the aliased expression that needs to be pushed // down into the underlying route. -func (oa *orderedAggregate) needDistinctHandlingGen4(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode) (bool, *sqlparser.AliasedExpr, error) { +func (oa *orderedAggregate) needDistinctHandlingGen4(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode, semTable *semantics.SemTable, vschema ContextVSchema) (bool, *sqlparser.AliasedExpr, error) { if !funcExpr.Distinct { return false, nil, nil } @@ -322,11 +322,9 @@ func (oa *orderedAggregate) needDistinctHandlingGen4(funcExpr *sqlparser.FuncExp // Unreachable return true, innerAliased, nil } - // check for unique vindex - //vindex := pb.st.Vindex(innerAliased.Expr, rb) - //if vindex != nil && vindex.IsUnique() { - // return false, nil, nil - //} + if exprHasUniqueVindex(vschema, semTable, innerAliased.Expr) { + return false, nil, nil + } return true, innerAliased, nil } From 1fc0a0fa0646e4559891851dbb82ef411089a1b5 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 12:43:36 +0530 Subject: [PATCH 07/12] update gen4 query plans in aggr cases Signed-off-by: Harshit Gangal --- .../planbuilder/testdata/aggr_cases.txt | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index 2d054559070..b531326e3a5 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -895,6 +895,31 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col, count(distinct id) from user group by col", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(1)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col, count(distinct id), weight_string(col) from `user` where 1 != 1 group by col", + "OrderBy": "(0|2) ASC", + "Query": "select col, count(distinct id), weight_string(col) from `user` group by col order by col asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex "select col1, count(distinct col2) from user group by col1" @@ -923,6 +948,31 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS count(distinct col2)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex and no group by "select count(distinct col2) from user" @@ -950,6 +1000,30 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select count(distinct col2) from user", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(0|1) AS count(distinct col2)", + "ResultColumns": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col2, weight_string(col2) from `user` where 1 != 1 group by col2", + "OrderBy": "(0|1) ASC", + "Query": "select col2, weight_string(col2) from `user` group by col2 order by col2 asc", + "Table": "`user`" + } + ] + } +} # count with distinct no unique vindex, count expression aliased "select col1, count(distinct col2) c2 from user group by col1" @@ -978,6 +1052,31 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) c2 from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS c2", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # sum with distinct no unique vindex "select col1, sum(distinct col2) from user group by col1" @@ -1006,6 +1105,31 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, sum(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_distinct(1) AS sum(distinct col2)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } +} # min with distinct no unique vindex. distinct is ignored. "select col1, min(distinct col2) from user group by col1" @@ -1034,6 +1158,31 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, min(distinct col2) from user group by col1", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "min(1)", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, min(distinct col2), weight_string(col1) from `user` where 1 != 1 group by col1", + "OrderBy": "(0|2) ASC", + "Query": "select col1, min(distinct col2), weight_string(col1) from `user` group by col1 order by col1 asc", + "Table": "`user`" + } + ] + } +} # order by count distinct "select col1, count(distinct col2) k from user group by col1 order by k" @@ -1069,6 +1218,38 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select col1, count(distinct col2) k from user group by col1 order by k", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "1 ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_distinct(1|3) AS k", + "GroupBy": "(0|2)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2", + "OrderBy": "(0|2) ASC, (1|3) ASC", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2 order by col1 asc, col2 asc", + "Table": "`user`" + } + ] + } + ] + } +} # scatter aggregate group by aggregate function "select count(*) b from user group by b" From 065741121868dbfd6605831b1bd3e1e903e3a1a1 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 13:28:20 +0530 Subject: [PATCH 08/12] made groupbyparams as reference so to preserve the change of keycol on comparison Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/cached_size.go | 14 +++++----- go/vt/vtgate/engine/ordered_aggregate.go | 4 +-- go/vt/vtgate/engine/ordered_aggregate_test.go | 26 +++++++++---------- go/vt/vtgate/planbuilder/grouping.go | 4 +-- go/vt/vtgate/planbuilder/horizon_planning.go | 6 ++--- go/vt/wrangler/vdiff.go | 6 ++--- go/vt/wrangler/vdiff_test.go | 2 +- 7 files changed, 31 insertions(+), 31 deletions(-) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index acfc640f6e4..1876c9819d2 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -33,7 +33,7 @@ func (cached *AggregateParams) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(72) } // field Alias string size += int64(len(cached.Alias)) @@ -406,18 +406,18 @@ func (cached *OrderedAggregate) CachedSize(alloc bool) int64 { if alloc { size += int64(80) } - // field Aggregates []vitess.io/vitess/go/vt/vtgate/engine.AggregateParams + // field Aggregates []*vitess.io/vitess/go/vt/vtgate/engine.AggregateParams { - size += int64(cap(cached.Aggregates)) * int64(48) + size += int64(cap(cached.Aggregates)) * int64(8) for _, elem := range cached.Aggregates { - size += elem.CachedSize(false) + size += elem.CachedSize(true) } } - // field GroupByKeys []vitess.io/vitess/go/vt/vtgate/engine.GroupByParams + // field GroupByKeys []*vitess.io/vitess/go/vt/vtgate/engine.GroupByParams { - size += int64(cap(cached.GroupByKeys)) * int64(32) + size += int64(cap(cached.GroupByKeys)) * int64(8) for _, elem := range cached.GroupByKeys { - size += elem.CachedSize(false) + size += elem.CachedSize(true) } } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 920173f9279..9c0efd42f14 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -49,7 +49,7 @@ type OrderedAggregate struct { // GroupByKeys specifies the input values that must be used for // the aggregation key. - GroupByKeys []GroupByParams + GroupByKeys []*GroupByParams // TruncateColumnCount specifies the number of columns to return // in the final result. Rest of the columns are truncated @@ -490,7 +490,7 @@ func aggregateParamsToString(in interface{}) string { } func groupByParamsToString(i interface{}) string { - return i.(GroupByParams).String() + return i.(*GroupByParams).String() } func (oa *OrderedAggregate) description() PrimitiveDescription { diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 8560f49d579..fed61165c52 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -53,7 +53,7 @@ func TestOrderedAggregateExecute(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -90,7 +90,7 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 2}}, + GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -132,7 +132,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -175,7 +175,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 2}}, + GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } @@ -316,7 +316,7 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -392,7 +392,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { Opcode: AggregateCount, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -480,7 +480,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { Opcode: AggregateSum, Col: 2, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -525,7 +525,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { Col: 1, Alias: "sum(distinct col2)", }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -560,7 +560,7 @@ func TestOrderedAggregateKeysFail(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -593,7 +593,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { Opcode: AggregateCount, Col: 1, }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -721,7 +721,7 @@ func TestNoInputAndNoGroupingKeys(outer *testing.T) { Col: 0, Alias: test.name, }}, - GroupByKeys: []GroupByParams{}, + GroupByKeys: []*GroupByParams{}, Input: fp, } @@ -815,7 +815,7 @@ func TestCountDistinctOnVarchar(t *testing.T) { WAssigned: true, Alias: "count(distinct c2)", }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, TruncateColumnCount: 2, } @@ -879,7 +879,7 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { WAssigned: true, Alias: "count(distinct c2)", }}, - GroupByKeys: []GroupByParams{{KeyCol: 0}}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, TruncateColumnCount: 2, } diff --git a/go/vt/vtgate/planbuilder/grouping.go b/go/vt/vtgate/planbuilder/grouping.go index 1f17aa4ec54..1bf8ee06f62 100644 --- a/go/vt/vtgate/planbuilder/grouping.go +++ b/go/vt/vtgate/planbuilder/grouping.go @@ -77,7 +77,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou default: return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: only simple references allowed") } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: colNumber, WeightStringCol: -1}) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: colNumber, WeightStringCol: -1}) } // Append the distinct aggregate if any. if node.extraDistinct != nil { @@ -110,7 +110,7 @@ func planDistinct(input logicalPlan) (logicalPlan, error) { if rc.column.Origin() == node { return newDistinct(node), nil } - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: i, WeightStringCol: -1}) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: i, WeightStringCol: -1}) } newInput, err := planDistinct(node.input) if err != nil { diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index ec88ba27c9e..86af9f3ddca 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -266,7 +266,7 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem return false, err } if groupExpr.DistinctAggrIndex == 0 { - node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, engine.GroupByParams{KeyCol: keyCol, WeightStringCol: wsOffset, Expr: groupExpr.WeightStrExpr}) + node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: keyCol, WeightStringCol: wsOffset, Expr: groupExpr.WeightStrExpr}) } else { if wsOffset != -1 { node.eaggr.Aggregates[groupExpr.DistinctAggrIndex-1].WAssigned = true @@ -556,7 +556,7 @@ func (hp *horizonPlanning) planDistinctOA(currPlan *orderedAggregate) error { for _, aggrParam := range currPlan.eaggr.Aggregates { if sqlparser.EqualsExpr(sExpr.Col.Expr, aggrParam.Expr) { found = true - eaggr.GroupByKeys = append(eaggr.GroupByKeys, engine.GroupByParams{KeyCol: aggrParam.Col, WeightStringCol: -1}) + eaggr.GroupByKeys = append(eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: aggrParam.Col, WeightStringCol: -1}) break } } @@ -579,7 +579,7 @@ func (hp *horizonPlanning) addDistinct() error { eaggr: eaggr, } for index, sExpr := range hp.qp.SelectExprs { - grpParam := engine.GroupByParams{KeyCol: index, WeightStringCol: -1} + grpParam := &engine.GroupByParams{KeyCol: index, WeightStringCol: -1} _, wOffset, added, err := wrapAndPushExpr(sExpr.Col.Expr, sExpr.Col.Expr, hp.plan, hp.semTable) if err != nil { return err diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index b6e91ea0a82..545e2ae63f7 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -538,10 +538,10 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer return td, nil } -func pkColsToGroupByParams(pkCols []int) []engine.GroupByParams { - var res []engine.GroupByParams +func pkColsToGroupByParams(pkCols []int) []*engine.GroupByParams { + var res []*engine.GroupByParams for _, col := range pkCols { - res = append(res, engine.GroupByParams{KeyCol: col, WeightStringCol: -1}) + res = append(res, &engine.GroupByParams{KeyCol: col, WeightStringCol: -1}) } return res } diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 1dcb7bccad6..53736a04687 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -394,7 +394,7 @@ func TestVDiffPlanSuccess(t *testing.T) { Opcode: engine.AggregateSum, Col: 3, }}, - GroupByKeys: []engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1}}, + GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1}}, Input: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), }, targetPrimitive: newMergeSorter(nil, []compareColInfo{{0, 0, true}}), From 128d0b875135698a828dccef1081e87bad34f139 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 20:55:15 +0530 Subject: [PATCH 09/12] make sum distinct work with varchar column types Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/gen4/gen4_test.go | 17 ++++- go/vt/vtgate/engine/ordered_aggregate.go | 10 ++- go/vt/vtgate/engine/ordered_aggregate_test.go | 66 ++++++++++++++++++- 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/go/test/endtoend/vtgate/gen4/gen4_test.go b/go/test/endtoend/vtgate/gen4/gen4_test.go index 9d01c34b73b..8359044210c 100644 --- a/go/test/endtoend/vtgate/gen4/gen4_test.go +++ b/go/test/endtoend/vtgate/gen4/gen4_test.go @@ -86,7 +86,7 @@ func TestGroupBy(t *testing.T) { `[INT64(2) VARCHAR("B") VARCHAR("C") VARCHAR("abc")]]`) } -func TestAggregationFunc(t *testing.T) { +func TestDistinctAggregationFunc(t *testing.T) { ctx := context.Background() conn, err := mysql.Connect(ctx, &vtParams) require.NoError(t, err) @@ -97,14 +97,25 @@ func TestAggregationFunc(t *testing.T) { // insert some data. checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) - // on primary vindex + // count on primary vindex assertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`, `[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`) - // on any column + // count on any column assertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`, `[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`) + // sum of columns + assertMatches(t, conn, `select sum(id), sum(tcol1) from t2`, + `[[DECIMAL(36) FLOAT64(0)]]`) + + // sum on primary vindex + assertMatches(t, conn, `select tcol1, sum(distinct id) from t2 group by tcol1`, + `[[VARCHAR("A") DECIMAL(9)] [VARCHAR("B") DECIMAL(15)] [VARCHAR("C") DECIMAL(12)]]`) + + // sum on any column + assertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`, + `[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`) } func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 9c0efd42f14..759bc2bd680 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -340,8 +340,12 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. } case AggregateSumDistinct: curDistinct = row[aggr.KeyCol] + if aggr.WAssigned && !curDistinct.IsComparable() { + aggr.KeyCol = aggr.WCol + curDistinct = row[aggr.KeyCol] + } var err error - newRow[aggr.Col], err = evalengine.Cast(row[aggr.KeyCol], opcodeType[aggr.Opcode]) + newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) if err != nil { newRow[aggr.Col] = sumZero } @@ -416,6 +420,10 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes continue } curDistinct = row2[aggr.KeyCol] + if aggr.WAssigned && !curDistinct.IsComparable() { + aggr.KeyCol = aggr.WCol + curDistinct = row2[aggr.KeyCol] + } } var err error switch aggr.Opcode { diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index fed61165c52..5a8c1dc354d 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -853,9 +853,9 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( fields, + "null|null|null", "null|a|0x41", "null|b|0x42", - "null|null|null", "1|null|null", "1|null|null", "1|a|0x41", @@ -908,3 +908,67 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { require.NoError(t, err) assert.Equal(t, want, results) } + +func TestSumDistinctOnVarcharWithNulls(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|weight_string(c2)", + "int64|varchar|varbinary", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "null|null|null", + "null|a|0x41", + "null|b|0x42", + "1|null|null", + "1|null|null", + "1|a|0x41", + "1|a|0x41", + "1|b|0x42", + "2|null|null", + "2|b|0x42", + "3|null|null", + "3|null|null", + "3|null|null", + "3|null|null", + )}, + } + + oa := &OrderedAggregate{ + PreProcess: true, + Aggregates: []*AggregateParams{{ + Opcode: AggregateSumDistinct, + Col: 1, + WCol: 2, + WAssigned: true, + Alias: "sum(distinct c2)", + }}, + GroupByKeys: []*GroupByParams{{KeyCol: 0}}, + Input: fp, + TruncateColumnCount: 2, + } + + want := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "c1|sum(distinct c2)", + "int64|decimal", + ), + `null|0`, `1|0`, `2|0`, `3|null`, + ) + + qr, err := oa.Execute(nil, nil, false) + require.NoError(t, err) + assert.Equal(t, want, qr) + + fp.rewind() + results := &sqltypes.Result{} + err = oa.StreamExecute(nil, nil, false, func(qr *sqltypes.Result) error { + if qr.Fields != nil { + results.Fields = qr.Fields + } + results.Rows = append(results.Rows, qr.Rows...) + return nil + }) + require.NoError(t, err) + assert.Equal(t, want, results) +} From a92ef0f98d49a60beaa2a6c41d926d81aed8e441 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 27 Jul 2021 22:13:00 +0530 Subject: [PATCH 10/12] some code refactor Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/ordered_aggregate.go | 38 +++++++++---------- .../planbuilder/abstract/queryprojection.go | 17 ++------- go/vt/vtgate/planbuilder/horizon_planning.go | 6 +++ .../planbuilder/testdata/aggr_cases.txt | 2 +- 4 files changed, 29 insertions(+), 34 deletions(-) diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 759bc2bd680..e78db30f25e 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -78,14 +78,17 @@ func (gbp GroupByParams) String() string { // AggregateParams specify the parameters for each aggregation. // It contains the opcode and input column number. type AggregateParams struct { - Opcode AggregateOpcode - Col int + Opcode AggregateOpcode + Col int + + // These are used only for distinct opcodes. KeyCol int WCol int WAssigned bool // Alias is set only for distinct opcodes. Alias string `json:",omitempty"` - Expr sqlparser.Expr + + Expr sqlparser.Expr } func (ap *AggregateParams) isDistinct() bool { @@ -98,7 +101,7 @@ func (ap *AggregateParams) preProcess() bool { func (ap *AggregateParams) String() string { keyCol := strconv.Itoa(ap.Col) - if ap.Opcode == AggregateCountDistinct && ap.WAssigned { + if ap.WAssigned { keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) } if ap.Alias != "" { @@ -327,11 +330,7 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. for _, aggr := range oa.Aggregates { switch aggr.Opcode { case AggregateCountDistinct: - curDistinct = row[aggr.KeyCol] - if aggr.WAssigned && !curDistinct.IsComparable() { - aggr.KeyCol = aggr.WCol - curDistinct = row[aggr.KeyCol] - } + curDistinct = findComparableCurrentDistinct(row, aggr) // Type is int64. Ok to call MakeTrusted. if row[aggr.KeyCol].IsNull() { newRow[aggr.Col] = countZero @@ -339,11 +338,7 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. newRow[aggr.Col] = countOne } case AggregateSumDistinct: - curDistinct = row[aggr.KeyCol] - if aggr.WAssigned && !curDistinct.IsComparable() { - aggr.KeyCol = aggr.WCol - curDistinct = row[aggr.KeyCol] - } + curDistinct = findComparableCurrentDistinct(row, aggr) var err error newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) if err != nil { @@ -364,6 +359,15 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. return newRow, curDistinct } +func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value { + curDistinct := row[aggr.KeyCol] + if aggr.WAssigned && !curDistinct.IsComparable() { + aggr.KeyCol = aggr.WCol + curDistinct = row[aggr.KeyCol] + } + return curDistinct +} + // GetFields is a Primitive function. func (oa *OrderedAggregate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { qr, err := oa.Input.GetFields(vcursor, bindVars) @@ -419,11 +423,7 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes if cmp == 0 { continue } - curDistinct = row2[aggr.KeyCol] - if aggr.WAssigned && !curDistinct.IsComparable() { - aggr.KeyCol = aggr.WCol - curDistinct = row2[aggr.KeyCol] - } + curDistinct = findComparableCurrentDistinct(row2, aggr) } var err error switch aggr.Opcode { diff --git a/go/vt/vtgate/planbuilder/abstract/queryprojection.go b/go/vt/vtgate/planbuilder/abstract/queryprojection.go index 478b8f2c9bc..9311645a120 100644 --- a/go/vt/vtgate/planbuilder/abstract/queryprojection.go +++ b/go/vt/vtgate/planbuilder/abstract/queryprojection.go @@ -67,18 +67,16 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) { Distinct: sel.Distinct, } - distinctAggrFunc := false for _, selExp := range sel.SelectExprs { exp, ok := selExp.(*sqlparser.AliasedExpr) if !ok { return nil, semantics.Gen4NotSupportedF("%T in select list", selExp) } - foundDistinctAggrFunc, err := checkForInvalidAggregations(exp, distinctAggrFunc) + err := checkForInvalidAggregations(exp) if err != nil { return nil, err } - distinctAggrFunc = distinctAggrFunc || foundDistinctAggrFunc col := SelectExpr{ Col: exp, } @@ -137,25 +135,16 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) { return qp, nil } -func checkForInvalidAggregations(exp *sqlparser.AliasedExpr, failOnDistinctAggrFunc bool) (bool, error) { - distinctAggrFunc := false - err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { +func checkForInvalidAggregations(exp *sqlparser.AliasedExpr) error { + return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { fExpr, ok := node.(*sqlparser.FuncExpr) if ok && fExpr.IsAggregate() { if len(fExpr.Exprs) != 1 { return false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.SyntaxError, "aggregate functions take a single argument '%s'", sqlparser.String(fExpr)) } - if fExpr.Distinct { - if failOnDistinctAggrFunc { - return false, semantics.Gen4NotSupportedF("multiple distinct aggregation function") - } - distinctAggrFunc = true - return true, nil - } } return true, nil }, exp.Expr) - return distinctAggrFunc, err } func (qp *QueryProjection) getNonAggrExprNotMatchingGroupByExprs() sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 86af9f3ddca..997c67fa48b 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -165,6 +165,12 @@ func (hp *horizonPlanning) planAggregations() error { return err } + // Currently the OA engine primitive is able to handle only one distinct aggregation function. + // PreProcess being true tells that it is already handling it. + if oa.eaggr.PreProcess && handleDistinct { + return semantics.Gen4NotSupportedF("multiple distinct aggregation function") + } + pushExpr := e.Col var alias string if handleDistinct { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index b531326e3a5..e226ded6152 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -1111,7 +1111,7 @@ Gen4 plan same as above "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "sum_distinct(1) AS sum(distinct col2)", + "Aggregates": "sum_distinct(1|3) AS sum(distinct col2)", "GroupBy": "(0|2)", "ResultColumns": 2, "Inputs": [ From c4ba17d3746677df818f45c887bb4c84c06a21e8 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Jul 2021 14:01:49 +0200 Subject: [PATCH 11/12] tweaked tests to make the easier to read Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/ordered_aggregate_test.go | 65 ++++++++++--------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 5a8c1dc354d..a32f50b5e29 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -799,10 +799,10 @@ func TestCountDistinctOnVarchar(t *testing.T) { fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( fields, - "1|a|0x41", - "1|a|0x41", - "1|b|0x42", - "2|b|0x42", + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|b|0x42", )}, } @@ -825,7 +825,8 @@ func TestCountDistinctOnVarchar(t *testing.T) { "c1|count(distinct c2)", "int64|int64", ), - `1|2`, `2|1`, + `10|2`, + `20|1`, ) qr, err := oa.Execute(nil, nil, false) @@ -856,17 +857,17 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { "null|null|null", "null|a|0x41", "null|b|0x42", - "1|null|null", - "1|null|null", - "1|a|0x41", - "1|a|0x41", - "1|b|0x42", - "2|null|null", - "2|b|0x42", - "3|null|null", - "3|null|null", - "3|null|null", - "3|null|null", + "10|null|null", + "10|null|null", + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|null|null", + "20|b|0x42", + "30|null|null", + "30|null|null", + "30|null|null", + "30|null|null", )}, } @@ -889,7 +890,10 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { "c1|count(distinct c2)", "int64|int64", ), - `null|2`, `1|2`, `2|1`, `3|0`, + `null|2`, + `10|2`, + `20|1`, + `30|0`, ) qr, err := oa.Execute(nil, nil, false) @@ -920,17 +924,17 @@ func TestSumDistinctOnVarcharWithNulls(t *testing.T) { "null|null|null", "null|a|0x41", "null|b|0x42", - "1|null|null", - "1|null|null", - "1|a|0x41", - "1|a|0x41", - "1|b|0x42", - "2|null|null", - "2|b|0x42", - "3|null|null", - "3|null|null", - "3|null|null", - "3|null|null", + "10|null|null", + "10|null|null", + "10|a|0x41", + "10|a|0x41", + "10|b|0x42", + "20|null|null", + "20|b|0x42", + "30|null|null", + "30|null|null", + "30|null|null", + "30|null|null", )}, } @@ -953,7 +957,10 @@ func TestSumDistinctOnVarcharWithNulls(t *testing.T) { "c1|sum(distinct c2)", "int64|decimal", ), - `null|0`, `1|0`, `2|0`, `3|null`, + `null|0`, + `10|0`, + `20|0`, + `30|null`, ) qr, err := oa.Execute(nil, nil, false) From 3407e07bd138e506c82347e3965c8d2ba58da5f4 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Jul 2021 14:43:10 +0200 Subject: [PATCH 12/12] small refactoring to move code closer together Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/horizon_planning.go | 135 ++++++++++++------ go/vt/vtgate/planbuilder/ordered_aggregate.go | 25 ---- 2 files changed, 89 insertions(+), 71 deletions(-) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 997c67fa48b..3f5aa033bc1 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -154,55 +154,33 @@ func (hp *horizonPlanning) planAggregations() error { continue } - if e.Aggr && oa != nil { - fExpr, isFunc := e.Col.Expr.(*sqlparser.FuncExpr) - if !isFunc { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") - } - opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] - handleDistinct, innerAliased, err := oa.needDistinctHandlingGen4(fExpr, opcode, hp.semTable, hp.vschema) - if err != nil { - return err - } - - // Currently the OA engine primitive is able to handle only one distinct aggregation function. - // PreProcess being true tells that it is already handling it. - if oa.eaggr.PreProcess && handleDistinct { - return semantics.Gen4NotSupportedF("multiple distinct aggregation function") - } - - pushExpr := e.Col - var alias string - if handleDistinct { - pushExpr = innerAliased + fExpr, isFunc := e.Col.Expr.(*sqlparser.FuncExpr) + if !isFunc { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: complex aggregate expression") + } + opcode := engine.SupportedAggregates[fExpr.Name.Lowered()] + handleDistinct, innerAliased, err := hp.needDistinctHandling(fExpr, opcode, oa.input) + if err != nil { + return err + } - switch opcode { - case engine.AggregateCount: - opcode = engine.AggregateCountDistinct - case engine.AggregateSum: - opcode = engine.AggregateSumDistinct - } - if e.Col.As.IsEmpty() { - alias = sqlparser.String(e.Col.Expr) - } else { - alias = e.Col.As.String() - } + // Currently the OA engine primitive is able to handle only one distinct aggregation function. + // PreProcess being true tells that it is already handling it. + if oa.eaggr.PreProcess && handleDistinct { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "multiple distinct aggregation function") + } - oa.eaggr.PreProcess = true - hp.haveToTruncate(true) - hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, abstract.GroupBy{Inner: innerAliased.Expr, WeightStrExpr: innerAliased.Expr, DistinctAggrIndex: len(oa.eaggr.Aggregates) + 1}) - } - offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) - if err != nil { - return err - } - oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ - Opcode: opcode, - Col: offset, - Alias: alias, - Expr: fExpr, - }) + pushExpr, alias, opcode := hp.createPushExprAndAlias(e, handleDistinct, innerAliased, opcode, oa) + offset, _, err := pushProjection(pushExpr, oa.input, hp.semTable, true, true) + if err != nil { + return err } + oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ + Opcode: opcode, + Col: offset, + Alias: alias, + Expr: fExpr, + }) } for _, groupExpr := range hp.qp.GroupByExprs { @@ -248,6 +226,44 @@ func (hp *horizonPlanning) planAggregations() error { return nil } +// createPushExprAndAlias creates the expression that should be pushed down to the leaves, +// and changes the opcode so it is a distinct one if needed +func (hp *horizonPlanning) createPushExprAndAlias( + expr abstract.SelectExpr, + handleDistinct bool, + innerAliased *sqlparser.AliasedExpr, + opcode engine.AggregateOpcode, + oa *orderedAggregate, +) (*sqlparser.AliasedExpr, string, engine.AggregateOpcode) { + pushExpr := expr.Col + var alias string + if handleDistinct { + pushExpr = innerAliased + + switch opcode { + case engine.AggregateCount: + opcode = engine.AggregateCountDistinct + case engine.AggregateSum: + opcode = engine.AggregateSumDistinct + } + if expr.Col.As.IsEmpty() { + alias = sqlparser.String(expr.Col.Expr) + } else { + alias = expr.Col.As.String() + } + + oa.eaggr.PreProcess = true + hp.haveToTruncate(true) + by := abstract.GroupBy{ + Inner: innerAliased.Expr, + WeightStrExpr: innerAliased.Expr, + DistinctAggrIndex: len(oa.eaggr.Aggregates) + 1, + } + hp.qp.GroupByExprs = append(hp.qp.GroupByExprs, by) + } + return pushExpr, alias, opcode +} + func hasUniqueVindex(vschema ContextVSchema, semTable *semantics.SemTable, groupByExprs []abstract.GroupBy) bool { for _, groupByExpr := range groupByExprs { if exprHasUniqueVindex(vschema, semTable, groupByExpr.WeightStrExpr) { @@ -606,3 +622,30 @@ func selectHasUniqueVindex(vschema ContextVSchema, semTable *semantics.SemTable, } return false } + +// needDistinctHandling returns true if oa needs to handle the distinct clause. +// If true, it will also return the aliased expression that needs to be pushed +// down into the underlying route. +func (hp *horizonPlanning) needDistinctHandling(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode, input logicalPlan) (bool, *sqlparser.AliasedExpr, error) { + if !funcExpr.Distinct { + return false, nil, nil + } + if opcode != engine.AggregateCount && opcode != engine.AggregateSum { + return false, nil, nil + } + innerAliased, ok := funcExpr.Exprs[0].(*sqlparser.AliasedExpr) + if !ok { + return false, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "syntax error: %s", sqlparser.String(funcExpr)) + } + _, ok = input.(*route) + if !ok { + // Unreachable + return true, innerAliased, nil + } + if exprHasUniqueVindex(hp.vschema, hp.semTable, innerAliased.Expr) { + // if we can see a unique vindex on this table/column, + // we know the results will be unique, and we don't need to DISTINCTify them + return false, nil, nil + } + return true, innerAliased, nil +} diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index f8dcd407ac9..63b6cd33871 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -303,31 +303,6 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr return true, innerAliased, nil } -// needDistinctHandling returns true if oa needs to handle the distinct clause. -// If true, it will also return the aliased expression that needs to be pushed -// down into the underlying route. -func (oa *orderedAggregate) needDistinctHandlingGen4(funcExpr *sqlparser.FuncExpr, opcode engine.AggregateOpcode, semTable *semantics.SemTable, vschema ContextVSchema) (bool, *sqlparser.AliasedExpr, error) { - if !funcExpr.Distinct { - return false, nil, nil - } - if opcode != engine.AggregateCount && opcode != engine.AggregateSum { - return false, nil, nil - } - innerAliased, ok := funcExpr.Exprs[0].(*sqlparser.AliasedExpr) - if !ok { - return false, nil, fmt.Errorf("syntax error: %s", sqlparser.String(funcExpr)) - } - _, ok = oa.input.(*route) - if !ok { - // Unreachable - return true, innerAliased, nil - } - if exprHasUniqueVindex(vschema, semTable, innerAliased.Expr) { - return false, nil, nil - } - return true, innerAliased, nil -} - // Wireup implements the logicalPlan interface // If text columns are detected in the keys, then the function modifies // the primitive to pull a corresponding weight_string from mysql and