Skip to content

Commit

Permalink
[release-19.0] fix: order by subquery planning (#16049) (#16132)
Browse files Browse the repository at this point in the history
Co-authored-by: Harshit Gangal <harshit@planetscale.com>
Co-authored-by: Andres Taylor <andres@planetscale.com>
Co-authored-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
4 people authored Jun 13, 2024
1 parent 741a026 commit 2d16a00
Show file tree
Hide file tree
Showing 26 changed files with 584 additions and 213 deletions.
18 changes: 17 additions & 1 deletion go/test/endtoend/vtgate/queries/subquery/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,39 @@ create table t1
id2 bigint,
primary key (id1)
) Engine = InnoDB;

create table t1_id2_idx
(
id2 bigint,
keyspace_id varbinary(10),
primary key (id2)
) Engine = InnoDB;

create table t2
(
id3 bigint,
id4 bigint,
primary key (id3)
) Engine = InnoDB;

create table t2_id4_idx
(
id bigint not null auto_increment,
id4 bigint,
id3 bigint,
primary key (id),
key idx_id4 (id4)
) Engine = InnoDB;
) Engine = InnoDB;

CREATE TABLE user
(
id INT PRIMARY KEY,
name VARCHAR(100)
);

CREATE TABLE user_extra
(
user_id INT,
extra_info VARCHAR(100),
PRIMARY KEY (user_id, extra_info)
);
44 changes: 44 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package subquery

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -188,3 +189,46 @@ func TestSubqueryInDerivedTable(t *testing.T) {
mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`)
mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`)
}

func TestSubqueries(t *testing.T) {
// This method tests many types of subqueries. The queries should move to a vitess-tester test file once we have a way to run them.
// The commented out queries are failing because of wrong types being returned.
// The tests are commented out until the issue is fixed.
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
mcmp, closer := start(t)
defer closer()
queries := []string{
`INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')`,
`INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')`,
`SELECT (SELECT COUNT(*) FROM user_extra) AS order_count, id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra)`,
`SELECT id, (SELECT COUNT(*) FROM user_extra) AS order_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra) ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count, id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, COUNT(*) FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, round(MAX(id + (SELECT COUNT(*) FROM user_extra where user_id = 42))) as r FROM user WHERE id = 42 GROUP BY id ORDER BY r`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * 2 AS double_extra_count FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE LENGTH(extra_info) > 4)`,
`SELECT id, COUNT(*) FROM user GROUP BY id HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user_extra.user_id = user.id) + 1`,
`SELECT id, name FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * id`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) + id AS extra_count_plus_id FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info1') OR id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info2')`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, SUM(id) AS sum_ids FROM user GROUP BY id, name ORDER BY (SELECT COUNT(*) FROM user_extra)`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, AVG(id) AS avg_ids FROM user GROUP BY id, name HAVING (SELECT SUM(LENGTH(extra_info)) FROM user_extra) > 10`,
`SELECT id, name, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, MAX(id) AS max_id FROM user WHERE id IN (SELECT user_id FROM user_extra) GROUP BY id, name`,
`SELECT id, name, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, MIN(id) AS min_id FROM user GROUP BY id, name ORDER BY (SELECT MAX(LENGTH(extra_info)) FROM user_extra)`,
`SELECT id, name, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT MIN(LENGTH(extra_info)) FROM user_extra) < 5`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, AVG(id) AS avg_ids FROM user WHERE id > (SELECT COUNT(*) FROM user_extra) GROUP BY id, name`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, COUNT(id) AS count_ids FROM user GROUP BY id, name ORDER BY (SELECT SUM(LENGTH(extra_info)) FROM user_extra)`,
// `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT AVG(LENGTH(extra_info)) FROM user_extra) > 2`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) + id AS total_extra_count_plus_id, AVG(id) AS avg_ids FROM user WHERE id < (SELECT MAX(user_id) FROM user_extra) GROUP BY id, name`,
}

