Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Expression and Statement Simplifier #13636

Merged
merged 17 commits into from
Aug 9, 2023
Merged
Changes from 1 commit
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
replaced Rewrite with SafeRewrite in visitAllExpressionsInAST
Signed-off-by: Arvind Murty <10248018+arvind-murty@users.noreply.github.com>
  • Loading branch information
arvind-murty committed Aug 1, 2023
commit c378d5ed62d9e703bbf4230ea97a8147cb10af22
38 changes: 16 additions & 22 deletions go/vt/vtgate/simplifier/simplifier.go
Original file line number Diff line number Diff line change
@@ -98,7 +98,7 @@ func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.Se
if simplified {
return in
}

// we found no simplifications
return nil
}

@@ -379,17 +379,13 @@ func newExprCursor(expr sqlparser.Expr, replace func(replaceWith sqlparser.Expr)
}

// visitAllExpressionsInAST will walk the AST and visit all expressions
// This cursor has a few extra capabilities that the normal sqlparser.Rewrite does not have,
// This cursor has a few extra capabilities that the normal sqlparser.SafeRewrite does not have,
// such as visiting and being able to change individual expressions in a AND tree
func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expressionCursor) bool) {
abort := false
post := func(*sqlparser.Cursor) bool {
return !abort
alwaysVisit := func(node, parent sqlparser.SQLNode) bool {
return true
}
pre := func(cursor *sqlparser.Cursor) bool {
if abort {
return true
}
up := func(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case sqlparser.SelectExprs:
_, isSel := cursor.Parent().(*sqlparser.Select)
@@ -441,35 +437,36 @@ func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expres
expr.Expr = original
},
)
abort = !visit(item)
visit(item)
}
case *sqlparser.Where:
exprs := sqlparser.SplitAndExpression(nil, node.Expr)
set := func(input []sqlparser.Expr) {
node.Expr = sqlparser.AndExpressions(input...)
exprs = input
}
abort = !visitExpressions(exprs, set, visit)
visitExpressions(exprs, set, visit)
case *sqlparser.JoinCondition:
join, ok := cursor.Parent().(*sqlparser.JoinTableExpr)
if !ok {
return true
}
// TODO: improve this
if join.Join != sqlparser.NormalJoinType || node.Using != nil {
return false
return true
}
exprs := sqlparser.SplitAndExpression(nil, node.On)
set := func(input []sqlparser.Expr) {
node.On = sqlparser.AndExpressions(input...)
exprs = input
}
abort = !visitExpressions(exprs, set, visit)
visitExpressions(exprs, set, visit)
case sqlparser.GroupBy:
set := func(input []sqlparser.Expr) {
node = input
cursor.Replace(node)
}
abort = !visitExpressions(node, set, visit)
visitExpressions(node, set, visit)
case sqlparser.OrderBy:
for idx := 0; idx < len(node); idx++ {
order := node[idx]
@@ -513,10 +510,7 @@ func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expres
order.Expr = original
},
)
abort = visit(item)
if abort {
break
}
visit(item)
}
case *sqlparser.Limit:
if node.Offset != nil {
@@ -532,9 +526,9 @@ func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expres
/*restore*/ func() {
node.Offset = original
})
abort = visit(cursor)
visit(cursor)
}
if !abort && node.Rowcount != nil {
if node.Rowcount != nil {
original := node.Rowcount
cursor := newExprCursor(node.Rowcount,
/*replace*/ func(replaceWith sqlparser.Expr) {
@@ -547,12 +541,12 @@ func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expres
/*restore*/ func() {
node.Rowcount = original
})
abort = visit(cursor)
visit(cursor)
}
}
return true
}
sqlparser.Rewrite(clone, pre, post)
sqlparser.SafeRewrite(clone, alwaysVisit, up)
}

// visitExpressions allows the cursor to visit all expressions in a slice,