From 34194de06bfd067ad54d150b4ff2275bc4eab771 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 5 Oct 2021 19:04:01 +0200 Subject: [PATCH 1/4] Support in the rewriter for replace & revisit Co-authored-by: Florent Poinsard Signed-off-by: Andres Taylor --- .../asthelpergen/integration/ast_rewrite.go | 24 +++- .../integration/integration_rewriter_test.go | 40 ++++++ .../asthelpergen/integration/test_helpers.go | 21 ++- go/tools/asthelpergen/integration/types.go | 4 + go/tools/asthelpergen/rewrite_gen.go | 20 ++- go/vt/sqlparser/ast_rewrite.go | 120 +++++++++++++++--- go/vt/sqlparser/rewriter_api.go | 20 +++ go/vt/sqlparser/rewriter_test.go | 54 ++++++++ 8 files changed, 283 insertions(+), 20 deletions(-) diff --git a/go/tools/asthelpergen/integration/ast_rewrite.go b/go/tools/asthelpergen/integration/ast_rewrite.go index 5d554e92e43..3741b2080cb 100644 --- a/go/tools/asthelpergen/integration/ast_rewrite.go +++ b/go/tools/asthelpergen/integration/ast_rewrite.go @@ -59,7 +59,13 @@ func (a *application) rewriteBytes(parent AST, node Bytes, replacer replacerFunc a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Bytes) + a.cur.revisit = false + return a.rewriteBytes(parent, node, replacer) + } + if kontinue { return true } } @@ -104,7 +110,13 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(InterfaceSlice) + a.cur.revisit = false + return a.rewriteInterfaceSlice(parent, node, replacer) + } + if kontinue { return true } } @@ -159,7 +171,13 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(LeafSlice) + a.cur.revisit = false + return a.rewriteLeafSlice(parent, node, replacer) + } + if kontinue { return true } } diff --git a/go/tools/asthelpergen/integration/integration_rewriter_test.go b/go/tools/asthelpergen/integration/integration_rewriter_test.go index 7699fe45f6a..c7822980cd4 100644 --- a/go/tools/asthelpergen/integration/integration_rewriter_test.go +++ b/go/tools/asthelpergen/integration/integration_rewriter_test.go @@ -164,6 +164,46 @@ func TestRewriteVisitInterfaceSlice(t *testing.T) { }) } +func TestRewriteAndRevisitInterfaceSlice(t *testing.T) { + leaf1 := &Leaf{2} + leaf2 := &Leaf{3} + ast := InterfaceSlice{ + leaf1, + leaf2, + } + ast2 := InterfaceSlice{ + leaf2, + leaf1, + } + + tv := &rewriteTestVisitor{} + + a := false + _ = Rewrite(ast, func(cursor *Cursor) bool { + tv.pre(cursor) + switch cursor.node.(type) { + case InterfaceSlice: + if a { + break + } + a = true + cursor.ReplaceAndRevisit(ast2) + } + return true + }, tv.post) + + tv.assertEquals(t, []step{ + Pre{ast}, // when we visit ast, we want to replace and revisit, + // which means that we don't do a post on this node, or visit the children + Pre{ast2}, + Pre{leaf2}, + Post{leaf2}, + Pre{leaf1}, + Post{leaf1}, + Post{ast2}, + }) +} + func TestRewriteVisitRefContainerReplace(t *testing.T) { ast := &RefContainer{ ASTType: &RefContainer{NotASTType: 12}, diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 923f6f7c546..db9facb3c3b 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -47,6 +47,8 @@ type Cursor struct { parent AST replacer replacerFunc node AST + // marks that the node has been replaced, and the new node should be visited + revisit bool } // Node returns the current Node. @@ -55,13 +57,30 @@ func (c *Cursor) Node() AST { return c.node } // Parent returns the parent of the current Node. func (c *Cursor) Parent() AST { return c.parent } -// Replace replaces the current node in the parent field with this new object. The use needs to make sure to not +// Replace replaces the current node in the parent field with this new object. The user needs to make sure to not // replace the object with something of the wrong type, or the visitor will panic. func (c *Cursor) Replace(newNode AST) { c.replacer(newNode, c.parent) c.node = newNode } +// ReplaceAndRevisit replaces the current node in the parent field with this new object. +// When used, this will abort the visitation of the current node - no post or children visited, +// and the new node visited. +func (c *Cursor) ReplaceAndRevisit(newNode AST) { + switch newNode.(type) { + case InterfaceSlice: + default: + // We need to add support to the generated code for when to look at the revisit flag. At the moment it is only + // there for slices of AST implementations + panic("no support added for this type yet") + } + + c.replacer(newNode, c.parent) + c.node = newNode + c.revisit = true +} + type replacerFunc func(newNode, parent AST) // Rewrite is the api. diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index 3bed2b5e009..4bb4081fb87 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -110,6 +110,10 @@ func (r InterfaceSlice) String() string { return "[" + strings.Join(elements, ", ") + "]" } +func (r InterfaceSlice) IsRevisitable() bool { + return true +} + // We need to support these types - a slice of AST elements can implement the interface type Bytes []byte diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 1a7d2411d7e..ab01584c1cb 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -27,6 +27,10 @@ const ( rewriteName = "rewrite" ) +type Revisitable interface { + IsRevisitable() bool +} + type rewriteGen struct { ifaceName string file *jen.File @@ -178,7 +182,21 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator stmts := []jen.Code{ jen.If(jen.Id("node == nil").Block(returnTrue())), } - stmts = append(stmts, executePre()) + + typeString := types.TypeString(t, noQualifier) + + preStmts := setupCursor() + preStmts = append(preStmts, + jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"), + jen.If(jen.Id("a.cur.revisit").Block( + jen.Id("node").Op("=").Id("a.cur.node.("+typeString+")"), + jen.Id("a.cur.revisit").Op("=").False(), + jen.Return(jen.Id("a.rewrite"+typeString+"(parent, node, replacer)")), + )), + jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))), + ) + + stmts = append(stmts, jen.If(jen.Id("a.pre!= nil").Block(preStmts...))) haveChildren := false if shouldAdd(slice.Elem(), spi.iface()) { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index c13144d32a8..6da577e73cb 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1130,7 +1130,13 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Columns) + a.cur.revisit = false + return a.rewriteColumns(parent, node, replacer) + } + if kontinue { return true } } @@ -1161,7 +1167,13 @@ func (a *application) rewriteComments(parent SQLNode, node Comments, replacer re a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Comments) + a.cur.revisit = false + return a.rewriteComments(parent, node, replacer) + } + if kontinue { return true } } @@ -1846,7 +1858,13 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Exprs) + a.cur.revisit = false + return a.rewriteExprs(parent, node, replacer) + } + if kontinue { return true } } @@ -2002,7 +2020,13 @@ func (a *application) rewriteGroupBy(parent SQLNode, node GroupBy, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(GroupBy) + a.cur.revisit = false + return a.rewriteGroupBy(parent, node, replacer) + } + if kontinue { return true } } @@ -2629,7 +2653,13 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(OnDup) + a.cur.revisit = false + return a.rewriteOnDup(parent, node, replacer) + } + if kontinue { return true } } @@ -2746,7 +2776,13 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(OrderBy) + a.cur.revisit = false + return a.rewriteOrderBy(parent, node, replacer) + } + if kontinue { return true } } @@ -2957,7 +2993,13 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Partitions) + a.cur.revisit = false + return a.rewritePartitions(parent, node, replacer) + } + if kontinue { return true } } @@ -3377,7 +3419,13 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(SelectExprs) + a.cur.revisit = false + return a.rewriteSelectExprs(parent, node, replacer) + } + if kontinue { return true } } @@ -3496,7 +3544,13 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(SetExprs) + a.cur.revisit = false + return a.rewriteSetExprs(parent, node, replacer) + } + if kontinue { return true } } @@ -3883,7 +3937,13 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(TableExprs) + a.cur.revisit = false + return a.rewriteTableExprs(parent, node, replacer) + } + if kontinue { return true } } @@ -3964,7 +4024,13 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(TableNames) + a.cur.revisit = false + return a.rewriteTableNames(parent, node, replacer) + } + if kontinue { return true } } @@ -3995,7 +4061,13 @@ func (a *application) rewriteTableOptions(parent SQLNode, node TableOptions, rep a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(TableOptions) + a.cur.revisit = false + return a.rewriteTableOptions(parent, node, replacer) + } + if kontinue { return true } } @@ -4338,7 +4410,13 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(UpdateExprs) + a.cur.revisit = false + return a.rewriteUpdateExprs(parent, node, replacer) + } + if kontinue { return true } } @@ -4443,7 +4521,13 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(ValTuple) + a.cur.revisit = false + return a.rewriteValTuple(parent, node, replacer) + } + if kontinue { return true } } @@ -4498,7 +4582,13 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - if !a.pre(&a.cur) { + kontinue := !a.pre(&a.cur) + if a.cur.revisit { + node = a.cur.node.(Values) + a.cur.revisit = false + return a.rewriteValues(parent, node, replacer) + } + if kontinue { return true } } diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 71c0660deb8..a1de8a8d539 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -72,6 +72,9 @@ type Cursor struct { parent SQLNode replacer replacerFunc node SQLNode + + // marks that the node has been replaced, and the new node should be visited + revisit bool } // Node returns the current Node. @@ -87,6 +90,23 @@ func (c *Cursor) Replace(newNode SQLNode) { c.node = newNode } +// ReplaceAndRevisit replaces the current node in the parent field with this new object. +// When used, this will abort the visitation of the current node - no post or children visited, +// and the new node visited. +func (c *Cursor) ReplaceAndRevisit(newNode SQLNode) { + switch newNode.(type) { + case SelectExprs: + default: + // We need to add support to the generated code for when to look at the revisit flag. At the moment it is only + // there for slices of SQLNode implementations + panic("no support added for this type yet") + } + + c.replacer(newNode, c.parent) + c.node = newNode + c.revisit = true +} + type replacerFunc func(newNode, parent SQLNode) // application carries all the shared data so we can pass it around cheaply. diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 6887da8c1a8..dadd2c501df 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -19,6 +19,8 @@ package sqlparser import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) @@ -38,6 +40,58 @@ func BenchmarkVisitLargeExpression(b *testing.B) { } } +func TestReplaceWorksInLaterCalls(t *testing.T) { + q := "select * from tbl1" + stmt, err := Parse(q) + require.NoError(t, err) + count := 0 + Rewrite(stmt, func(cursor *Cursor) bool { + switch node := cursor.Node().(type) { + case *Select: + node.SelectExprs[0] = &AliasedExpr{ + Expr: NewStrLiteral("apa"), + } + node.SelectExprs = append(node.SelectExprs, &AliasedExpr{ + Expr: NewStrLiteral("foobar"), + }) + case *StarExpr: + t.Errorf("should not have seen the star") + case *Literal: + count++ + } + return true + }, nil) + assert.Equal(t, 2, count) +} + +func TestReplaceAndRevisitWorksInLaterCalls(t *testing.T) { + q := "select * from tbl1" + stmt, err := Parse(q) + require.NoError(t, err) + count := 0 + Rewrite(stmt, func(cursor *Cursor) bool { + switch node := cursor.Node().(type) { + case SelectExprs: + if len(node) != 1 { + return true + } + expr1 := &AliasedExpr{ + Expr: NewStrLiteral("apa"), + } + expr2 := &AliasedExpr{ + Expr: NewStrLiteral("foobar"), + } + cursor.ReplaceAndRevisit(SelectExprs{expr1, expr2}) + case *StarExpr: + t.Errorf("should not have seen the star") + case *Literal: + count++ + } + return true + }, nil) + assert.Equal(t, 2, count) +} + func TestChangeValueTypeGivesError(t *testing.T) { parse, err := Parse("select * from a join b on a.id = b.id") require.NoError(t, err) From 37e7884558bd1340a78403ac119895e007577729 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 6 Oct 2021 07:50:09 +0200 Subject: [PATCH 2/4] removed earlyRewrite phase in the analyzer Co-authored-by: Andres Taylor Signed-off-by: Florent Poinsard --- go/tools/asthelpergen/integration/types.go | 4 -- go/tools/asthelpergen/rewrite_gen.go | 4 -- go/vt/sqlparser/ast_rewrite.go | 2 +- go/vt/vtgate/semantics/analyzer.go | 22 +++++----- go/vt/vtgate/semantics/early_rewriter.go | 49 +++++++++------------- 5 files changed, 33 insertions(+), 48 deletions(-) diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index 4bb4081fb87..3bed2b5e009 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -110,10 +110,6 @@ func (r InterfaceSlice) String() string { return "[" + strings.Join(elements, ", ") + "]" } -func (r InterfaceSlice) IsRevisitable() bool { - return true -} - // We need to support these types - a slice of AST elements can implement the interface type Bytes []byte diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index ab01584c1cb..4804ef8d874 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -27,10 +27,6 @@ const ( rewriteName = "rewrite" ) -type Revisitable interface { - IsRevisitable() bool -} - type rewriteGen struct { ifaceName string file *jen.File diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 6da577e73cb..e5d2f77bb9f 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -3419,7 +3419,7 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - kontinue := !a.pre(&a.cur) + kontinue := !a.pre(&a.cur) if a.cur.revisit { node = a.cur.node.(SelectExprs) a.cur.revisit = false diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 197cb6f73bb..e109b6f3046 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -32,16 +32,16 @@ import ( // analyzer controls the flow of the analysis. // It starts the tree walking and controls which part of the analysis sees which parts of the tree type analyzer struct { - scoper *scoper - tables *tableCollector - binder *binder - typer *typer + scoper *scoper + tables *tableCollector + binder *binder + typer *typer + rewriter *earlyRewriter - err error + err error inProjection int - projErr error - + projErr error hasRewritten bool } @@ -57,6 +57,7 @@ func newAnalyzer(dbName string, si SchemaInformation) *analyzer { s.org = a a.tables.org = a a.binder = newBinder(s, a, a.tables, a.typer) + a.rewriter = &earlyRewriter{scoper: s} return a } @@ -75,9 +76,6 @@ func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInf semTable := analyzer.newSemTable(statement) // Rewriting operation - if err = earlyRewrite(statement, semTable, analyzer.scoper); err != nil { - return nil, err - } analyzer.hasRewritten = true // Analysis post rewriting @@ -135,6 +133,10 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.setError(err) return true } + if err := a.rewriter.down(cursor); err != nil { + a.setError(err) + return true + } } else { // after expand star if err := checkUnionColumns(cursor); err != nil { a.setError(err) diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index cbb7c10fad7..3d09e922df3 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -25,29 +25,21 @@ import ( ) type earlyRewriter struct { - err error - semTable *SemTable scoper *scoper clause string } -// earlyRewrite rewrites the query before the binder has had a chance to work on the query -// it introduces new expressions that the binder will later need to bind correctly -func earlyRewrite(statement sqlparser.SelectStatement, semTable *SemTable, scoper *scoper) error { - r := earlyRewriter{ - semTable: semTable, - scoper: scoper, - } - sqlparser.Rewrite(statement, r.rewrite, nil) - return r.err -} - -func (r *earlyRewriter) rewrite(cursor *sqlparser.Cursor) bool { +func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { - case *sqlparser.Select: - tables := r.semTable.GetSelectTables(node) + case sqlparser.SelectExprs: + _, isSel:=cursor.Parent().(*sqlparser.Select) + if !isSel { + return nil + } + tables := r.scoper.currentScope().tables var selExprs sqlparser.SelectExprs - for _, selectExpr := range node.SelectExprs { + changed := false + for _, selectExpr := range node { starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr) if !isStarExpr { selExprs = append(selExprs, selectExpr) @@ -55,16 +47,18 @@ func (r *earlyRewriter) rewrite(cursor *sqlparser.Cursor) bool { } starExpanded, colNames, err := expandTableColumns(tables, starExpr) if err != nil { - r.err = err - return false + return err } if !starExpanded || colNames == nil { selExprs = append(selExprs, selectExpr) continue } selExprs = append(selExprs, colNames...) + changed = true + } + if changed { + cursor.ReplaceAndRevisit(selExprs) } - node.SelectExprs = selExprs case *sqlparser.Order: r.clause = "order clause" case sqlparser.GroupBy: @@ -77,27 +71,24 @@ func (r *earlyRewriter) rewrite(cursor *sqlparser.Cursor) bool { } num, err := strconv.Atoi(node.Val) if err != nil { - r.err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) - break + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) + } if num < 1 || num > len(currScope.selectStmt.SelectExprs) { - r.err = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) - break + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) } for i := 0; i < num; i++ { expr := currScope.selectStmt.SelectExprs[i] _, ok := expr.(*sqlparser.AliasedExpr) if !ok { - r.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(expr)) - return true + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(expr)) } } aliasedExpr, ok := currScope.selectStmt.SelectExprs[num-1].(*sqlparser.AliasedExpr) if !ok { - r.err = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) - break + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) } if !aliasedExpr.As.IsEmpty() { @@ -107,7 +98,7 @@ func (r *earlyRewriter) rewrite(cursor *sqlparser.Cursor) bool { cursor.Replace(expr) } } - return true + return nil } // realCloneOfColNames clones all the expressions including ColName. From 04670655bdb362e7556456de51f1be75da3e11fb Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Wed, 6 Oct 2021 08:03:07 +0200 Subject: [PATCH 3/4] removed post up and down steps in the analyzer Co-authored-by: Andres Taylor Signed-off-by: Florent Poinsard --- go/vt/vtgate/semantics/analyzer.go | 89 +++++++++++------------------- go/vt/vtgate/semantics/scoper.go | 50 ----------------- 2 files changed, 31 insertions(+), 108 deletions(-) diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index e109b6f3046..f13b6fba61b 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -42,7 +42,6 @@ type analyzer struct { inProjection int projErr error - hasRewritten bool } // newAnalyzer create the semantic analyzer @@ -75,15 +74,6 @@ func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInf // Creation of the semantic table semTable := analyzer.newSemTable(statement) - // Rewriting operation - analyzer.hasRewritten = true - - // Analysis post rewriting - err = analyzer.analyze(statement) - if err != nil { - return nil, err - } - semTable.ProjectionErr = analyzer.projErr return semTable, nil } @@ -124,31 +114,21 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { return true } - if !a.hasRewritten { - if err := a.scoper.down(cursor); err != nil { - a.setError(err) - return true - } - if err := a.checkForInvalidConstructs(cursor); err != nil { - a.setError(err) - return true - } - if err := a.rewriter.down(cursor); err != nil { - a.setError(err) - return true - } - } else { // after expand star - if err := checkUnionColumns(cursor); err != nil { - a.setError(err) - return true - } - - a.scoper.downPost(cursor) - - if err := a.binder.down(cursor); err != nil { - a.setError(err) - return true - } + if err := a.scoper.down(cursor); err != nil { + a.setError(err) + return true + } + if err := a.checkForInvalidConstructs(cursor); err != nil { + a.setError(err) + return true + } + if err := a.rewriter.down(cursor); err != nil { + a.setError(err) + return true + } + if err := a.binder.down(cursor); err != nil { + a.setError(err) + return true } a.enterProjection(cursor) @@ -163,24 +143,17 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } - if !a.hasRewritten { - if err := a.scoper.up(cursor); err != nil { - a.setError(err) - return false - } - if err := a.tables.up(cursor); err != nil { - a.setError(err) - return false - } - } else { // after expand star - if err := a.scoper.upPost(cursor); err != nil { - a.setError(err) - return false - } - if err := a.typer.up(cursor); err != nil { - a.setError(err) - return false - } + if err := a.scoper.up(cursor); err != nil { + a.setError(err) + return false + } + if err := a.tables.up(cursor); err != nil { + a.setError(err) + return false + } + if err := a.typer.up(cursor); err != nil { + a.setError(err) + return false } a.leaveProjection(cursor) @@ -197,11 +170,7 @@ func containsStar(s sqlparser.SelectExprs) bool { return false } -func checkUnionColumns(cursor *sqlparser.Cursor) error { - union, isUnion := cursor.Node().(*sqlparser.Union) - if !isUnion { - return nil - } +func checkUnionColumns(union *sqlparser.Union) error { firstProj := sqlparser.GetFirstSelect(union).SelectExprs if containsStar(firstProj) { // if we still have *, we can't figure out if the query is invalid or not @@ -338,6 +307,10 @@ func (a *analyzer) checkForInvalidConstructs(cursor *sqlparser.Cursor) error { if err != nil { return err } + err = checkUnionColumns(node) + if err != nil { + return err + } } return nil diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index dab7f00bce7..9819dc71862 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -177,60 +177,10 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { return nil } -func (s *scoper) downPost(cursor *sqlparser.Cursor) { - var scope *scope - var found bool - - switch node := cursor.Node().(type) { - case sqlparser.OrderBy: - scope, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: orderBy}] - case sqlparser.GroupBy: - scope, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: groupBy}] - case *sqlparser.Where: - if node.Type != sqlparser.HavingClause { - break - } - scope, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: having}] - default: - if validAsMapKey(node) { - scope, found = s.sqlNodeScope[scopeKey{node: node}] - } - } - - if found { - s.push(scope) - } -} - func validAsMapKey(s sqlparser.SQLNode) bool { return reflect.TypeOf(s).Comparable() } -func (s *scoper) upPost(cursor *sqlparser.Cursor) error { - var found bool - - switch node := cursor.Node().(type) { - case sqlparser.OrderBy: - _, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: orderBy}] - case sqlparser.GroupBy: - _, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: groupBy}] - case *sqlparser.Where: - if node.Type != sqlparser.HavingClause { - break - } - _, found = s.sqlNodeScope[scopeKey{node: cursor.Parent(), typ: having}] - default: - if validAsMapKey(node) { - _, found = s.sqlNodeScope[scopeKey{node: node}] - } - } - - if found { - s.popScope() - } - return nil -} - func (s *scoper) changeScopeForNode(cursor *sqlparser.Cursor, k scopeKey) error { switch parent := cursor.Parent().(type) { case *sqlparser.Select: From d3727870e929007412db5b13061b959d3c110582 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 6 Oct 2021 09:37:34 +0200 Subject: [PATCH 4/4] simplified the scoper Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_rewrite.go | 2 +- go/vt/vtgate/semantics/analyzer.go | 4 +-- go/vt/vtgate/semantics/early_rewriter.go | 6 ++-- go/vt/vtgate/semantics/scoper.go | 39 ++++++------------------ 4 files changed, 16 insertions(+), 35 deletions(-) diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index e5d2f77bb9f..6da577e73cb 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -3419,7 +3419,7 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla a.cur.replacer = replacer a.cur.parent = parent a.cur.node = node - kontinue := !a.pre(&a.cur) + kontinue := !a.pre(&a.cur) if a.cur.revisit { node = a.cur.node.(SelectExprs) a.cur.revisit = false diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index f13b6fba61b..b208fa1da05 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -38,10 +38,10 @@ type analyzer struct { typer *typer rewriter *earlyRewriter - err error + err error inProjection int - projErr error + projErr error } // newAnalyzer create the semantic analyzer diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 3d09e922df3..613d39c39a5 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -25,14 +25,14 @@ import ( ) type earlyRewriter struct { - scoper *scoper - clause string + scoper *scoper + clause string } func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case sqlparser.SelectExprs: - _, isSel:=cursor.Parent().(*sqlparser.Select) + _, isSel := cursor.Parent().(*sqlparser.Select) if !isSel { return nil } diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 9819dc71862..04a47c562f5 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -30,23 +30,15 @@ type ( // scoper is responsible for figuring out the scoping for the query, // and keeps the current scope when walking the tree scoper struct { - rScope map[*sqlparser.Select]*scope - wScope map[*sqlparser.Select]*scope - sqlNodeScope map[scopeKey]*scope - scopes []*scope - org originable + rScope map[*sqlparser.Select]*scope + wScope map[*sqlparser.Select]*scope + scopes []*scope + org originable // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 specialExprScopes map[*sqlparser.Literal]*scope } - scopeKey struct { - typ keyType - node sqlparser.SQLNode - } - - keyType int8 - scope struct { parent *scope selectStmt *sqlparser.Select @@ -55,18 +47,10 @@ type ( } ) -const ( - _ keyType = iota - orderBy - groupBy - having -) - func newScoper() *scoper { return &scoper{ rScope: map[*sqlparser.Select]*scope{}, wScope: map[*sqlparser.Select]*scope{}, - sqlNodeScope: map[scopeKey]*scope{}, specialExprScopes: map[*sqlparser.Literal]*scope{}, } } @@ -83,7 +67,6 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { s.rScope[node] = currScope s.wScope[node] = newScope(nil) - s.sqlNodeScope[scopeKey{node: node}] = currScope case sqlparser.TableExpr: if isParentSelect(cursor) { // when checking the expressions used in JOIN conditions, special rules apply where the ON expression @@ -93,7 +76,6 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { nScope := newScope(nil) nScope.selectStmt = cursor.Parent().(*sqlparser.Select) s.push(nScope) - s.sqlNodeScope[scopeKey{node: node}] = nScope } case sqlparser.SelectExprs: sel, parentIsSelect := cursor.Parent().(*sqlparser.Select) @@ -109,7 +91,7 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { } wScope.tables = append(wScope.tables, createVTableInfoForExpressions(node, s.currentScope().tables, s.org)) case sqlparser.OrderBy: - err := s.changeScopeForNode(cursor, scopeKey{node: cursor.Parent(), typ: orderBy}) + err := s.createSpecialScopePostProjection(cursor.Parent()) if err != nil { return err } @@ -120,7 +102,7 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { } } case sqlparser.GroupBy: - err := s.changeScopeForNode(cursor, scopeKey{node: cursor.Parent(), typ: groupBy}) + err := s.createSpecialScopePostProjection(cursor.Parent()) if err != nil { return err } @@ -134,7 +116,7 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { if node.Type != sqlparser.HavingClause { break } - return s.changeScopeForNode(cursor, scopeKey{node: cursor.Parent(), typ: having}) + return s.createSpecialScopePostProjection(cursor.Parent()) } return nil } @@ -181,15 +163,15 @@ func validAsMapKey(s sqlparser.SQLNode) bool { return reflect.TypeOf(s).Comparable() } -func (s *scoper) changeScopeForNode(cursor *sqlparser.Cursor, k scopeKey) error { - switch parent := cursor.Parent().(type) { +// createSpecialScopePostProjection is used for the special projection in ORDER BY, GROUP BY and HAVING +func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) error { + switch parent := parent.(type) { case *sqlparser.Select: // In ORDER BY, GROUP BY and HAVING, we can see both the scope in the FROM part of the query, and the SELECT columns created // so before walking the rest of the tree, we change the scope to match this behaviour incomingScope := s.currentScope() nScope := newScope(incomingScope) s.push(nScope) - s.sqlNodeScope[k] = nScope wScope := s.wScope[parent] nScope.tables = append(nScope.tables, wScope.tables...) nScope.selectStmt = incomingScope.selectStmt @@ -221,7 +203,6 @@ func (s *scoper) changeScopeForNode(cursor *sqlparser.Cursor, k scopeKey) error } s.push(nScope) - s.sqlNodeScope[k] = nScope } return nil }