diff --git a/cel/inlining.go b/cel/inlining.go index e3383b68..78d5bea6 100644 --- a/cel/inlining.go +++ b/cel/inlining.go @@ -88,7 +88,7 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) { for _, match := range matches { // Copy the inlined AST expr and source info. - copyExpr := copyASTAndMetadata(ctx, inlineVar.def) + copyExpr := ctx.CopyASTAndMetadata(inlineVar.def) opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type()) } continue @@ -121,26 +121,15 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A } // Copy the inlined AST expr and source info. - copyExpr := copyASTAndMetadata(ctx, inlineVar.def) + copyExpr := ctx.CopyASTAndMetadata(inlineVar.def) // Update the least common ancestor by inserting a cel.bind() call to the alias. inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca) opt.inlineExpr(ctx, lca, inlined, inlineVar.Type()) - ctx.sourceInfo.SetMacroCall(lca.ID(), bindMacro) + ctx.SetMacroCall(lca.ID(), bindMacro) } return a } -// copyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being -// optimized. -func copyASTAndMetadata(ctx *OptimizerContext, a *ast.AST) ast.Expr { - copyExpr, copyInfo := ctx.CopyAST(a) - // Add in the macro calls from the inlined AST - for id, call := range copyInfo.MacroCalls() { - ctx.sourceInfo.SetMacroCall(id, call) - } - return copyExpr -} - // inlineExpr replaces the current expression with the inlined one, unless the location of the inlining // happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is // made to determine whether the inlined value can be presence or existence tested. @@ -168,11 +157,11 @@ func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, i if inlined.Kind() == ast.SelectKind { presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined) ctx.UpdateExpr(prev, presenceTest) - ctx.sourceInfo.SetMacroCall(prev.ID(), hasMacro) + ctx.SetMacroCall(prev.ID(), hasMacro) return } - ctx.sourceInfo.ClearMacroCall(prev.ID()) + ctx.ClearMacroCall(prev.ID()) if inlinedType.IsAssignableType(NullType) { ctx.UpdateExpr(prev, ctx.NewCall(operators.NotEquals, diff --git a/cel/optimizer.go b/cel/optimizer.go index 99aeeb81..f26df462 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -212,6 +212,12 @@ type optimizerExprFactory struct { sourceInfo *ast.SourceInfo } +// NewAST creates an AST from the current expression using the tracked source info which +// is modified and managed by the OptimizerContext. +func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST { + return ast.NewAST(expr, opt.sourceInfo) +} + // CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the // renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions // between copies. @@ -226,6 +232,27 @@ func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) return copyExpr, copyInfo } +// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being +// optimized. +func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr { + copyExpr, copyInfo := opt.CopyAST(a) + for macroID, call := range copyInfo.MacroCalls() { + opt.SetMacroCall(macroID, call) + } + return copyExpr +} + +// ClearMacroCall clears the macro at the given expression id. +func (opt *optimizerExprFactory) ClearMacroCall(id int64) { + opt.sourceInfo.ClearMacroCall(id) +} + +// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info +// metadata. +func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) { + opt.sourceInfo.SetMacroCall(id, expr) +} + // NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression // representing the unexpanded call signature to be inserted into the source info macro call metadata. func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) { @@ -239,7 +266,7 @@ func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, var return id }) if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists { - opt.sourceInfo.SetMacroCall(remainingID, opt.fac.CopyExpr(call)) + opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call)) } astExpr = opt.fac.NewComprehension(macroID, diff --git a/cel/optimizer_test.go b/cel/optimizer_test.go index 713d11c8..8ecd8216 100644 --- a/cel/optimizer_test.go +++ b/cel/optimizer_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/ext" proto3pb "github.com/google/cel-go/test/proto3pb" ) @@ -29,18 +30,7 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) { expr := `has(a.b)` inlined := `[x, y].filter(i, i.size() > 0)[0].z` - opts := []cel.EnvOption{ - cel.Types(&proto3pb.TestAllTypes{}), - cel.OptionalTypes(), - cel.EnableMacroCallTracking(), - cel.Variable("a", cel.MapType(cel.StringType, cel.StringType)), - cel.Variable("x", cel.MapType(cel.StringType, cel.StringType)), - cel.Variable("y", cel.MapType(cel.StringType, cel.StringType)), - } - e, err := cel.NewEnv(opts...) - if err != nil { - t.Fatalf("NewEnv() failed: %v", err) - } + e := optimizerEnv(t) exprAST, iss := e.Compile(expr) if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) @@ -65,6 +55,51 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) { } } +func TestStaticOptimizerNewAST(t *testing.T) { + tests := []string{ + `[3, 2, 1]`, + `[1, 2, 3].all(i, i != 0)`, + `cel.bind(m, {"a": 1, "b": 2}, m.filter(k, m[k] > 1))`, + } + for _, tst := range tests { + tc := tst + t.Run(tc, func(t *testing.T) { + e := optimizerEnv(t) + exprAST, iss := e.Compile(tc) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", tc, iss.Err()) + } + opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + optAST, iss := opt.Optimize(e, exprAST) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + optString, err := cel.AstToString(optAST) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if tc != optString { + t.Errorf("identity optimizer got %q, wanted %q", optString, tc) + } + }) + } +} + +type identityOptimizer struct { + t *testing.T +} + +func (opt *identityOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { + opt.t.Helper() + // The copy method should effectively update all of the old macro refs with new ones that are + // identical, but renumbered. + main := ctx.CopyASTAndMetadata(a) + // The new AST call will create a parsed expression which will be type-checked by the static + // optimizer. The input and output expressions should be identical, though may vary by number + // though. + return ctx.NewAST(main) +} + type testOptimizer struct { t *testing.T inlineExpr *ast.AST @@ -106,3 +141,21 @@ func getMacroKeys(macroCalls map[int64]ast.Expr) []int { sort.Ints(keys) return keys } + +func optimizerEnv(t *testing.T) *cel.Env { + t.Helper() + opts := []cel.EnvOption{ + cel.Types(&proto3pb.TestAllTypes{}), + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + ext.Bindings(), + cel.Variable("a", cel.MapType(cel.StringType, cel.StringType)), + cel.Variable("x", cel.MapType(cel.StringType, cel.StringType)), + cel.Variable("y", cel.MapType(cel.StringType, cel.StringType)), + } + e, err := cel.NewEnv(opts...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + return e +}