From 897cfa790b7fba4552937e32ce3086895d63bf85 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 13 Oct 2021 19:43:11 +0530 Subject: [PATCH 1/5] added failing testcases as unit tests Signed-off-by: Manan Gupta --- go/vt/vtgate/planbuilder/gen4_planner_test.go | 84 +++++++++++++++++++ go/vt/vtgate/planbuilder/route_planning.go | 14 ++-- go/vt/vtgate/semantics/semantic_state.go | 5 +- 3 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/gen4_planner_test.go diff --git a/go/vt/vtgate/planbuilder/gen4_planner_test.go b/go/vt/vtgate/planbuilder/gen4_planner_test.go new file mode 100644 index 00000000000..498bb8cfbe0 --- /dev/null +++ b/go/vt/vtgate/planbuilder/gen4_planner_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestBindingSubquery(t *testing.T) { + testcases := []struct { + query string + numOfTablesReq int + extractor func(p *sqlparser.Select) sqlparser.Expr + rewrite bool + }{ + { + query: "select (select col from tabl limit 1) as a from foo join tabl order by a + 1", + numOfTablesReq: 0, + extractor: func(sel *sqlparser.Select) sqlparser.Expr { + return sel.OrderBy[0].Expr + }, + rewrite: true, + }, { + query: "select t.a from (select (select col from tabl limit 1) as a from foo join tabl) t", + numOfTablesReq: 0, + extractor: func(sel *sqlparser.Select) sqlparser.Expr { + return extractExpr(sel, 0) + }, + rewrite: true, + }, { + query: "select (select col from tabl where foo.id = 4 limit 1) as a from foo join tabl", + numOfTablesReq: 1, + extractor: func(sel *sqlparser.Select) sqlparser.Expr { + return extractExpr(sel, 0) + }, + rewrite: false, + }, + } + for _, testcase := range testcases { + t.Run(testcase.query, func(t *testing.T) { + parse, err := sqlparser.Parse(testcase.query) + require.NoError(t, err) + selStmt := parse.(*sqlparser.Select) + semTable, err := semantics.Analyze(selStmt, "d", &semantics.FakeSI{ + Tables: map[string]*vindexes.Table{ + "tabl": {Name: sqlparser.NewTableIdent("tabl")}, + "foo": {Name: sqlparser.NewTableIdent("foo")}, + }, + }) + require.NoError(t, err) + if testcase.rewrite { + err = queryRewrite(semTable, sqlparser.NewReservedVars("vt", make(sqlparser.BindVars)), selStmt) + require.NoError(t, err) + } + expr := testcase.extractor(selStmt) + tableset := semTable.RecursiveDeps(expr) + require.Equal(t, testcase.numOfTablesReq, tableset.NumberOfTables()) + }) + } +} + +func extractExpr(in *sqlparser.Select, idx int) sqlparser.Expr { + return in.SelectExprs[idx].(*sqlparser.AliasedExpr).Expr +} diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 0ddab3923df..9d8d374c581 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -199,7 +199,7 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner if err != nil { return nil, err } - return rt, rewriteSubqueryDependenciesForJoin(ctx, outerTree.rhs, outerTree, subQueryInner) + return rt, rewriteColumnsInSubqueryForJoin(ctx, outerTree.rhs, outerTree, subQueryInner) } merged, err = tryMergeSubQuery(ctx, outerTree.lhs, subq, subQueryInner, joinPredicates, newMergefunc) if err != nil { @@ -215,7 +215,7 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner if err != nil { return nil, err } - return rt, rewriteSubqueryDependenciesForJoin(ctx, outerTree.lhs, outerTree, subQueryInner) + return rt, rewriteColumnsInSubqueryForJoin(ctx, outerTree.lhs, outerTree, subQueryInner) } merged, err = tryMergeSubQuery(ctx, outerTree.rhs, subq, subQueryInner, joinPredicates, newMergefunc) if err != nil { @@ -231,13 +231,15 @@ func tryMergeSubQuery(ctx *planningContext, outer, subq queryTree, subQueryInner } } +// rewriteColumnsInSubqueryForJoin rewrites the columns that appear from the other side +// of the join. For example, let's say we merged a subquery on the right side of a join tree +// If it was using any columns from the left side then they need to be replaced by bind variables supplied +// from that side. // outerTree is the joinTree within whose children the subquery lives in // the child of joinTree which does not contain the subquery is the otherTree -func rewriteSubqueryDependenciesForJoin(ctx *planningContext, otherTree queryTree, outerTree *joinTree, subQueryInner *abstract.SubQueryInner) error { - // first we find the other side of the tree by comparing the tableIDs - // other side is RHS if the subquery is in the LHS, otherwise it is LHS +func rewriteColumnsInSubqueryForJoin(ctx *planningContext, otherTree queryTree, outerTree *joinTree, subQueryInner *abstract.SubQueryInner) error { var rewriteError error - // go over the entire where expression in the subquery + // go over the entire expression in the subquery sqlparser.Rewrite(subQueryInner.ExtractedSubquery.Original, func(cursor *sqlparser.Cursor) bool { sqlNode := cursor.Node() switch node := sqlNode.(type) { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index bd2a91e327e..2d37661ab36 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -219,8 +219,8 @@ func (d ExprDependencies) Dependencies(expr sqlparser.Expr) (deps TableSet) { }() } - // During the original semantic analysis, all ColName:s were found and bound the the corresponding tables - // Here, we'll walk the expression tree and look to see if we can found any sub-expressions + // During the original semantic analysis, all ColNames were found and bound to the corresponding tables + // Here, we'll walk the expression tree and look to see if we can find any sub-expressions // that have already set dependencies. _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { expr, ok := node.(sqlparser.Expr) @@ -282,6 +282,7 @@ func (st *SemTable) FindSubqueryReference(subquery *sqlparser.Subquery) *sqlpars return nil } +// GetSubqueryNeedingRewrite returns a list of sub-queries that need to be rewritten func (st *SemTable) GetSubqueryNeedingRewrite() []*sqlparser.ExtractedSubquery { var res []*sqlparser.ExtractedSubquery for _, extractedSubquery := range st.SubqueryRef { From 39540838b16c77f3a6e48dd056dd8e0514f9248e Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 14 Oct 2021 13:45:31 +0530 Subject: [PATCH 2/5] fixed subquery binding Signed-off-by: Manan Gupta --- go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 2 +- go/vt/sqlparser/ast_equals.go | 2 +- go/vt/sqlparser/ast_rewrite.go | 4 +- go/vt/sqlparser/ast_visit.go | 2 +- go/vt/sqlparser/cached_size.go | 6 +- go/vt/sqlparser/precedence_test.go | 2 +- go/vt/vtgate/planbuilder/abstract/operator.go | 2 +- go/vt/vtgate/planbuilder/gen4_planner_test.go | 22 +++--- .../planbuilder/querytree_transformers.go | 2 +- go/vt/vtgate/planbuilder/route_planning.go | 10 ++- go/vt/vtgate/planbuilder/routetree.go | 2 +- go/vt/vtgate/semantics/analyzer.go | 13 ++-- go/vt/vtgate/semantics/analyzer_test.go | 4 +- go/vt/vtgate/semantics/binder.go | 22 +++++- go/vt/vtgate/semantics/derived_table.go | 2 +- go/vt/vtgate/semantics/real_table.go | 2 +- go/vt/vtgate/semantics/semantic_state.go | 14 ++-- go/vt/vtgate/semantics/tabletset.go | 45 +++++++++++ go/vt/vtgate/semantics/tabletset_test.go | 76 +++++++++++++++++++ go/vt/vtgate/semantics/vindex_table.go | 2 +- go/vt/vtgate/semantics/vtable.go | 2 +- 22 files changed, 193 insertions(+), 47 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index b724a8b375a..e42227d52d6 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -1971,7 +1971,7 @@ type ( ArgName string HasValuesArg string OpCode int // this should really be engine.PulloutOpCode, but we cannot depend on engine :( - Subquery SelectStatement + Subquery *Subquery OtherSide Expr // represents the side of the comparison, this field will be nil if Original is not a comparison NeedsRewrite bool // tells whether we need to rewrite this subquery to Original or not } diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 375e0116132..96060511316 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -862,7 +862,7 @@ func CloneRefOfExtractedSubquery(n *ExtractedSubquery) *ExtractedSubquery { } out := *n out.Original = CloneExpr(n.Original) - out.Subquery = CloneSelectStatement(n.Subquery) + out.Subquery = CloneRefOfSubquery(n.Subquery) out.OtherSide = CloneExpr(n.OtherSide) return &out } diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 155409ac6b1..20794604648 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1568,7 +1568,7 @@ func EqualsRefOfExtractedSubquery(a, b *ExtractedSubquery) bool { a.OpCode == b.OpCode && a.NeedsRewrite == b.NeedsRewrite && EqualsExpr(a.Original, b.Original) && - EqualsSelectStatement(a.Subquery, b.Subquery) && + EqualsRefOfSubquery(a.Subquery, b.Subquery) && EqualsExpr(a.OtherSide, b.OtherSide) } diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index f26d237f4ac..747d7bfb6f2 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1906,8 +1906,8 @@ func (a *application) rewriteRefOfExtractedSubquery(parent SQLNode, node *Extrac }) { return false } - if !a.rewriteSelectStatement(node, node.Subquery, func(newNode, parent SQLNode) { - parent.(*ExtractedSubquery).Subquery = newNode.(SelectStatement) + if !a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*ExtractedSubquery).Subquery = newNode.(*Subquery) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index 2a110e53fff..c75bc920069 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1039,7 +1039,7 @@ func VisitRefOfExtractedSubquery(in *ExtractedSubquery, f Visit) error { if err := VisitExpr(in.Original, f); err != nil { return err } - if err := VisitSelectStatement(in.Subquery, f); err != nil { + if err := VisitRefOfSubquery(in.Subquery, f); err != nil { return err } if err := VisitExpr(in.OtherSide, f); err != nil { diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index ef25f7d2e11..7c4f311a396 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -944,10 +944,8 @@ func (cached *ExtractedSubquery) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.ArgName))) // field HasValuesArg string size += hack.RuntimeAllocSize(int64(len(cached.HasValuesArg))) - // field Subquery vitess.io/vitess/go/vt/sqlparser.SelectStatement - if cc, ok := cached.Subquery.(cachedObject); ok { - size += cc.CachedSize(true) - } + // field Subquery *vitess.io/vitess/go/vt/sqlparser.Subquery + size += cached.Subquery.CachedSize(true) // field OtherSide vitess.io/vitess/go/vt/sqlparser.Expr if cc, ok := cached.OtherSide.(cachedObject); ok { size += cc.CachedSize(true) diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index 7ad31f822fe..d6cbb8a845a 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -79,7 +79,7 @@ func TestNotInSubqueryPrecedence(t *testing.T) { ArgName: "arg1", HasValuesArg: "has_values1", OpCode: 1, - Subquery: subq.Select, + Subquery: subq, OtherSide: cmp.Left, } not.Expr = extracted diff --git a/go/vt/vtgate/planbuilder/abstract/operator.go b/go/vt/vtgate/planbuilder/abstract/operator.go index 93e1c88a036..8f2cfac35a8 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator.go +++ b/go/vt/vtgate/planbuilder/abstract/operator.go @@ -192,7 +192,7 @@ func createOperatorFromSelect(sel *sqlparser.Select, semTable *semantics.SemTabl if len(semTable.SubqueryMap[sel]) > 0 { resultantOp = &SubQuery{} for _, sq := range semTable.SubqueryMap[sel] { - subquerySelectStatement, isSel := sq.Subquery.(*sqlparser.Select) + subquerySelectStatement, isSel := sq.Subquery.Select.(*sqlparser.Select) if !isSel { return nil, semantics.Gen4NotSupportedF("UNION in subquery") } diff --git a/go/vt/vtgate/planbuilder/gen4_planner_test.go b/go/vt/vtgate/planbuilder/gen4_planner_test.go index 498bb8cfbe0..f9841b7f014 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner_test.go +++ b/go/vt/vtgate/planbuilder/gen4_planner_test.go @@ -28,28 +28,28 @@ import ( func TestBindingSubquery(t *testing.T) { testcases := []struct { - query string - numOfTablesReq int - extractor func(p *sqlparser.Select) sqlparser.Expr - rewrite bool + query string + requiredTableSet semantics.TableSet + extractor func(p *sqlparser.Select) sqlparser.Expr + rewrite bool }{ { - query: "select (select col from tabl limit 1) as a from foo join tabl order by a + 1", - numOfTablesReq: 0, + query: "select (select col from tabl limit 1) as a from foo join tabl order by a + 1", + requiredTableSet: semantics.EmptyTableSet(), extractor: func(sel *sqlparser.Select) sqlparser.Expr { return sel.OrderBy[0].Expr }, rewrite: true, }, { - query: "select t.a from (select (select col from tabl limit 1) as a from foo join tabl) t", - numOfTablesReq: 0, + query: "select t.a from (select (select col from tabl limit 1) as a from foo join tabl) t", + requiredTableSet: semantics.EmptyTableSet(), extractor: func(sel *sqlparser.Select) sqlparser.Expr { return extractExpr(sel, 0) }, rewrite: true, }, { - query: "select (select col from tabl where foo.id = 4 limit 1) as a from foo join tabl", - numOfTablesReq: 1, + query: "select (select col from tabl where foo.id = 4 limit 1) as a from foo", + requiredTableSet: semantics.SingleTableSet(0), extractor: func(sel *sqlparser.Select) sqlparser.Expr { return extractExpr(sel, 0) }, @@ -74,7 +74,7 @@ func TestBindingSubquery(t *testing.T) { } expr := testcase.extractor(selStmt) tableset := semTable.RecursiveDeps(expr) - require.Equal(t, testcase.numOfTablesReq, tableset.NumberOfTables()) + require.Equal(t, testcase.requiredTableSet, tableset) }) } } diff --git a/go/vt/vtgate/planbuilder/querytree_transformers.go b/go/vt/vtgate/planbuilder/querytree_transformers.go index e9759d2c6db..fa669ee6b5c 100644 --- a/go/vt/vtgate/planbuilder/querytree_transformers.go +++ b/go/vt/vtgate/planbuilder/querytree_transformers.go @@ -93,7 +93,7 @@ func transformSubqueryTree(ctx *planningContext, n *subqueryTree) (logicalPlan, if err != nil { return nil, err } - innerPlan, err = planHorizon(ctx, innerPlan, n.extracted.Subquery) + innerPlan, err = planHorizon(ctx, innerPlan, n.extracted.Subquery.Select) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 9d8d374c581..80574bcd899 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -46,7 +46,7 @@ func (c planningContext) isSubQueryToReplace(e sqlparser.Expr) bool { return false } for _, extractedSubq := range c.semTable.GetSubqueryNeedingRewrite() { - if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(&sqlparser.Subquery{Select: extractedSubq.Subquery}, ext) { + if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(&sqlparser.Subquery{Select: extractedSubq.Subquery.Select}, ext) { return true } } @@ -268,6 +268,14 @@ func rewriteColumnsInSubqueryForJoin(ctx *planningContext, otherTree queryTree, return true }, nil) + // update the dependencies for the subquery by removing the dependencies from the otherTree + tableSet := ctx.semTable.Direct[subQueryInner.ExtractedSubquery.Subquery] + tableSet.RemoveInPlace(otherTree.tableID()) + ctx.semTable.Direct[subQueryInner.ExtractedSubquery.Subquery] = tableSet + tableSet = ctx.semTable.Recursive[subQueryInner.ExtractedSubquery.Subquery] + tableSet.RemoveInPlace(otherTree.tableID()) + ctx.semTable.Recursive[subQueryInner.ExtractedSubquery.Subquery] = tableSet + // return any error while rewriting return rewriteError } diff --git a/go/vt/vtgate/planbuilder/routetree.go b/go/vt/vtgate/planbuilder/routetree.go index c738d410aae..4e14c776707 100644 --- a/go/vt/vtgate/planbuilder/routetree.go +++ b/go/vt/vtgate/planbuilder/routetree.go @@ -200,7 +200,7 @@ func (rp *routeTree) searchForNewVindexes(ctx *planningContext, predicates []sql // using the node.subquery which is the rewritten version of our subquery cmp := &sqlparser.ComparisonExpr{ Left: node.OtherSide, - Right: &sqlparser.Subquery{Select: node.Subquery}, + Right: &sqlparser.Subquery{Select: node.Subquery.Select}, Operator: originalCmp.Operator, } found, exitEarly, err := rp.planComparison(ctx, cmp) diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 2f25a53a9a0..fd2ae5f8405 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -126,10 +126,6 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.setError(err) return true } - if err := a.binder.down(cursor); err != nil { - a.setError(err) - return true - } a.enterProjection(cursor) // this is the visitor going down the tree. Returning false here would just not visit the children @@ -143,6 +139,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } + if err := a.binder.up(cursor); err != nil { + a.setError(err) + return true + } + if err := a.scoper.up(cursor); err != nil { a.setError(err) return false @@ -220,8 +221,8 @@ type originable interface { } func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ *querypb.Type) { - recursive = a.binder.recursive.Dependencies(expr) - direct = a.binder.direct.Dependencies(expr) + recursive = a.binder.recursive.dependencies(expr) + direct = a.binder.direct.dependencies(expr) qt, isFound := a.typer.exprTypes[expr] if !isFound { return diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index c910c0b3aba..6c5a1415daf 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -560,7 +560,7 @@ func TestSubqueriesMappingWhereClause(t *testing.T) { } extractedSubq := semTable.SubqueryRef[subq] - assert.True(t, sqlparser.EqualsExpr(&sqlparser.Subquery{Select: extractedSubq.Subquery}, subq)) + assert.True(t, sqlparser.EqualsExpr(extractedSubq.Subquery, subq)) assert.True(t, sqlparser.EqualsExpr(extractedSubq.Original, sel.Where.Expr)) assert.EqualValues(t, tc.opCode, extractedSubq.OpCode) if tc.otherSideName == "" { @@ -594,7 +594,7 @@ func TestSubqueriesMappingSelectExprs(t *testing.T) { subq := sel.SelectExprs[tc.selExprIdx].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Subquery) extractedSubq := semTable.SubqueryRef[subq] - assert.True(t, sqlparser.EqualsExpr(&sqlparser.Subquery{Select: extractedSubq.Subquery}, subq)) + assert.True(t, sqlparser.EqualsExpr(extractedSubq.Subquery, subq)) assert.True(t, sqlparser.EqualsExpr(extractedSubq.Original, subq)) assert.EqualValues(t, engine.PulloutValue, extractedSubq.OpCode) }) diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index fb29cf5554f..8a37573817b 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -52,7 +52,7 @@ func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) } } -func (b *binder) down(cursor *sqlparser.Cursor) error { +func (b *binder) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.Subquery: currScope := b.scoper.currentScope() @@ -61,7 +61,7 @@ func (b *binder) down(cursor *sqlparser.Cursor) error { } sq := &sqlparser.ExtractedSubquery{ - Subquery: node.Select, + Subquery: node, Original: node, OpCode: int(engine.PulloutValue), } @@ -88,6 +88,24 @@ func (b *binder) down(cursor *sqlparser.Cursor) error { b.subqueryMap[currScope.selectStmt] = append(b.subqueryMap[currScope.selectStmt], sq) b.subqueryRef[node] = sq + + subqRecursiveDeps := b.recursive.dependencies(node) + subqDirectDeps := b.direct.dependencies(node) + + tablesToKeep := EmptyTableSet() + sco := currScope + for sco != nil { + for _, table := range sco.tables { + tablesToKeep.MergeInPlace(table.getTableSet(b.org)) + } + sco = sco.parent + } + + subqDirectDeps.KeepOnly(tablesToKeep) + subqRecursiveDeps.KeepOnly(tablesToKeep) + b.recursive[node] = subqRecursiveDeps + b.direct[node] = subqDirectDeps + case *sqlparser.ColName: deps, err := b.resolveColumn(node, b.scoper.currentScope()) if err != nil { diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index 239184da062..144d0b1d7f3 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -62,7 +62,7 @@ func createDerivedTableForExpressions(expressions sqlparser.SelectExprs, cols sq return vTbl } -// Dependencies implements the TableInfo interface +// dependencies implements the TableInfo interface func (dt *DerivedTable) dependencies(colName string, org originable) (dependencies, error) { directDeps := org.tableSetFor(dt.ASTNode) for i, name := range dt.columnNames { diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 22a01da7f27..bfee236251b 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -35,7 +35,7 @@ type RealTable struct { var _ TableInfo = (*RealTable)(nil) -// Dependencies implements the TableInfo interface +// dependencies implements the TableInfo interface func (r *RealTable) dependencies(colName string, org originable) (dependencies, error) { ts := org.tableSetFor(r.ASTNode) for _, info := range r.getColumns() { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 2d37661ab36..8c252c2b78a 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -144,17 +144,17 @@ func (st *SemTable) TableInfoFor(id TableSet) (TableInfo, error) { // RecursiveDeps return the table dependencies of the expression. func (st *SemTable) RecursiveDeps(expr sqlparser.Expr) TableSet { - return st.Recursive.Dependencies(expr) + return st.Recursive.dependencies(expr) } // DirectDeps return the table dependencies of the expression. func (st *SemTable) DirectDeps(expr sqlparser.Expr) TableSet { - return st.Direct.Dependencies(expr) + return st.Direct.dependencies(expr) } // AddColumnEquality adds a relation of the given colName to the ColumnEqualities map func (st *SemTable) AddColumnEquality(colName *sqlparser.ColName, expr sqlparser.Expr) { - ts := st.Direct.Dependencies(colName) + ts := st.Direct.dependencies(colName) columnName := columnName{ Table: ts, ColumnName: colName.Name.String(), @@ -179,7 +179,7 @@ func (st *SemTable) GetExprAndEqualities(expr sqlparser.Expr) []sqlparser.Expr { // TableInfoForExpr returns the table info of the table that this expression depends on. // Careful: this only works for expressions that have a single table dependency func (st *SemTable) TableInfoForExpr(expr sqlparser.Expr) (TableInfo, error) { - return st.TableInfoFor(st.Direct.Dependencies(expr)) + return st.TableInfoFor(st.Direct.dependencies(expr)) } // GetSelectTables returns the table in the select. @@ -205,8 +205,8 @@ func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type { return nil } -// Dependencies return the table dependencies of the expression. This method finds table dependencies recursively -func (d ExprDependencies) Dependencies(expr sqlparser.Expr) (deps TableSet) { +// dependencies return the table dependencies of the expression. This method finds table dependencies recursively +func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) { if ValidAsMapKey(expr) { // we have something that could live in the cache var found bool @@ -232,7 +232,7 @@ func (d ExprDependencies) Dependencies(expr sqlparser.Expr) (deps TableSet) { if extracted, ok := expr.(*sqlparser.ExtractedSubquery); ok { if extracted.OtherSide != nil { - set := d.Dependencies(extracted.OtherSide) + set := d.dependencies(extracted.OtherSide) deps.MergeInPlace(set) } return false, nil diff --git a/go/vt/vtgate/semantics/tabletset.go b/go/vt/vtgate/semantics/tabletset.go index 54aab332ebc..23f1858b144 100644 --- a/go/vt/vtgate/semantics/tabletset.go +++ b/go/vt/vtgate/semantics/tabletset.go @@ -283,6 +283,46 @@ func (ts *TableSet) MergeInPlace(other TableSet) { } } +// RemoveInPlace removes all the tables in `other` from this TableSet +func (ts *TableSet) RemoveInPlace(other TableSet) { + switch { + case ts.large == nil && other.large == nil: + ts.small &= ^other.small + case ts.large == nil: + ts.small &= ^other.large.tables[0] + case other.large == nil: + ts.large.tables[0] &= ^other.small + default: + for idx := range ts.large.tables { + if len(other.large.tables) <= idx { + break + } + ts.large.tables[idx] &= ^other.large.tables[idx] + } + } +} + +// KeepOnly removes all the tables not in `other` from this TableSet +func (ts *TableSet) KeepOnly(other TableSet) { + switch { + case ts.large == nil && other.large == nil: + ts.small &= other.small + case ts.large == nil: + ts.small &= other.large.tables[0] + case other.large == nil: + ts.small = ts.large.tables[0] & other.small + ts.large = nil + default: + for idx := range ts.large.tables { + if len(other.large.tables) <= idx { + ts.large.tables = ts.large.tables[0:idx] + break + } + ts.large.tables[idx] &= other.large.tables[idx] + } + } +} + // AddTable adds the given table to this set func (ts *TableSet) AddTable(tableidx int) { switch { @@ -303,6 +343,11 @@ func SingleTableSet(tableidx int) TableSet { return TableSet{large: newLargeTableSet(0x0, tableidx)} } +// EmptyTableSet creates an empty TableSet +func EmptyTableSet() TableSet { + return TableSet{small: 0} +} + // MergeTableSets merges all the given TableSet into a single one func MergeTableSets(tss ...TableSet) (result TableSet) { for _, t := range tss { diff --git a/go/vt/vtgate/semantics/tabletset_test.go b/go/vt/vtgate/semantics/tabletset_test.go index c901f23974e..300810e18ad 100644 --- a/go/vt/vtgate/semantics/tabletset_test.go +++ b/go/vt/vtgate/semantics/tabletset_test.go @@ -139,3 +139,79 @@ func TestTableSet_LargeOffset(t *testing.T) { assert.Equal(t, tid, ts.TableOffset()) } } + +func TestTableSet_KeepOnly(t *testing.T) { + testcases := []struct { + name string + ts1 TableSet + ts2 TableSet + result TableSet + }{ + { + name: "both small", + ts1: SingleTableSet(1).Merge(SingleTableSet(2)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(1).Merge(SingleTableSet(3)).Merge(SingleTableSet(4)), + result: SingleTableSet(1).Merge(SingleTableSet(3)), + }, { + name: "both large", + ts1: SingleTableSet(1428).Merge(SingleTableSet(2432)).Merge(SingleTableSet(3412)), + ts2: SingleTableSet(1428).Merge(SingleTableSet(3412)).Merge(SingleTableSet(4342)), + result: SingleTableSet(1428).Merge(SingleTableSet(3412)), + }, { + name: "ts1 small ts2 large", + ts1: SingleTableSet(1).Merge(SingleTableSet(2)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(1).Merge(SingleTableSet(3)).Merge(SingleTableSet(4342)), + result: SingleTableSet(1).Merge(SingleTableSet(3)), + }, { + name: "ts1 large ts2 small", + ts1: SingleTableSet(1).Merge(SingleTableSet(2771)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(1).Merge(SingleTableSet(3)).Merge(SingleTableSet(4)), + result: SingleTableSet(1).Merge(SingleTableSet(3)), + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + testcase.ts1.KeepOnly(testcase.ts2) + assert.Equal(t, testcase.result, testcase.ts1) + }) + } +} + +func TestTableSet_RemoveInPlace(t *testing.T) { + testcases := []struct { + name string + ts1 TableSet + ts2 TableSet + result TableSet + }{ + { + name: "both small", + ts1: SingleTableSet(1).Merge(SingleTableSet(2)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(1).Merge(SingleTableSet(5)).Merge(SingleTableSet(4)), + result: SingleTableSet(2).Merge(SingleTableSet(3)), + }, { + name: "both large", + ts1: SingleTableSet(1428).Merge(SingleTableSet(2432)).Merge(SingleTableSet(3412)), + ts2: SingleTableSet(1424).Merge(SingleTableSet(2432)).Merge(SingleTableSet(4342)), + result: SingleTableSet(1428).Merge(SingleTableSet(3412)), + }, { + name: "ts1 small ts2 large", + ts1: SingleTableSet(1).Merge(SingleTableSet(2)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(14).Merge(SingleTableSet(2)).Merge(SingleTableSet(4342)), + result: SingleTableSet(1).Merge(SingleTableSet(3)), + }, { + name: "ts1 large ts2 small", + ts1: SingleTableSet(1).Merge(SingleTableSet(2771)).Merge(SingleTableSet(3)), + ts2: SingleTableSet(1).Merge(SingleTableSet(3)).Merge(SingleTableSet(4)), + result: SingleTableSet(2771), + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + testcase.ts1.RemoveInPlace(testcase.ts2) + assert.Equal(t, testcase.result, testcase.ts1) + }) + } +} diff --git a/go/vt/vtgate/semantics/vindex_table.go b/go/vt/vtgate/semantics/vindex_table.go index 768a955b7c5..93e17fb37d0 100644 --- a/go/vt/vtgate/semantics/vindex_table.go +++ b/go/vt/vtgate/semantics/vindex_table.go @@ -31,7 +31,7 @@ type VindexTable struct { var _ TableInfo = (*VindexTable)(nil) -// Dependencies implements the TableInfo interface +// dependencies implements the TableInfo interface func (v *VindexTable) dependencies(colName string, org originable) (dependencies, error) { return v.Table.dependencies(colName, org) } diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index 127d8d30879..c992246e5d2 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -34,7 +34,7 @@ type vTableInfo struct { var _ TableInfo = (*vTableInfo)(nil) -// Dependencies implements the TableInfo interface +// dependencies implements the TableInfo interface func (v *vTableInfo) dependencies(colName string, org originable) (dependencies, error) { var deps dependencies = ¬hing{} var err error From 4090a6900734db4263eab6357a5337f0a1063dc6 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 14 Oct 2021 11:08:46 +0200 Subject: [PATCH 3/5] clean up code and add queries to test the new dependency code Signed-off-by: Andres Taylor --- .../vtgate/planbuilder/abstract/querygraph.go | 5 +- go/vt/vtgate/planbuilder/route_planning.go | 2 +- .../planbuilder/testdata/select_cases.txt | 147 ++++++++++++++++++ go/vt/vtgate/semantics/binder.go | 105 +++++++------ 4 files changed, 211 insertions(+), 48 deletions(-) diff --git a/go/vt/vtgate/planbuilder/abstract/querygraph.go b/go/vt/vtgate/planbuilder/abstract/querygraph.go index 7f5db077309..fe4dd53be82 100644 --- a/go/vt/vtgate/planbuilder/abstract/querygraph.go +++ b/go/vt/vtgate/planbuilder/abstract/querygraph.go @@ -17,9 +17,7 @@ limitations under the License. package abstract import ( - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -139,7 +137,8 @@ func (qg *QueryGraph) collectPredicate(predicate sqlparser.Expr, semTable *seman case 1: found := qg.addToSingleTable(deps, predicate) if !found { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "table %v for predicate %v not found", deps, sqlparser.String(predicate)) + // this could be a predicate that only has dependencies from outside this QG + qg.addJoinPredicates(deps, predicate) } default: qg.addJoinPredicates(deps, predicate) diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 80574bcd899..c1a5bd01da4 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -46,7 +46,7 @@ func (c planningContext) isSubQueryToReplace(e sqlparser.Expr) bool { return false } for _, extractedSubq := range c.semTable.GetSubqueryNeedingRewrite() { - if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(&sqlparser.Subquery{Select: extractedSubq.Subquery.Select}, ext) { + if extractedSubq.NeedsRewrite && sqlparser.EqualsRefOfSubquery(extractedSubq.Subquery, ext) { return true } } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 92669e331d5..3b45f95de28 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -2201,3 +2201,150 @@ Gen4 plan same as above } } Gen4 plan same as above + +"select (select col from user limit 1) as a from user join user_extra order by a + 1" +"unsupported: in scatter query: complex order by expression: a + 1" +Gen4 plan same as above + +"select t.a from (select (select col from user limit 1) as a from user join user_extra) t" +{ + "QueryType": "SELECT", + "Original": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Subquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq_has_values1", + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as a from `user` where 1 != 1", + "Query": "select :__sq1 as a from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Subquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as a from `user` where 1 != 1", + "Query": "select :__sq1 as a from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + } + ] + } +} + +"select (select col from user where user_extra.id = 4 limit 1) as a from user join user_extra" +"unsupported: cross-shard correlated subquery" +Gen4 plan same as above diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 8a37573817b..aeddddad022 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -56,55 +56,15 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case *sqlparser.Subquery: currScope := b.scoper.currentScope() - if currScope.selectStmt == nil { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unable to bind subquery to select statement") - } - - sq := &sqlparser.ExtractedSubquery{ - Subquery: node, - Original: node, - OpCode: int(engine.PulloutValue), - } - - switch par := cursor.Parent().(type) { - case *sqlparser.ComparisonExpr: - switch par.Operator { - case sqlparser.InOp: - sq.OpCode = int(engine.PulloutIn) - case sqlparser.NotInOp: - sq.OpCode = int(engine.PulloutNotIn) - } - subq, exp := GetSubqueryAndOtherSide(par) - sq.Original = &sqlparser.ComparisonExpr{ - Left: exp, - Operator: par.Operator, - Right: subq, - } - sq.OtherSide = exp - case *sqlparser.ExistsExpr: - sq.OpCode = int(engine.PulloutExists) - sq.Original = par + sq, err := b.createExtractedSubquery(cursor, currScope, node) + if err != nil { + return err } b.subqueryMap[currScope.selectStmt] = append(b.subqueryMap[currScope.selectStmt], sq) b.subqueryRef[node] = sq - subqRecursiveDeps := b.recursive.dependencies(node) - subqDirectDeps := b.direct.dependencies(node) - - tablesToKeep := EmptyTableSet() - sco := currScope - for sco != nil { - for _, table := range sco.tables { - tablesToKeep.MergeInPlace(table.getTableSet(b.org)) - } - sco = sco.parent - } - - subqDirectDeps.KeepOnly(tablesToKeep) - subqRecursiveDeps.KeepOnly(tablesToKeep) - b.recursive[node] = subqRecursiveDeps - b.direct[node] = subqDirectDeps + b.setSubQueryDependencies(node, currScope) case *sqlparser.ColName: deps, err := b.resolveColumn(node, b.scoper.currentScope()) @@ -138,6 +98,62 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { return nil } +// setSubQueryDependencies sets the correct dependencies for the subquery +// the binder usually only sets the dependencies of ColNames, but we need to +// handle the subquery dependencies differently, so they are set manually here +// this method will only keep dependencies to tables outside the subquery +func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *scope) { + subqRecursiveDeps := b.recursive.dependencies(subq) + subqDirectDeps := b.direct.dependencies(subq) + + tablesToKeep := EmptyTableSet() + sco := currScope + for sco != nil { + for _, table := range sco.tables { + tablesToKeep.MergeInPlace(table.getTableSet(b.org)) + } + sco = sco.parent + } + + subqDirectDeps.KeepOnly(tablesToKeep) + subqRecursiveDeps.KeepOnly(tablesToKeep) + b.recursive[subq] = subqRecursiveDeps + b.direct[subq] = subqDirectDeps +} + +func (b *binder) createExtractedSubquery(cursor *sqlparser.Cursor, currScope *scope, subq *sqlparser.Subquery) (*sqlparser.ExtractedSubquery, error) { + if currScope.selectStmt == nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unable to bind subquery to select statement") + } + + sq := &sqlparser.ExtractedSubquery{ + Subquery: subq, + Original: subq, + OpCode: int(engine.PulloutValue), + } + + switch par := cursor.Parent().(type) { + case *sqlparser.ComparisonExpr: + switch par.Operator { + case sqlparser.InOp: + sq.OpCode = int(engine.PulloutIn) + case sqlparser.NotInOp: + sq.OpCode = int(engine.PulloutNotIn) + } + subq, exp := GetSubqueryAndOtherSide(par) + sq.Original = &sqlparser.ComparisonExpr{ + Left: exp, + Operator: par.Operator, + Right: subq, + } + sq.OtherSide = exp + case *sqlparser.ExistsExpr: + sq.OpCode = int(engine.PulloutExists) + sq.Original = par + } + return sq, nil +} + func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope) (deps dependency, err error) { var thisDeps dependencies for current != nil { @@ -187,6 +203,7 @@ func makeAmbiguousError(colName *sqlparser.ColName, err error) error { return err } +// GetSubqueryAndOtherSide returns the subquery and other side of a comparison, iff one of the sides is a SubQuery func GetSubqueryAndOtherSide(node *sqlparser.ComparisonExpr) (*sqlparser.Subquery, sqlparser.Expr) { var subq *sqlparser.Subquery var exp sqlparser.Expr From c0d1609c746242409805e4b1f482267c51002d4c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 14 Oct 2021 11:35:52 +0200 Subject: [PATCH 4/5] update test assertions Signed-off-by: Andres Taylor --- .../planbuilder/abstract/operator_test_data.txt | 4 +--- go/vt/vtgate/planbuilder/rewrite_test.go | 2 +- go/vt/vtgate/semantics/analyzer_test.go | 15 ++++++++------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/go/vt/vtgate/planbuilder/abstract/operator_test_data.txt b/go/vt/vtgate/planbuilder/abstract/operator_test_data.txt index 2970ce7863a..e2179858757 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator_test_data.txt +++ b/go/vt/vtgate/planbuilder/abstract/operator_test_data.txt @@ -303,9 +303,7 @@ SubQuery: { }] Outer: QueryGraph: { Tables: - TableSet{0}:`user` AS u - JoinPredicates: - TableSet{0,1} - u.id = (select id from user_extra where id = u.id) + TableSet{0}:`user` AS u where u.id = (select id from user_extra where id = u.id) } } diff --git a/go/vt/vtgate/planbuilder/rewrite_test.go b/go/vt/vtgate/planbuilder/rewrite_test.go index 31f39323fed..afa9d680abd 100644 --- a/go/vt/vtgate/planbuilder/rewrite_test.go +++ b/go/vt/vtgate/planbuilder/rewrite_test.go @@ -140,7 +140,7 @@ func TestHavingRewrite(t *testing.T) { assert.Equal(t, len(tcase.sqs), len(squeries), "number of subqueries not matched") } for _, sq := range squeries { - assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery)) + assert.Equal(t, tcase.sqs[sq.ArgName], sqlparser.String(sq.Subquery.Select)) } }) } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 6c5a1415daf..9d631db6677 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -34,11 +34,12 @@ var T0 TableSet var ( // Just here to make outputs more readable - T1 = SingleTableSet(0) - T2 = SingleTableSet(1) - T3 = SingleTableSet(2) - T4 = SingleTableSet(3) - T5 = SingleTableSet(4) + None = EmptyTableSet() + T1 = SingleTableSet(0) + T2 = SingleTableSet(1) + T3 = SingleTableSet(2) + T4 = SingleTableSet(3) + T5 = SingleTableSet(4) ) func extract(in *sqlparser.Select, idx int) sqlparser.Expr { @@ -488,10 +489,10 @@ func TestScopeForSubqueries(t *testing.T) { deps: T2, }, { sql: `select t.col1, (select (select y.col2 from y) from z) from x as t`, - deps: T3, + deps: None, }, { sql: `select t.col1, (select (select (select (select w.col2 from w) from x) from y) from z) from x as t`, - deps: T5, + deps: None, }, { sql: `select t.col1, (select id from t) from x as t`, deps: T2, From 438a96d01ac713e831fb5f78387ba36c8b386692 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 14 Oct 2021 11:51:15 +0200 Subject: [PATCH 5/5] update query to test a more interesting situation Signed-off-by: Andres Taylor --- .../planbuilder/testdata/select_cases.txt | 128 +++++++++++++++++- 1 file changed, 125 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 3b45f95de28..21ec7e3eb2f 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -2202,9 +2202,131 @@ Gen4 plan same as above } Gen4 plan same as above -"select (select col from user limit 1) as a from user join user_extra order by a + 1" -"unsupported: in scatter query: complex order by expression: a + 1" -Gen4 plan same as above +"select (select col from user limit 1) as a from user join user_extra order by a" +{ + "QueryType": "SELECT", + "Original": "select (select col from user limit 1) as a from user join user_extra order by a", + "Instructions": { + "OperatorType": "Subquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq_has_values1", + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as a, weight_string(:__sq1) from `user` where 1 != 1", + "OrderBy": "(0|1) ASC", + "Query": "select :__sq1 as a, weight_string(:__sq1) from `user` order by a asc", + "ResultColumns": 1, + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select (select col from user limit 1) as a from user join user_extra order by a", + "Instructions": { + "OperatorType": "Subquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from `user` where 1 != 1", + "Query": "select col from `user` limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :__sq1 as a, weight_string(:__sq1) from `user` where 1 != 1", + "OrderBy": "(0|1) ASC", + "Query": "select :__sq1 as a, weight_string(:__sq1) from `user` order by a asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + } +} "select t.a from (select (select col from user limit 1) as a from user join user_extra) t" {