diff --git a/interpreter/prune.go b/interpreter/prune.go
index b8b015a7..24c7e79d 100644
--- a/interpreter/prune.go
+++ b/interpreter/prune.go
@@ -67,10 +67,15 @@ type astPruner struct {
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
+ pruneState := NewEvalState()
+ for _, id := range state.IDs() {
+ v, _ := state.Value(id)
+ pruneState.SetValue(id, v)
+ }
pruner := &astPruner{
expr: expr,
macroCalls: macroCalls,
- state: state,
+ state: pruneState,
nextExprID: 1}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
@@ -91,24 +96,31 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
+ p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
}
@@ -128,6 +140,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
}
elemExprs[i] = elemExpr
}
+ p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
@@ -167,6 +180,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
entries[i] = entry
i++
}
+ p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
@@ -182,6 +196,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}
+func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
+ if !p.existsWithUnknownValue(node.GetId()) {
+ return nil, false
+ }
+ call := node.GetCallExpr()
+ val, valueExists := p.value(call.GetArgs()[1].GetId())
+ if !valueExists {
+ return nil, false
+ }
+ if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
+ return p.maybeCreateLiteral(node.GetId(), types.False)
+ }
+ return nil, false
+}
+
+func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
+ if !p.existsWithUnknownValue(node.GetId()) {
+ return nil, false
+ }
+ call := node.GetCallExpr()
+ arg := call.GetArgs()[0]
+ v, exists := p.value(arg.GetId())
+ if !exists {
+ return nil, false
+ }
+ if b, ok := v.(types.Bool); ok {
+ return p.maybeCreateLiteral(node.GetId(), !b)
+ }
+ return nil, false
+}
+
func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
@@ -224,7 +269,12 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if call.Function == operators.Conditional {
return p.maybePruneConditional(node)
}
-
+ if call.Function == operators.In {
+ return p.maybePruneIn(node)
+ }
+ if call.Function == operators.LogicalNot {
+ return p.maybePruneLogicalNot(node)
+ }
return nil, false
}
@@ -266,10 +316,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}, true
}
case *exprpb.Expr_CallExpr:
- if newExpr, pruned := p.maybePruneFunction(node); pruned {
- newExpr, _ = p.maybePrune(newExpr)
- return newExpr, true
- }
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
@@ -290,13 +336,18 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
prunedCall = true
newCall.Target = newTarget
}
+ newNode := &exprpb.Expr{
+ Id: node.GetId(),
+ ExprKind: &exprpb.Expr_CallExpr{
+ CallExpr: newCall,
+ },
+ }
+ if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
+ newExpr, _ = p.maybePrune(newExpr)
+ return newExpr, true
+ }
if prunedCall {
- return &exprpb.Expr{
- Id: node.GetId(),
- ExprKind: &exprpb.Expr_CallExpr{
- CallExpr: newCall,
- },
- }, true
+ return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go
index 6320d6ec..dd257490 100644
--- a/interpreter/prune_test.go
+++ b/interpreter/prune_test.go
@@ -56,6 +56,32 @@ var testCases = []testInfo{
expr: `a && [1, 1u, 1.0].exists(x, type(x) == uint)`,
out: `a`,
},
+ {
+ in: unknownActivation("this"),
+ expr: `this in []`,
+ out: `false`,
+ },
+ {
+ in: unknownActivation("this"),
+ expr: `this in {}`,
+ out: `false`,
+ },
+ {
+ in: partialActivation(map[string]any{"rules": []string{}}, "this"),
+ expr: `this in rules`,
+ out: `false`,
+ },
+ {
+ in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"),
+ expr: `this.size() > 0 ? this in rules.not_in : !(this in rules.not_in)`,
+ out: `(this.size() > 0) ? false : true`,
+ },
+ {
+ in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"),
+ expr: `this.size() > 0 ? this in rules.not_in :
+ !(this in rules.not_in) ? true : false`,
+ out: `(this.size() > 0) ? false : true`,
+ },
{
expr: `{'hello': 'world'.size()}`,
out: `{"hello": 5}`,
@@ -96,6 +122,11 @@ var testCases = []testInfo{
expr: `true ? b < 1.2 : c == ['hello']`,
out: `b < 1.2`,
},
+ {
+ in: unknownActivation("b", "c"),
+ expr: `false ? b < 1.2 : c == ['hello']`,
+ out: `c == ["hello"]`,
+ },
{
in: unknownActivation(),
expr: `[1+3, 2+2, 3+1, four]`,
@@ -121,18 +152,27 @@ var testCases = []testInfo{
expr: `test in {'a': 1, 'field': [test, 3]}.field`,
out: `test in {"a": 1, "field": [test, 3]}.field`,
},
- // TODO(issues/) the output test relies on tracking macro expansions back to their original
- // call patterns.
- /* {
- in: unknownActivation(),
- expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`,
- out: `[4, 4, 4, four].exists(x, x == four)`,
- }, */
+ // TODO: the output of an expression like this relies on either
+ // a) doing replacements on the original macro call, or
+ // b) mutating the macro call tracking data rather than the core
+ // expression in order to render the partial correctly.
+ // {
+ // in: unknownActivation(),
+ // expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`,
+ // out: `[4, 4, 4, four].exists(x, x == four)`,
+ // },
}
func TestPrune(t *testing.T) {
+ p, err := parser.NewParser(
+ parser.PopulateMacroCalls(true),
+ parser.Macros(parser.AllMacros...),
+ )
+ if err != nil {
+ t.Fatalf("parser.NewParser() failed: %v", err)
+ }
for i, tst := range testCases {
- ast, iss := parser.Parse(common.NewStringSource(tst.expr, ""))
+ ast, iss := p.Parse(common.NewStringSource(tst.expr, ""))
if len(iss.GetErrors()) > 0 {
t.Fatalf(iss.ToDisplayString())
}
@@ -142,10 +182,10 @@ func TestPrune(t *testing.T) {
interp := NewStandardInterpreter(containers.DefaultContainer, reg, reg, attrs)
interpretable, _ := interp.NewUncheckedInterpretable(
- ast.Expr,
+ ast.GetExpr(),
ExhaustiveEval(), Observe(EvalStateObserver(state)))
interpretable.Eval(testActivation(t, tst.in))
- newExpr := PruneAst(ast.Expr, ast.SourceInfo.GetMacroCalls(), state)
+ newExpr := PruneAst(ast.GetExpr(), ast.GetSourceInfo().GetMacroCalls(), state)
actual, err := parser.Unparse(newExpr.GetExpr(), newExpr.GetSourceInfo())
if err != nil {
t.Error(err)
@@ -165,6 +205,15 @@ func unknownActivation(vars ...string) PartialActivation {
return a
}
+func partialActivation(in map[string]any, vars ...string) PartialActivation {
+ pats := make([]*AttributePattern, len(vars), len(vars))
+ for i, v := range vars {
+ pats[i] = NewAttributePattern(v)
+ }
+ a, _ := NewPartialActivation(in, pats...)
+ return a
+}
+
func testActivation(t *testing.T, in any) Activation {
t.Helper()
if in == nil {