Skip to content

Commit

Permalink
Clear iteration variable data during expression pruning (#740)
Browse files Browse the repository at this point in the history
* Clear iteration variable data during expression pruning
* Remove unnecessary state clearing call
  • Loading branch information
TristonianJones authored Jun 15, 2023
1 parent 9812c42 commit ecf04a5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
6 changes: 5 additions & 1 deletion interpreter/evalstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ func (s *evalState) Value(exprID int64) (ref.Val, bool) {

// SetValue is an implementation of the EvalState interface method.
func (s *evalState) SetValue(exprID int64, val ref.Val) {
s.values[exprID] = val
if val == nil {
delete(s.values, exprID)
} else {
s.values[exprID] = val
}
}

// Reset implements the EvalState interface method.
Expand Down
24 changes: 22 additions & 2 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
compre := node.GetComprehensionExpr()
if compre != nil {
visit(macro, clearIterVarVisitor(compre.IterVar, p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
Expand Down Expand Up @@ -524,6 +529,17 @@ func getMaxID(expr *exprpb.Expr) int64 {
return maxID
}

func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
ident := e.GetIdentExpr()
if ident != nil && ident.GetName() == varName {
state.SetValue(e.GetId(), nil)
}
},
}
}

func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
Expand All @@ -543,7 +559,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
visitor.visitExpr(e)
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
Expand All @@ -567,7 +585,9 @@ func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
visitor.visitEntry(entry)
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
Expand Down
14 changes: 14 additions & 0 deletions interpreter/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ var testCases = []testInfo{
expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`,
out: `["bob"] == r.attr.authorized["managers"]`,
},
{
in: partialActivation(map[string]any{
"users": []string{"alice", "bob"},
}, NewAttributePattern("r").QualString("attr").Wildcard()),
expr: `users.filter(u, u.startsWith(r.attr.prefix))`,
out: `["alice", "bob"].filter(u, u.startsWith(r.attr.prefix))`,
},
{
in: partialActivation(map[string]any{
"users": []string{"alice", "bob"},
}, NewAttributePattern("r").QualString("attr").Wildcard()),
expr: `users.filter(u, r.attr.prefix.endsWith(u))`,
out: `["alice", "bob"].filter(u, r.attr.prefix.endsWith(u))`,
},
{
in: unknownActivation("four"),
expr: `[1+3, 2+2, 3+1, four]`,
Expand Down

0 comments on commit ecf04a5

Please sign in to comment.