Skip to content

Commit

Permalink
Improve the rewriter to simplify more queries (#14059)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 committed Sep 22, 2023
1 parent 247cc91 commit e790f2e
Show file tree
Hide file tree
Showing 6 changed files with 793 additions and 140 deletions.
48 changes: 33 additions & 15 deletions go/tools/asthelpergen/asthelpergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ type (
}
)

// exprInterfacePath is the path of the sqlparser.Expr interface.
const exprInterfacePath = "vitess.io/vitess/go/vt/sqlparser.Expr"

func (gen *astHelperGen) iface() *types.Interface {
return gen._iface
}
Expand Down Expand Up @@ -200,22 +203,15 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
scopes[pkg.PkgPath] = pkg.Types.Scope()
}

pos := strings.LastIndexByte(options.RootInterface, '.')
if pos < 0 {
return nil, fmt.Errorf("unexpected input type: %s", options.RootInterface)
}

pkgname := options.RootInterface[:pos]
typename := options.RootInterface[pos+1:]

scope := scopes[pkgname]
if scope == nil {
return nil, fmt.Errorf("no scope found for type '%s'", options.RootInterface)
tt, err := findTypeObject(options.RootInterface, scopes)
if err != nil {
return nil, err
}

tt := scope.Lookup(typename)
if tt == nil {
return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname)
exprType, _ := findTypeObject(exprInterfacePath, scopes)
var exprInterface *types.Interface
if exprType != nil {
exprInterface = exprType.Type().(*types.Named).Underlying().(*types.Interface)
}

nt := tt.Type().(*types.Named)
Expand All @@ -224,7 +220,7 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
newEqualsGen(pName, &options.Equals),
newCloneGen(pName, &options.Clone),
newVisitGen(pName),
newRewriterGen(pName, types.TypeString(nt, noQualifier)),
newRewriterGen(pName, types.TypeString(nt, noQualifier), exprInterface),
newCOWGen(pName, nt),
)

Expand All @@ -236,6 +232,28 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
return it, nil
}

// findTypeObject finds the types.Object for the given interface from the given scopes.
func findTypeObject(interfaceToFind string, scopes map[string]*types.Scope) (types.Object, error) {
pos := strings.LastIndexByte(interfaceToFind, '.')
if pos < 0 {
return nil, fmt.Errorf("unexpected input type: %s", interfaceToFind)
}

pkgname := interfaceToFind[:pos]
typename := interfaceToFind[pos+1:]

scope := scopes[pkgname]
if scope == nil {
return nil, fmt.Errorf("no scope found for type '%s'", interfaceToFind)
}

tt := scope.Lookup(typename)
if tt == nil {
return nil, fmt.Errorf("no type called '%s' found in '%s'", typename, pkgname)
}
return tt, nil
}

var _ generatorSPI = (*astHelperGen)(nil)

func (gen *astHelperGen) scope() *types.Scope {
Expand Down
29 changes: 21 additions & 8 deletions go/tools/asthelpergen/rewrite_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ const (
type rewriteGen struct {
ifaceName string
file *jen.File
// exprInterface is used to store the sqlparser.Expr interface
exprInterface *types.Interface
}

var _ generator = (*rewriteGen)(nil)

func newRewriterGen(pkgname string, ifaceName string) *rewriteGen {
func newRewriterGen(pkgname string, ifaceName string, exprInterface *types.Interface) *rewriteGen {
file := jen.NewFile(pkgname)
file.HeaderComment(licenseFileHeader)
file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")

return &rewriteGen{
ifaceName: ifaceName,
file: file,
ifaceName: ifaceName,
file: file,
exprInterface: exprInterface,
}
}

Expand Down Expand Up @@ -105,7 +108,7 @@ func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generat
}
fields := r.rewriteAllStructFields(t, strct, spi, true)

stmts := []jen.Code{executePre()}
stmts := []jen.Code{r.executePre(t)}
stmts = append(stmts, fields...)
stmts = append(stmts, executePost(len(fields) > 0))
stmts = append(stmts, returnTrue())
Expand All @@ -130,7 +133,7 @@ func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi ge
return nil
}
*/
stmts = append(stmts, executePre())
stmts = append(stmts, r.executePre(t))
fields := r.rewriteAllStructFields(t, strct, spi, false)
stmts = append(stmts, fields...)
stmts = append(stmts, executePost(len(fields) > 0))
Expand Down Expand Up @@ -225,9 +228,19 @@ func setupCursor() []jen.Code {
jen.Id("a.cur.node = node"),
}
}
func executePre() jen.Code {
func (r *rewriteGen) executePre(t types.Type) jen.Code {
curStmts := setupCursor()
curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue()))
if r.exprInterface != nil && types.Implements(t, r.exprInterface) {
curStmts = append(curStmts, jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
jen.If(jen.Id("a.cur.revisit").Block(
jen.Id("a.cur.revisit").Op("=").False(),
jen.Return(jen.Id("a.rewriteExpr(parent, a.cur.node.(Expr), replacer)")),
)),
jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))),
)
} else {
curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue()))
}
return jen.If(jen.Id("a.pre!= nil").Block(curStmts...))
}

Expand All @@ -251,7 +264,7 @@ func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI)
return nil
}

stmts := []jen.Code{executePre(), executePost(false), returnTrue()}
stmts := []jen.Code{r.executePre(t), executePost(false), returnTrue()}
r.rewriteFunc(t, stmts)
return nil
}
Expand Down
Loading

0 comments on commit e790f2e

Please sign in to comment.