Skip to content

Commit

Permalink
Introduce new helper APIs for optimizers (#903)
Browse files Browse the repository at this point in the history
* Introduce new helper APIs for optimizers
  • Loading branch information
TristonianJones authored Mar 6, 2024
1 parent 13b3d56 commit 9775e65
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 29 deletions.
1 change: 1 addition & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ go_test(
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//ext:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
Expand Down
21 changes: 5 additions & 16 deletions cel/inlining.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
Expand Down Expand Up @@ -121,26 +121,15 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A
}

// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
copyExpr := ctx.CopyASTAndMetadata(inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.sourceInfo.SetMacroCall(lca.ID(), bindMacro)
ctx.SetMacroCall(lca.ID(), bindMacro)
}
return a
}

// copyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func copyASTAndMetadata(ctx *OptimizerContext, a *ast.AST) ast.Expr {
copyExpr, copyInfo := ctx.CopyAST(a)
// Add in the macro calls from the inlined AST
for id, call := range copyInfo.MacroCalls() {
ctx.sourceInfo.SetMacroCall(id, call)
}
return copyExpr
}

// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
Expand Down Expand Up @@ -168,11 +157,11 @@ func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, i
if inlined.Kind() == ast.SelectKind {
presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.sourceInfo.SetMacroCall(prev.ID(), hasMacro)
ctx.SetMacroCall(prev.ID(), hasMacro)
return
}

ctx.sourceInfo.ClearMacroCall(prev.ID())
ctx.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
Expand Down
29 changes: 28 additions & 1 deletion cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ type optimizerExprFactory struct {
sourceInfo *ast.SourceInfo
}

// NewAST creates an AST from the current expression using the tracked source info which
// is modified and managed by the OptimizerContext.
func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST {
return ast.NewAST(expr, opt.sourceInfo)
}

// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the
// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions
// between copies.
Expand All @@ -226,6 +232,27 @@ func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo)
return copyExpr, copyInfo
}

// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
copyExpr, copyInfo := opt.CopyAST(a)
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
return copyExpr
}

// ClearMacroCall clears the macro at the given expression id.
func (opt *optimizerExprFactory) ClearMacroCall(id int64) {
opt.sourceInfo.ClearMacroCall(id)
}

// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info
// metadata.
func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}

// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
Expand All @@ -239,7 +266,7 @@ func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, var
return id
})
if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists {
opt.sourceInfo.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call))
}

astExpr = opt.fac.NewComprehension(macroID,
Expand Down
77 changes: 65 additions & 12 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/ext"

proto3pb "github.com/google/cel-go/test/proto3pb"
)
Expand All @@ -29,18 +30,7 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
expr := `has(a.b)`
inlined := `[x, y].filter(i, i.size() > 0)[0].z`

opts := []cel.EnvOption{
cel.Types(&proto3pb.TestAllTypes{}),
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
cel.Variable("a", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("x", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("y", cel.MapType(cel.StringType, cel.StringType)),
}
e, err := cel.NewEnv(opts...)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
e := optimizerEnv(t)
exprAST, iss := e.Compile(expr)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
Expand All @@ -65,6 +55,51 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
}
}

func TestStaticOptimizerNewAST(t *testing.T) {
tests := []string{
`[3, 2, 1]`,
`[1, 2, 3].all(i, i != 0)`,
`cel.bind(m, {"a": 1, "b": 2}, m.filter(k, m[k] > 1))`,
}
for _, tst := range tests {
tc := tst
t.Run(tc, func(t *testing.T) {
e := optimizerEnv(t)
exprAST, iss := e.Compile(tc)
if iss.Err() != nil {
t.Fatalf("Compile(%q) failed: %v", tc, iss.Err())
}
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
optAST, iss := opt.Optimize(e, exprAST)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
}
optString, err := cel.AstToString(optAST)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
if tc != optString {
t.Errorf("identity optimizer got %q, wanted %q", optString, tc)
}
})
}
}

type identityOptimizer struct {
t *testing.T
}

func (opt *identityOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
opt.t.Helper()
// The copy method should effectively update all of the old macro refs with new ones that are
// identical, but renumbered.
main := ctx.CopyASTAndMetadata(a)
// The new AST call will create a parsed expression which will be type-checked by the static
// optimizer. The input and output expressions should be identical, though may vary by number
// though.
return ctx.NewAST(main)
}

type testOptimizer struct {
t *testing.T
inlineExpr *ast.AST
Expand Down Expand Up @@ -106,3 +141,21 @@ func getMacroKeys(macroCalls map[int64]ast.Expr) []int {
sort.Ints(keys)
return keys
}

func optimizerEnv(t *testing.T) *cel.Env {
t.Helper()
opts := []cel.EnvOption{
cel.Types(&proto3pb.TestAllTypes{}),
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
ext.Bindings(),
cel.Variable("a", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("x", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("y", cel.MapType(cel.StringType, cel.StringType)),
}
e, err := cel.NewEnv(opts...)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
return e
}

0 comments on commit 9775e65

Please sign in to comment.