diff --git a/interpreter/prune.go b/interpreter/prune.go index b8b015a7..24c7e79d 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -67,10 +67,15 @@ type astPruner struct { // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { + pruneState := NewEvalState() + for _, id := range state.IDs() { + v, _ := state.Value(id) + pruneState.SetValue(id, v) + } pruner := &astPruner{ expr: expr, macroCalls: macroCalls, - state: state, + state: pruneState, nextExprID: 1} newExpr, _ := pruner.maybePrune(expr) return &exprpb.ParsedExpr{ @@ -91,24 +96,31 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { switch val.Type() { case types.BoolType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true case types.IntType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true case types.UintType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true case types.StringType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true case types.DoubleType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true case types.BytesType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true case types.NullType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true } @@ -128,6 +140,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo } elemExprs[i] = elemExpr } + p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_ListExpr{ @@ -167,6 +180,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo entries[i] = entry i++ } + p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_StructExpr{ @@ -182,6 +196,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo return nil, false } +func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { + if !p.existsWithUnknownValue(node.GetId()) { + return nil, false + } + call := node.GetCallExpr() + val, valueExists := p.value(call.GetArgs()[1].GetId()) + if !valueExists { + return nil, false + } + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { + return p.maybeCreateLiteral(node.GetId(), types.False) + } + return nil, false +} + +func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { + if !p.existsWithUnknownValue(node.GetId()) { + return nil, false + } + call := node.GetCallExpr() + arg := call.GetArgs()[0] + v, exists := p.value(arg.GetId()) + if !exists { + return nil, false + } + if b, ok := v.(types.Bool); ok { + return p.maybeCreateLiteral(node.GetId(), !b) + } + return nil, false +} + func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { if !p.existsWithUnknownValue(node.GetId()) { return nil, false @@ -224,7 +269,12 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { if call.Function == operators.Conditional { return p.maybePruneConditional(node) } - + if call.Function == operators.In { + return p.maybePruneIn(node) + } + if call.Function == operators.LogicalNot { + return p.maybePruneLogicalNot(node) + } return nil, false } @@ -266,10 +316,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { }, true } case *exprpb.Expr_CallExpr: - if newExpr, pruned := p.maybePruneFunction(node); pruned { - newExpr, _ = p.maybePrune(newExpr) - return newExpr, true - } var prunedCall bool call := node.GetCallExpr() args := call.GetArgs() @@ -290,13 +336,18 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { prunedCall = true newCall.Target = newTarget } + newNode := &exprpb.Expr{ + Id: node.GetId(), + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: newCall, + }, + } + if newExpr, pruned := p.maybePruneFunction(newNode); pruned { + newExpr, _ = p.maybePrune(newExpr) + return newExpr, true + } if prunedCall { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: newCall, - }, - }, true + return newNode, true } case *exprpb.Expr_ListExpr: elems := node.GetListExpr().GetElements() diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 6320d6ec..dd257490 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -56,6 +56,32 @@ var testCases = []testInfo{ expr: `a && [1, 1u, 1.0].exists(x, type(x) == uint)`, out: `a`, }, + { + in: unknownActivation("this"), + expr: `this in []`, + out: `false`, + }, + { + in: unknownActivation("this"), + expr: `this in {}`, + out: `false`, + }, + { + in: partialActivation(map[string]any{"rules": []string{}}, "this"), + expr: `this in rules`, + out: `false`, + }, + { + in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"), + expr: `this.size() > 0 ? this in rules.not_in : !(this in rules.not_in)`, + out: `(this.size() > 0) ? false : true`, + }, + { + in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"), + expr: `this.size() > 0 ? this in rules.not_in : + !(this in rules.not_in) ? true : false`, + out: `(this.size() > 0) ? false : true`, + }, { expr: `{'hello': 'world'.size()}`, out: `{"hello": 5}`, @@ -96,6 +122,11 @@ var testCases = []testInfo{ expr: `true ? b < 1.2 : c == ['hello']`, out: `b < 1.2`, }, + { + in: unknownActivation("b", "c"), + expr: `false ? b < 1.2 : c == ['hello']`, + out: `c == ["hello"]`, + }, { in: unknownActivation(), expr: `[1+3, 2+2, 3+1, four]`, @@ -121,18 +152,27 @@ var testCases = []testInfo{ expr: `test in {'a': 1, 'field': [test, 3]}.field`, out: `test in {"a": 1, "field": [test, 3]}.field`, }, - // TODO(issues/) the output test relies on tracking macro expansions back to their original - // call patterns. - /* { - in: unknownActivation(), - expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, - out: `[4, 4, 4, four].exists(x, x == four)`, - }, */ + // TODO: the output of an expression like this relies on either + // a) doing replacements on the original macro call, or + // b) mutating the macro call tracking data rather than the core + // expression in order to render the partial correctly. + // { + // in: unknownActivation(), + // expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, + // out: `[4, 4, 4, four].exists(x, x == four)`, + // }, } func TestPrune(t *testing.T) { + p, err := parser.NewParser( + parser.PopulateMacroCalls(true), + parser.Macros(parser.AllMacros...), + ) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } for i, tst := range testCases { - ast, iss := parser.Parse(common.NewStringSource(tst.expr, "")) + ast, iss := p.Parse(common.NewStringSource(tst.expr, "")) if len(iss.GetErrors()) > 0 { t.Fatalf(iss.ToDisplayString()) } @@ -142,10 +182,10 @@ func TestPrune(t *testing.T) { interp := NewStandardInterpreter(containers.DefaultContainer, reg, reg, attrs) interpretable, _ := interp.NewUncheckedInterpretable( - ast.Expr, + ast.GetExpr(), ExhaustiveEval(), Observe(EvalStateObserver(state))) interpretable.Eval(testActivation(t, tst.in)) - newExpr := PruneAst(ast.Expr, ast.SourceInfo.GetMacroCalls(), state) + newExpr := PruneAst(ast.GetExpr(), ast.GetSourceInfo().GetMacroCalls(), state) actual, err := parser.Unparse(newExpr.GetExpr(), newExpr.GetSourceInfo()) if err != nil { t.Error(err) @@ -165,6 +205,15 @@ func unknownActivation(vars ...string) PartialActivation { return a } +func partialActivation(in map[string]any, vars ...string) PartialActivation { + pats := make([]*AttributePattern, len(vars), len(vars)) + for i, v := range vars { + pats[i] = NewAttributePattern(v) + } + a, _ := NewPartialActivation(in, pats...) + return a +} + func testActivation(t *testing.T, in any) Activation { t.Helper() if in == nil {