Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve and Fix Distinct Aggregation planner #13466

Merged
merged 7 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,33 @@ func (mcmp *MySQLCompare) ExecAndIgnore(query string) (*sqltypes.Result, error)
_, _ = mcmp.MySQLConn.ExecuteFetch(query, 1000, true)
return mcmp.VtConn.ExecuteFetch(query, 1000, true)
}

func (mcmp *MySQLCompare) Run(query string, f func(mcmp *MySQLCompare)) {
mcmp.t.Run(query, func(t *testing.T) {
inner := &MySQLCompare{
t: t,
MySQLConn: mcmp.MySQLConn,
VtConn: mcmp.VtConn,
}
f(inner)
})
}

// ExecAllowError executes the query against both Vitess and MySQL.
// If there is no error, it compares the result
// Return any Vitess execution error without comparing the results.
func (mcmp *MySQLCompare) ExecAllowError(query string) (*sqltypes.Result, error) {
mcmp.t.Helper()
vtQr, vtErr := mcmp.VtConn.ExecuteFetch(query, 1000, true)
if vtErr != nil {
return nil, vtErr
}
mysqlQr, mysqlErr := mcmp.MySQLConn.ExecuteFetch(query, 1000, true)

// Since we allow errors, we don't want to compare results if one of the client failed.
// Vitess and MySQL should always be agreeing whether the query returns an error or not.
if mysqlErr == nil {
vtErr = compareVitessAndMySQLResults(mcmp.t, query, mcmp.VtConn, vtQr, mysqlQr, false)
}
return vtQr, vtErr
}
34 changes: 34 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,37 @@ func compareRow(t *testing.T, mRes *sqltypes.Result, vtRes *sqltypes.Result, grp
require.True(t, foundKey, "mysql and vitess result does not same row: vitess:%v, mysql:%v", vtRes.Rows, mRes.Rows)
}
}

func TestDistinctAggregation(t *testing.T) {
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")

tcases := []struct {
query string
expectedErr string
}{{
query: `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct shardkey) (errno 1235) (sqlstate 42000)",
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`,
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`,
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`,
expectedErr: "VT12001: unsupported: only one DISTINCT aggregation is allowed in a SELECT: sum(distinct b.t1_id) (errno 1235) (sqlstate 42000)",
}, {
query: `SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.t1_id), min(DISTINCT a.t1_id) FROM t1 a, t1 b group by a.value`,
}}

for _, tc := range tcases {
mcmp.Run(tc.query, func(mcmp *utils.MySQLCompare) {
_, err := mcmp.ExecAllowError(tc.query)
if tc.expectedErr == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tc.expectedErr)
})
}
}
8 changes: 8 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,7 @@ type (

DistinctableAggr interface {
IsDistinct() bool
SetDistinct(bool)
}

Count struct {
Expand Down Expand Up @@ -3371,6 +3372,13 @@ func (avg *Avg) IsDistinct() bool { return avg.Distinct }
func (count *Count) IsDistinct() bool { return count.Distinct }
func (grpConcat *GroupConcatExpr) IsDistinct() bool { return grpConcat.Distinct }

func (sum *Sum) SetDistinct(distinct bool) { sum.Distinct = distinct }
func (min *Min) SetDistinct(distinct bool) { min.Distinct = distinct }
func (max *Max) SetDistinct(distinct bool) { max.Distinct = distinct }
func (avg *Avg) SetDistinct(distinct bool) { avg.Distinct = distinct }
func (count *Count) SetDistinct(distinct bool) { count.Distinct = distinct }
func (grpConcat *GroupConcatExpr) SetDistinct(distinct bool) { grpConcat.Distinct = distinct }

func (*Sum) AggrName() string { return "sum" }
func (*Min) AggrName() string { return "min" }
func (*Max) AggrName() string { return "max" }
Expand Down
14 changes: 14 additions & 0 deletions go/vt/sqlparser/ast_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ func (er *astRewriter) rewriteUp(cursor *Cursor) bool {
er.rewriteShowBasic(node)
case *ExistsExpr:
er.existsRewrite(cursor, node)
case DistinctableAggr:
er.rewriteDistinctableAggr(cursor, node)
}
return true
}
Expand Down Expand Up @@ -683,6 +685,18 @@ func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) {
sel.GroupBy = nil
}

// rewriteDistinctableAggr removed Distinct from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler.
func (er *astRewriter) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) {
if !node.IsDistinct() {
return
}
switch aggr := node.(type) {
case *Max, *Min:
aggr.SetDistinct(false)
er.bindVars.NoteRewrite()
}
}

func bindVarExpression(name string) Expr {
return NewArgument(name)
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/sqlparser/ast_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ func TestRewrites(in *testing.T) {
}, {
in: "SELECT id, name, salary FROM user_details",
expected: "SELECT id, name, salary FROM (select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id) as user_details",
}, {
in: "select max(distinct c1), min(distinct c2), avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl",
expected: "select max(c1) as `max(distinct c1)`, min(c2) as `min(distinct c2)`, avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl",
}, {
in: "SHOW VARIABLES",
expected: "SHOW VARIABLES",
Expand Down
88 changes: 86 additions & 2 deletions go/vt/vtgate/engine/scalar_aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/utils"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
)

Expand Down Expand Up @@ -255,6 +254,91 @@ func TestScalarGroupConcatWithAggrOnEngine(t *testing.T) {
}
}

// TestScalarDistinctAggr tests distinct aggregation on engine.
func TestScalarDistinctAggrOnEngine(t *testing.T) {
fields := sqltypes.MakeTestFields(
"value|value",
"int64|int64",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"100|100",
"200|200",
"200|200",
"400|400",
"400|400",
"600|600",
)}}

oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)"),
NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)"),
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", results.Rows))
}

func TestScalarDistinctPushedDown(t *testing.T) {
fields := sqltypes.MakeTestFields(
"count(distinct value)|sum(distinct value)",
"int64|decimal",
)

fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"2|200",
"6|400",
"3|700",
"1|10",
"7|30",
"8|90",
)}}

countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)")
countAggr.OrigOpcode = AggregateCountDistinct
sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)")
sumAggr.OrigOpcode = AggregateSumDistinct
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{
countAggr,
sumAggr,
},
Input: fp,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
require.NoError(t, err)
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", qr.Rows))

fp.rewind()
results := &sqltypes.Result{}
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
if qr.Fields != nil {
results.Fields = qr.Fields
}
results.Rows = append(results.Rows, qr.Rows...)
return nil
})
require.NoError(t, err)
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", results.Rows))
}

// TestScalarGroupConcat tests group_concat with partial aggregation on engine.
func TestScalarGroupConcat(t *testing.T) {
fields := sqltypes.MakeTestFields(
Expand Down
92 changes: 70 additions & 22 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func (a *Aggregator) aggregateTheAggregates() {
func aggregateTheAggregate(a *Aggregator, i int) {
aggr := a.Aggregations[i]
switch aggr.OpCode {
case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct:
// All count variations turn into SUM above the Route.
case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct, opcode.AggregateSumDistinct:
// All count variations turn into SUM above the Route. This is also applied for Sum distinct when it is pushed down.
// Think of it as we are SUMming together a bunch of distributed COUNTs.
aggr.OriginalOpCode, aggr.OpCode = aggr.OpCode, opcode.AggregateSum
a.Aggregations[i] = aggr
Expand Down Expand Up @@ -115,37 +115,72 @@ func pushDownAggregationThroughRoute(

// pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down
func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error {
for i, aggregation := range aggregator.Aggregations {
if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) {
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation)
canPushDownDistinctAggr, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return err
}

distinctAggrGroupByAdded := false

for i, aggr := range aggregator.Aggregations {
if !aggr.Distinct || canPushDownDistinctAggr {
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggr)
aggregateTheAggregate(aggregator, i)
continue
}
innerExpr := aggregation.Func.GetArg()

if aggregator.DistinctExpr != nil {
if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) {
// we can handle multiple distinct aggregations, as long as they are aggregating on the same expression
aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr)
continue
}
return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original)))
}

// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
aggregator.DistinctExpr = innerExpr
aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
aeDistinctExpr := aeWrap(distinctExpr)
aggrBelowRoute.Columns[aggr.ColOffset] = aeDistinctExpr

aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr
// We handle a distinct aggregation by turning it into a group by and
// doing the aggregating on the vtgate level instead
// Adding to group by can be done only once even though there are multiple distinct aggregation with same expression.
if !distinctAggrGroupByAdded {
groupBy := NewGroupBy(distinctExpr, distinctExpr, aeDistinctExpr)
groupBy.ColOffset = aggr.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
distinctAggrGroupByAdded = true
}
}

groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
groupBy.ColOffset = aggregation.ColOffset
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
if !canPushDownDistinctAggr {
aggregator.DistinctExpr = distinctExpr
}

return nil
}

func checkIfWeCanPushDown(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr, error) {
canPushDown := true
var distinctExpr sqlparser.Expr
var differentExpr *sqlparser.AliasedExpr

for _, aggr := range aggregator.Aggregations {
if !aggr.Distinct {
continue
}

innerExpr := aggr.Func.GetArg()
if !exprHasUniqueVindex(ctx, innerExpr) {
canPushDown = false
}
if distinctExpr == nil {
distinctExpr = innerExpr
}
if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) {
differentExpr = aggr.Original
}
}

if !canPushDown && differentExpr != nil {
return false, nil, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr)))
}

return canPushDown, distinctExpr, nil
}

func pushDownAggregationThroughFilter(
ctx *plancontext.PlanningContext,
aggregator *Aggregator,
Expand Down Expand Up @@ -411,6 +446,18 @@ func splitAggrColumnsToLeftAndRight(
outerJoin: join.LeftJoin,
}

canPushDownDistinctAggr, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
if err != nil {
return nil, nil, err
}

// Distinct aggregation cannot be pushed down in the join.
// We keep node of the distinct aggregation expression to be used later for ordering.
if !canPushDownDistinctAggr {
aggregator.DistinctExpr = distinctExpr
return nil, nil, errAbortAggrPushing
}

outer:
// we prefer adding the aggregations in the same order as the columns are declared
for colIdx, col := range aggregator.Columns {
Expand Down Expand Up @@ -509,7 +556,8 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er
// this is only used for SHOW GTID queries that will never contain joins
return vterrors.VT13001("cannot do join with vgtid")
case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
return errAbortAggrPushing
// we are not going to see values multiple times, so we don't need to multiply with the count(*) from the other side
return ab.handlePushThroughAggregation(ctx, aggr)
default:
return errHorizonNotPlanned()
}
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ type (
Grouping []GroupBy
Aggregations []Aggr

// We support a single distinct aggregation per aggregator. It is stored here
// We support a single distinct aggregation per aggregator. It is stored here.
// When planning the ordering that the OrderedAggregate will require,
// this needs to be the last ORDER BY expression
DistinctExpr sqlparser.Expr

// Pushed will be set to true once this aggregation has been pushed deeper in the tree
Expand Down
Loading