for idx, query := range queries {
mcmp.Run(fmt.Sprintf("%d %s", idx, query), func(mcmp *utils.MySQLCompare) {
mcmp.Exec(query)
})
}
}
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
"autocommit": "true"
},
"owner": "t2"
},
"xxhash": {
"type": "xxhash"
}
},
"tables": {
Expand Down Expand Up @@ -64,6 +67,34 @@
"name": "hash"
}
]
},
"user_extra": {
"name": "user_extra",
"column_vindexes": [
{
"columns": [
"user_id",
"extra_info"
],
"type": "xxhash",
"name": "xxhash",
"vindex": null
}
]
},
"user": {
"name": "user",
"column_vindexes": [
{
"columns": [
"id"
],
"type": "xxhash",
"name": "xxhash",
"vindex": null
}
]
}

}
}
47 changes: 47 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2864,6 +2864,8 @@ type (
Expr
GetArg() Expr
GetArgs() Exprs
SetArg(expr Expr)
SetArgs(exprs Exprs) error
// AggrName returns the lower case string representing this aggregation function
AggrName() string
}
Expand Down Expand Up @@ -3375,6 +3377,51 @@ func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} }
func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} }
func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} }

func (min *Min) SetArg(expr Expr) { min.Arg = expr }
func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr }
func (max *Max) SetArg(expr Expr) { max.Arg = expr }
func (avg *Avg) SetArg(expr Expr) { avg.Arg = expr }
func (*CountStar) SetArg(expr Expr) {}
func (count *Count) SetArg(expr Expr) { count.Args = Exprs{expr} }
func (grpConcat *GroupConcatExpr) SetArg(expr Expr) { grpConcat.Exprs = Exprs{expr} }
func (bAnd *BitAnd) SetArg(expr Expr) { bAnd.Arg = expr }
func (bOr *BitOr) SetArg(expr Expr) { bOr.Arg = expr }
func (bXor *BitXor) SetArg(expr Expr) { bXor.Arg = expr }
func (std *Std) SetArg(expr Expr) { std.Arg = expr }
func (stdD *StdDev) SetArg(expr Expr) { stdD.Arg = expr }
func (stdP *StdPop) SetArg(expr Expr) { stdP.Arg = expr }
func (stdS *StdSamp) SetArg(expr Expr) { stdS.Arg = expr }
func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr }
func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr }
func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr }
func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr }

func (min *Min) SetArgs(exprs Exprs) error { return setFuncArgs(min, exprs, "MIN") }
func (sum *Sum) SetArgs(exprs Exprs) error { return setFuncArgs(sum, exprs, "SUM") }
func (max *Max) SetArgs(exprs Exprs) error { return setFuncArgs(max, exprs, "MAX") }
func (avg *Avg) SetArgs(exprs Exprs) error { return setFuncArgs(avg, exprs, "AVG") }
func (*CountStar) SetArgs(Exprs) error { return nil }
func (bAnd *BitAnd) SetArgs(exprs Exprs) error { return setFuncArgs(bAnd, exprs, "BIT_AND") }
func (bOr *BitOr) SetArgs(exprs Exprs) error { return setFuncArgs(bOr, exprs, "BIT_OR") }
func (bXor *BitXor) SetArgs(exprs Exprs) error { return setFuncArgs(bXor, exprs, "BIT_XOR") }
func (std *Std) SetArgs(exprs Exprs) error { return setFuncArgs(std, exprs, "STD") }
func (stdD *StdDev) SetArgs(exprs Exprs) error { return setFuncArgs(stdD, exprs, "STDDEV") }
func (stdP *StdPop) SetArgs(exprs Exprs) error { return setFuncArgs(stdP, exprs, "STDDEV_POP") }
func (stdS *StdSamp) SetArgs(exprs Exprs) error { return setFuncArgs(stdS, exprs, "STDDEV_SAMP") }
func (varP *VarPop) SetArgs(exprs Exprs) error { return setFuncArgs(varP, exprs, "VAR_POP") }
func (varS *VarSamp) SetArgs(exprs Exprs) error { return setFuncArgs(varS, exprs, "VAR_SAMP") }
func (variance *Variance) SetArgs(exprs Exprs) error { return setFuncArgs(variance, exprs, "VARIANCE") }
func (av *AnyValue) SetArgs(exprs Exprs) error { return setFuncArgs(av, exprs, "ANY_VALUE") }

