diff --git a/interpreter/evalstate.go b/interpreter/evalstate.go index cc0d3e6f..4bdd1fdc 100644 --- a/interpreter/evalstate.go +++ b/interpreter/evalstate.go @@ -66,7 +66,11 @@ func (s *evalState) Value(exprID int64) (ref.Val, bool) { // SetValue is an implementation of the EvalState interface method. func (s *evalState) SetValue(exprID int64, val ref.Val) { - s.values[exprID] = val + if val == nil { + delete(s.values, exprID) + } else { + s.values[exprID] = val + } } // Reset implements the EvalState interface method. diff --git a/interpreter/prune.go b/interpreter/prune.go index d1b5d6bd..d94d9aeb 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -341,6 +341,11 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { } } if macro, found := p.macroCalls[node.GetId()]; found { + // Ensure that intermediate values for the comprehension are cleared during pruning + compre := node.GetComprehensionExpr() + if compre != nil { + visit(macro, clearIterVarVisitor(compre.IterVar, p.state)) + } // prune the expression in terms of the macro call instead of the expanded form. if newMacro, pruned := p.prune(macro); pruned { p.macroCalls[node.GetId()] = newMacro @@ -524,6 +529,17 @@ func getMaxID(expr *exprpb.Expr) int64 { return maxID } +func clearIterVarVisitor(varName string, state EvalState) astVisitor { + return astVisitor{ + visitExpr: func(e *exprpb.Expr) { + ident := e.GetIdentExpr() + if ident != nil && ident.GetName() == varName { + state.SetValue(e.GetId(), nil) + } + }, + } +} + func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ visitExpr: func(e *exprpb.Expr) { @@ -543,7 +559,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) { exprs := []*exprpb.Expr{expr} for len(exprs) != 0 { e := exprs[0] - visitor.visitExpr(e) + if visitor.visitExpr != nil { + visitor.visitExpr(e) + } exprs = exprs[1:] switch e.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: @@ -567,7 +585,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) { exprs = append(exprs, list.GetElements()...) case *exprpb.Expr_StructExpr: for _, entry := range e.GetStructExpr().GetEntries() { - visitor.visitEntry(entry) + if visitor.visitEntry != nil { + visitor.visitEntry(entry) + } if entry.GetMapKey() != nil { exprs = append(exprs, entry.GetMapKey()) } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 3a759aba..8ea3edff 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -370,6 +370,20 @@ var testCases = []testInfo{ expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`, out: `["bob"] == r.attr.authorized["managers"]`, }, + { + in: partialActivation(map[string]any{ + "users": []string{"alice", "bob"}, + }, NewAttributePattern("r").QualString("attr").Wildcard()), + expr: `users.filter(u, u.startsWith(r.attr.prefix))`, + out: `["alice", "bob"].filter(u, u.startsWith(r.attr.prefix))`, + }, + { + in: partialActivation(map[string]any{ + "users": []string{"alice", "bob"}, + }, NewAttributePattern("r").QualString("attr").Wildcard()), + expr: `users.filter(u, r.attr.prefix.endsWith(u))`, + out: `["alice", "bob"].filter(u, r.attr.prefix.endsWith(u))`, + }, { in: unknownActivation("four"), expr: `[1+3, 2+2, 3+1, four]`,