diff --git a/go/vt/vtgate/planbuilder/rewrite.go b/go/vt/vtgate/planbuilder/rewrite.go index 1d6cf13314f..de7ca445949 100644 --- a/go/vt/vtgate/planbuilder/rewrite.go +++ b/go/vt/vtgate/planbuilder/rewrite.go @@ -25,10 +25,10 @@ import ( ) type rewriter struct { - semTable *semantics.SemTable - reservedVars *sqlparser.ReservedVars - isInSubquery int - err error + semTable *semantics.SemTable + reservedVars *sqlparser.ReservedVars + subqueryStack []*sqlparser.ExtractedSubquery + err error } func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.SelectStatement) error { @@ -58,7 +58,7 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { case *sqlparser.AliasedTableExpr: // rewrite names of the routed tables for the subquery // We only need to do this for non-derived tables and if they are in a subquery - if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || r.isInSubquery == 0 { + if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || len(r.subqueryStack) == 0 { break } // find the tableSet and tableInfo that this table points to @@ -89,13 +89,24 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { // replace the table name with the original table tableName.Name = vindexTable.Name node.Expr = tableName + + if len(r.subqueryStack) == 0 { + break + } + currSuqbuery := r.subqueryStack[len(r.subqueryStack)-1] + sqlparser.Rewrite(currSuqbuery.Original, func(cursor *sqlparser.Cursor) bool { + switch cursor.Node().(type) { + case *sqlparser.Subquery: + cursor.Replace(&sqlparser.Subquery{Select: currSuqbuery.Subquery}) + return false + } + return true + }, nil) case *sqlparser.Subquery: - r.isInSubquery++ err := rewriteSubquery(cursor, r, node) if err != nil { r.err = err } - } return true } @@ -103,7 +114,7 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool { func (r *rewriter) rewriteUp(cursor *sqlparser.Cursor) bool { switch cursor.Node().(type) { case *sqlparser.Subquery: - r.isInSubquery-- + r.subqueryStack = r.subqueryStack[:len(r.subqueryStack)-1] } return r.err == nil } @@ -125,6 +136,7 @@ func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Co return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map") } + r.subqueryStack = append(r.subqueryStack, semTableSQ) argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues() semTableSQ.ArgName = argName semTableSQ.HasValuesArg = hasValuesArg @@ -154,6 +166,7 @@ func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subq if semTableSQ.ArgName != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue { return nil } + r.subqueryStack = append(r.subqueryStack, semTableSQ) argName := r.reservedVars.ReserveSubQuery() semTableSQ.ArgName = argName cursor.Replace(semTableSQ) @@ -166,6 +179,7 @@ func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlpars return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map") } + r.subqueryStack = append(r.subqueryStack, semTableSQ) argName := r.reservedVars.ReserveHasValuesSubQuery() semTableSQ.ArgName = argName cursor.Replace(semTableSQ)