func (count *Count) SetArgs(exprs Exprs) error {
count.Args = exprs
return nil
}
func (grpConcat *GroupConcatExpr) SetArgs(exprs Exprs) error {
grpConcat.Exprs = exprs
return nil
}

func (sum *Sum) IsDistinct() bool { return sum.Distinct }
func (min *Min) IsDistinct() bool { return min.Distinct }
func (max *Max) IsDistinct() bool { return max.Distinct }
Expand Down
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,15 @@ func ContainsAggregation(e SQLNode) bool {
return hasAggregates
}

// setFuncArgs sets the arguments for the aggregation function, while checking that there is only one argument
func setFuncArgs(aggr AggrFunc, exprs Exprs, name string) error {
if len(exprs) != 1 {
return vterrors.VT03001(name)
}
aggr.SetArg(exprs[0])
return nil
}

// GetFirstSelect gets the first select statement
func GetFirstSelect(selStmt SelectStatement) *Select {
if selStmt == nil {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import (
"sync"
"testing"

fuzz "github.com/AdaLogics/go-fuzz-headers"

"vitess.io/vitess/go/json2"
"vitess.io/vitess/go/sqltypes"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
"vitess.io/vitess/go/vt/vtgate/vindexes"

fuzz "github.com/AdaLogics/go-fuzz-headers"
)

var initter sync.Once
Expand Down
23 changes: 11 additions & 12 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,21 @@ func reachedPhase(ctx *plancontext.PlanningContext, p Phase) bool {
// Any columns that are needed to evaluate the subquery needs to be added as
// grouping columns to the aggregation being pushed down, and then after the
// subquery evaluation we are free to reassemble the total aggregation values.
// This is very similar to how we push aggregation through an shouldRun-join.
// This is very similar to how we push aggregation through an apply-join.
func pushAggregationThroughSubquery(
ctx *plancontext.PlanningContext,
rootAggr *Aggregator,
src *SubQueryContainer,
) (Operator, *ApplyResult) {
pushedAggr := rootAggr.Clone([]Operator{src.Outer}).(*Aggregator)
pushedAggr.Original = false
pushedAggr.Pushed = false

pushedAggr := rootAggr.SplitAggregatorBelowOperators(ctx, []Operator{src.Outer})
for _, subQuery := range src.Inner {
lhsCols := subQuery.OuterExpressionsNeeded(ctx, src.Outer)
for _, colName := range lhsCols {
idx := slices.IndexFunc(pushedAggr.Columns, func(ae *sqlparser.AliasedExpr) bool {
findColName := func(ae *sqlparser.AliasedExpr) bool {
return ctx.SemTable.EqualsExpr(ae.Expr, colName)
})
if idx >= 0 {
}
if slices.IndexFunc(pushedAggr.Columns, findColName) >= 0 {
// we already have the column, no need to push it again
continue
}
pushedAggr.addColumnWithoutPushing(ctx, aeWrap(colName), true)
Expand All @@ -114,8 +112,10 @@ func pushAggregationThroughSubquery(

src.Outer = pushedAggr

for _, aggregation := range pushedAggr.Aggregations {
aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...)
for _, aggr := range pushedAggr.Aggregations {
// we rewrite columns in the aggregation to use the argument form of the subquery
aggr.Original.Expr = rewriteColNameToArgument(ctx, aggr.Original.Expr, aggr.SubQueryExpression, src.Inner...)
pushedAggr.Columns[aggr.ColOffset].Expr = rewriteColNameToArgument(ctx, pushedAggr.Columns[aggr.ColOffset].Expr, aggr.SubQueryExpression, src.Inner...)
}

if !rootAggr.Original {
Expand Down Expand Up @@ -150,7 +150,7 @@ func pushAggregationThroughRoute(
route *Route,
) (Operator, *ApplyResult) {
// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs())
aggrBelowRoute := aggregator.SplitAggregatorBelowOperators(ctx, route.Inputs())
aggrBelowRoute.Aggregations = nil

pushAggregations(ctx, aggregator, aggrBelowRoute)
Expand Down Expand Up @@ -256,7 +256,6 @@ func pushAggregationThroughFilter(
pushedAggr := aggregator.Clone([]Operator{filter.Source}).(*Aggregator)
pushedAggr.Pushed = false
pushedAggr.Original = false

withNextColumn:
for _, col := range columnsNeeded {
for _, gb := range pushedAggr.Grouping {
Expand Down
42 changes: 40 additions & 2 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,21 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
return nil
}

func (aggr Aggr) setPushColumn(exprs sqlparser.Exprs) {
if aggr.Func == nil {
if len(exprs) > 1 {
panic(vterrors.VT13001(fmt.Sprintf("unexpected number of expression in an random aggregation: %s", sqlparser.String(exprs))))
}
aggr.Original.Expr = exprs[0]
return
}

err := aggr.Func.SetArgs(exprs)
if err != nil {
panic(err)
}
}

func (aggr Aggr) getPushColumn() sqlparser.Expr {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
Expand All @@ -311,6 +326,17 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr {
}
}

func (aggr Aggr) getPushColumnExprs() sqlparser.Exprs {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
return sqlparser.Exprs{aggr.Original.Expr}
case opcode.AggregateCountStar:
return sqlparser.Exprs{sqlparser.NewIntLiteral("1")}
default:
return aggr.Func.GetArgs()
}
}

func (a *Aggregator) planOffsetsNotPushed(ctx *plancontext.PlanningContext) {
a.Source = newAliasedProjection(a.Source)
// we need to keep things in the column order, so we can't iterate over the aggregations or groupings
Expand Down Expand Up @@ -408,14 +434,26 @@ func (a *Aggregator) internalAddColumn(ctx *plancontext.PlanningContext, aliased
return offset
}

// SplitAggregatorBelowRoute returns the aggregator that will live under the Route.
// SplitAggregatorBelowOperators returns the aggregator that will live under the Route.
// This is used when we are splitting the aggregation so one part is done
// at the mysql level and one part at the vtgate level
func (a *Aggregator) SplitAggregatorBelowRoute(input []Operator) *Aggregator {
func (a *Aggregator) SplitAggregatorBelowOperators(ctx *plancontext.PlanningContext, input []Operator) *Aggregator {
newOp := a.Clone(input).(*Aggregator)
newOp.Pushed = false
newOp.Original = false
newOp.DT = nil

// We need to make sure that the columns are cloned so that the original operator is not affected
// by the changes we make to the new operator
newOp.Columns = slice.Map(a.Columns, func(from *sqlparser.AliasedExpr) *sqlparser.AliasedExpr {
return ctx.SemTable.Clone(from).(*sqlparser.AliasedExpr)
})
for idx, aggr := range newOp.Aggregations {
newOp.Aggregations[idx].Original = ctx.SemTable.Clone(aggr.Original).(*sqlparser.AliasedExpr)
}
for idx, gb := range newOp.Grouping {
newOp.Grouping[idx].Inner = ctx.SemTable.Clone(gb.Inner).(sqlparser.Expr)
}
return newOp
}

Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/planbuilder/operators/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import (
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

type compactable interface {
// Compact implement this interface for operators that have easy to see optimisations
Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult)
}

// compact will optimise the operator tree into a smaller but equivalent version
func compact(ctx *plancontext.PlanningContext, op Operator) Operator {
type compactable interface {
// Compact implement this interface for operators that have easy to see optimisations
Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult)
}

newOp := BottomUp(op, TableID, func(op Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
newOp, ok := op.(compactable)
if !ok {
Expand Down
Loading

0 comments on commit 2d16a00

Please sign in to comment.