Skip to content

Commit

Permalink
on aggregate count and sum splittling create new aggr and update the …
Browse files Browse the repository at this point in the history
…column offset based on where it is pushed

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Jun 6, 2023
1 parent faae08e commit 21df396
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 103 deletions.
98 changes: 44 additions & 54 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,12 @@ func splitAggrColumnsToLeftAndRight(
outer:
// we prefer adding the aggregations in the same order as the columns are declared
for colIdx, col := range aggregator.Columns {
for aggrIdx, aggr := range aggregator.Aggregations {
for _, aggr := range aggregator.Aggregations {
if aggr.ColOffset == colIdx {
aggrToKeep, err := builder.handleAggr(ctx, aggr)
err := builder.handleAggr(ctx, aggr)
if err != nil {
return nil, nil, err
}
aggregator.Aggregations[aggrIdx] = aggrToKeep
continue outer
}
}
Expand Down Expand Up @@ -426,19 +425,19 @@ func (p *joinPusher) countStar(ctx *plancontext.PlanningContext) (*sqlparser.Ali
return p.csAE, true
}

func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) {
func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) error {
switch aggr.OpCode {
case opcode.AggregateCountStar:
return ab.handleCountStar(ctx, aggr)
ab.handleCountStar(ctx, aggr)
return nil
case opcode.AggregateMax, opcode.AggregateMin, opcode.AggregateRandom:
return ab.handlePushThroughAggregation(ctx, aggr)
case opcode.AggregateCount, opcode.AggregateSum:
return ab.handleAggrWithCountStarMultiplier(ctx, aggr)

case opcode.AggregateUnassigned:
return Aggr{}, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original)))
return vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original)))
default:
return Aggr{}, errHorizonNotPlanned()
return errHorizonNotPlanned()
}
}

Expand All @@ -460,30 +459,58 @@ func (ab *aggBuilder) pushThroughRight(aggr Aggr) {
})
}

