Skip to content

Commit

Permalink
Scan the input AST to establish a non-conflicting expression ID for p…
Browse files Browse the repository at this point in the history
…runed expressions
  • Loading branch information
TristonianJones committed May 12, 2023
1 parent c08c0cc commit b545a27
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
40 changes: 40 additions & 0 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,46 @@ func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalSt
maxID = id + 1
}
}
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if e.GetId() >= maxID {
maxID = e.GetId() + 1
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range expr.GetStructExpr().GetEntries() {
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
if entry.GetId() >= maxID {
maxID = entry.GetId() + 1
}
}
}
}

pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
Expand Down
12 changes: 12 additions & 0 deletions interpreter/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package interpreter

import (
"fmt"
"testing"

"github.com/google/cel-go/common"
Expand Down Expand Up @@ -157,6 +158,16 @@ var testCases = []testInfo{
expr: `foo == "bar" && r.attr.loc in ["GB", "US"]`,
out: `r.attr.loc in ["GB", "US"]`,
},
{
in: partialActivation(map[string]any{
"users": []map[string]string{
{"name": "alice", "role": "EMPLOYEE"},
{"name": "bob", "role": "MANAGER"},
{"name": "eve", "role": "CUSTOMER"},
}}, "r.attr.*"),
expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`,
out: `["bob"] == r.attr.authorized["managers"]`,
},
// TODO: the output of an expression like this relies on either
// a) doing replacements on the original macro call, or
// b) mutating the macro call tracking data rather than the core
Expand Down Expand Up @@ -196,6 +207,7 @@ func TestPrune(t *testing.T) {
t.Error(err)
}
if !test.Compare(actual, tst.out) {
fmt.Println(ast.GetExpr())
t.Errorf("prune[%d], diff: %s", i, test.DiffMessage("structure", actual, tst.out))
}
}
Expand Down

0 comments on commit b545a27

Please sign in to comment.