Skip to content

Commit

Permalink
Fixed ExtractedSuqbuery issue with routed tables
Browse files Browse the repository at this point in the history
Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
frouioui committed Oct 12, 2021
1 parent e12a286 commit dc15f61
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions go/vt/vtgate/planbuilder/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import (
)

type rewriter struct {
semTable *semantics.SemTable
reservedVars *sqlparser.ReservedVars
isInSubquery int
err error
semTable *semantics.SemTable
reservedVars *sqlparser.ReservedVars
subqueryStack []*sqlparser.ExtractedSubquery
err error
}

func queryRewrite(semTable *semantics.SemTable, reservedVars *sqlparser.ReservedVars, statement sqlparser.SelectStatement) error {
Expand Down Expand Up @@ -58,7 +58,7 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool {
case *sqlparser.AliasedTableExpr:
// rewrite names of the routed tables for the subquery
// We only need to do this for non-derived tables and if they are in a subquery
if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || r.isInSubquery == 0 {
if _, isDerived := node.Expr.(*sqlparser.DerivedTable); isDerived || len(r.subqueryStack) == 0 {
break
}
// find the tableSet and tableInfo that this table points to
Expand Down Expand Up @@ -89,21 +89,32 @@ func (r *rewriter) rewriteDown(cursor *sqlparser.Cursor) bool {
// replace the table name with the original table
tableName.Name = vindexTable.Name
node.Expr = tableName

if len(r.subqueryStack) == 0 {
break
}
currSuqbuery := r.subqueryStack[len(r.subqueryStack)-1]
sqlparser.Rewrite(currSuqbuery.Original, func(cursor *sqlparser.Cursor) bool {
switch cursor.Node().(type) {
case *sqlparser.Subquery:
cursor.Replace(&sqlparser.Subquery{Select: currSuqbuery.Subquery})
return false
}
return true
}, nil)
case *sqlparser.Subquery:
r.isInSubquery++
err := rewriteSubquery(cursor, r, node)
if err != nil {
r.err = err
}

}
return true
}

func (r *rewriter) rewriteUp(cursor *sqlparser.Cursor) bool {
switch cursor.Node().(type) {
case *sqlparser.Subquery:
r.isInSubquery--
r.subqueryStack = r.subqueryStack[:len(r.subqueryStack)-1]
}
return r.err == nil
}
Expand All @@ -125,6 +136,7 @@ func rewriteInSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Co
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map")
}

r.subqueryStack = append(r.subqueryStack, semTableSQ)
argName, hasValuesArg := r.reservedVars.ReserveSubQueryWithHasValues()
semTableSQ.ArgName = argName
semTableSQ.HasValuesArg = hasValuesArg
Expand Down Expand Up @@ -154,6 +166,7 @@ func rewriteSubquery(cursor *sqlparser.Cursor, r *rewriter, node *sqlparser.Subq
if semTableSQ.ArgName != "" || engine.PulloutOpcode(semTableSQ.OpCode) != engine.PulloutValue {
return nil
}
r.subqueryStack = append(r.subqueryStack, semTableSQ)
argName := r.reservedVars.ReserveSubQuery()
semTableSQ.ArgName = argName
cursor.Replace(semTableSQ)
Expand All @@ -166,6 +179,7 @@ func (r *rewriter) rewriteExistsSubquery(cursor *sqlparser.Cursor, node *sqlpars
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: came across subquery that was not in the subq map")
}

r.subqueryStack = append(r.subqueryStack, semTableSQ)
argName := r.reservedVars.ReserveHasValuesSubQuery()
semTableSQ.ArgName = argName
cursor.Replace(semTableSQ)
Expand Down

0 comments on commit dc15f61

Please sign in to comment.