func (ab *aggBuilder) handlePushThroughAggregation(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) {
func (ab *aggBuilder) handlePushThroughAggregation(ctx *plancontext.PlanningContext, aggr Aggr) error {
ab.proj.addUnexploredExpr(aggr.Original, aggr.Original.Expr)

deps := ctx.SemTable.RecursiveDeps(aggr.Original.Expr)
switch {
case deps.IsSolvedBy(ab.lhs.tableID):
ab.pushThroughLeft(aggr)
return aggr, nil
case deps.IsSolvedBy(ab.rhs.tableID):
ab.pushThroughRight(aggr)
return aggr, nil
default:
return Aggr{}, vterrors.VT12001("aggregation on columns from different sources: " + sqlparser.String(aggr.Original.Expr))
return vterrors.VT12001("aggregation on columns from different sources: " + sqlparser.String(aggr.Original.Expr))
}
return nil
}

func (ab *aggBuilder) handleCountStar(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) {
func (ab *aggBuilder) handleCountStar(ctx *plancontext.PlanningContext, aggr Aggr) {
// Projection is necessary since we are going to need to do arithmetics to summarize the aggregates
ab.projectionRequired = true

// Add the aggregate to both sides of the join.
lhsAE := ab.leftCountStar(ctx)
rhsAE := ab.rightCountStar(ctx)

ab.buildProjectionForAggr(lhsAE, rhsAE, aggr)
}

func (ab *aggBuilder) handleAggrWithCountStarMultiplier(ctx *plancontext.PlanningContext, aggr Aggr) error {
ab.projectionRequired = true

deps := ctx.SemTable.RecursiveDeps(aggr.Original.Expr)

var lhsAE, rhsAE *sqlparser.AliasedExpr
switch {
case deps.IsSolvedBy(ab.lhs.tableID):
ab.pushThroughLeft(aggr)
lhsAE = aggr.Original
rhsAE = ab.rightCountStar(ctx)

case deps.IsSolvedBy(ab.rhs.tableID):
ab.pushThroughRight(aggr)
lhsAE = ab.leftCountStar(ctx)
rhsAE = aggr.Original

default:
return errHorizonNotPlanned()
}

ab.buildProjectionForAggr(lhsAE, rhsAE, aggr)
return nil
}

func (ab *aggBuilder) buildProjectionForAggr(lhsAE *sqlparser.AliasedExpr, rhsAE *sqlparser.AliasedExpr, aggr Aggr) {
// We expect the expressions to be different on each side of the join, otherwise it's an error.
if lhsAE.Expr == rhsAE.Expr {
panic(fmt.Sprintf("Need the two produced expressions to be different. %T %T", lhsAE, rhsAE))
Expand Down Expand Up @@ -513,45 +540,6 @@ func (ab *aggBuilder) handleCountStar(ctx *plancontext.PlanningContext, aggr Agg
}

ab.proj.addUnexploredExpr(projAE, projExpr)
return aggr, nil
}

func (ab *aggBuilder) handleAggrWithCountStarMultiplier(ctx *plancontext.PlanningContext, aggr Aggr) (Aggr, error) {
ab.projectionRequired = true

expr := aggr.Original.Expr
deps := ctx.SemTable.RecursiveDeps(expr)
var otherSide sqlparser.Expr

switch {
case deps.IsSolvedBy(ab.lhs.tableID):
ab.pushThroughLeft(aggr)
ae := ab.rightCountStar(ctx)
otherSide = ae.Expr

case deps.IsSolvedBy(ab.rhs.tableID):
ab.pushThroughRight(aggr)
ae := ab.leftCountStar(ctx)
otherSide = ae.Expr

default:
return Aggr{}, errHorizonNotPlanned()
}

if ab.outerJoin {
otherSide = coalesceFunc(otherSide)
}

projAE := &sqlparser.AliasedExpr{
Expr: aggr.Original.Expr,
As: sqlparser.NewIdentifierCI(aggr.Original.ColumnName()),
}
ab.proj.addUnexploredExpr(projAE, &sqlparser.BinaryExpr{
Operator: sqlparser.MultOp,
Left: expr,
Right: otherSide,
})
return aggr, nil
}

func coalesceFunc(e sqlparser.Expr) sqlparser.Expr {
Expand Down Expand Up @@ -584,8 +572,10 @@ func (p *joinPusher) addAggr(ctx *plancontext.PlanningContext, aggr Aggr) sqlpar
// pushThroughAggr pushes through an aggregation without changing dependencies.
// Can be used for aggregations we can push in one piece
func (p *joinPusher) pushThroughAggr(aggr Aggr) {
p.pushed.Columns = append(p.pushed.Columns, aggr.Original)
p.pushed.Aggregations = append(p.pushed.Aggregations, aggr)
newAggr := NewAggr(aggr.OpCode, aggr.Func, aggr.Original, aggr.Alias)
newAggr.ColOffset = len(p.pushed.Columns)
p.pushed.Columns = append(p.pushed.Columns, newAggr.Original)
p.pushed.Aggregations = append(p.pushed.Aggregations, newAggr)
}

// addGrouping creates a copy of the given GroupBy, updates its column offset to point to the correct location in the new Aggregator,
Expand Down
92 changes: 90 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -3374,7 +3374,7 @@
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 2] as a",
"[COLUMN 0] * [COLUMN 1] as count(user_extra.a)",
"[COLUMN 1] * [COLUMN 0] as count(user_extra.a)",
"[COLUMN 3] as weight_string(`user`.a)"
],
"Inputs": [
Expand Down Expand Up @@ -3437,7 +3437,7 @@
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] * [COLUMN 1] as count(u.textcol1)",
"[COLUMN 2] * [COLUMN 3] as count(ue.foo)",
"[COLUMN 3] * [COLUMN 2] as count(ue.foo)",
"[COLUMN 4] as bar",
"[COLUMN 5] as weight_string(us.bar)"
],
Expand Down Expand Up @@ -6299,5 +6299,93 @@
"user.user"
]
}
},
{
"comment": "multiple count star and a count with 3 table join",
"query": "select count(*), count(*), count(u.col) from user u, user u2, user_extra ue",
"v3-plan": "VT12001: unsupported: cross-shard query with aggregates",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select count(*), count(*), count(u.col) from user u, user u2, user_extra ue",
"Instructions": {
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum_count_star(0) AS count(*), sum_count_star(1) AS count(*), sum_count(2) AS count(u.col)",
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] * [COLUMN 1] as count(*)",
"[COLUMN 0] * [COLUMN 1] as count(*)",
"[COLUMN 0] * [COLUMN 2] as count(u.col)"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,R:1",
"TableName": "user_extra_`user`_`user`",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra as ue where 1 != 1",
"Query": "select count(*) from user_extra as ue",
"Table": "user_extra"
},
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] * [COLUMN 1] as count(*)",
"[COLUMN 2] * [COLUMN 1] as count(u.col)"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,L:1",
"TableName": "`user`_`user`",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*), count(u.col) from `user` as u where 1 != 1 group by .0",
"Query": "select count(*), count(u.col) from `user` as u group by .0",
"Table": "`user`"
},
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from `user` as u2 where 1 != 1 group by .0",
"Query": "select count(*) from `user` as u2 group by .0",
"Table": "`user`"
}
]
}
]
}
]
}
]
}
]
},
"TablesUsed": [
"user.user",
"user.user_extra"
]
}
}
]
Loading

0 comments on commit 21df396

Please sign in to comment.