From b545a278d883460f607101d5fd09e4c0d58d24f9 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Thu, 11 May 2023 18:13:44 -0700 Subject: [PATCH] Scan the input AST to establish a non-conflicting expression ID for pruned expressions --- interpreter/prune.go | 40 +++++++++++++++++++++++++++++++++++++++ interpreter/prune_test.go | 12 ++++++++++++ 2 files changed, 52 insertions(+) diff --git a/interpreter/prune.go b/interpreter/prune.go index 85b3b0659..674964838 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -76,6 +76,46 @@ func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalSt maxID = id + 1 } } + exprs := []*exprpb.Expr{expr} + for len(exprs) != 0 { + e := exprs[0] + if e.GetId() >= maxID { + maxID = e.GetId() + 1 + } + exprs = exprs[1:] + switch e.GetExprKind().(type) { + case *exprpb.Expr_SelectExpr: + exprs = append(exprs, e.GetSelectExpr().GetOperand()) + case *exprpb.Expr_CallExpr: + call := e.GetCallExpr() + if call.GetTarget() != nil { + exprs = append(exprs, call.GetTarget()) + } + exprs = append(exprs, call.GetArgs()...) + case *exprpb.Expr_ComprehensionExpr: + compre := e.GetComprehensionExpr() + exprs = append(exprs, + compre.GetIterRange(), + compre.GetAccuInit(), + compre.GetLoopCondition(), + compre.GetLoopStep(), + compre.GetResult()) + case *exprpb.Expr_ListExpr: + list := e.GetListExpr() + exprs = append(exprs, list.GetElements()...) + case *exprpb.Expr_StructExpr: + for _, entry := range expr.GetStructExpr().GetEntries() { + if entry.GetMapKey() != nil { + exprs = append(exprs, entry.GetMapKey()) + } + exprs = append(exprs, entry.GetValue()) + if entry.GetId() >= maxID { + maxID = entry.GetId() + 1 + } + } + } + } + pruner := &astPruner{ expr: expr, macroCalls: macroCalls, diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index fb7855d5e..46b316e91 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -15,6 +15,7 @@ package interpreter import ( + "fmt" "testing" "github.com/google/cel-go/common" @@ -157,6 +158,16 @@ var testCases = []testInfo{ expr: `foo == "bar" && r.attr.loc in ["GB", "US"]`, out: `r.attr.loc in ["GB", "US"]`, }, + { + in: partialActivation(map[string]any{ + "users": []map[string]string{ + {"name": "alice", "role": "EMPLOYEE"}, + {"name": "bob", "role": "MANAGER"}, + {"name": "eve", "role": "CUSTOMER"}, + }}, "r.attr.*"), + expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`, + out: `["bob"] == r.attr.authorized["managers"]`, + }, // TODO: the output of an expression like this relies on either // a) doing replacements on the original macro call, or // b) mutating the macro call tracking data rather than the core @@ -196,6 +207,7 @@ func TestPrune(t *testing.T) { t.Error(err) } if !test.Compare(actual, tst.out) { + fmt.Println(ast.GetExpr()) t.Errorf("prune[%d], diff: %s", i, test.DiffMessage("structure", actual, tst.out)) } }