From 68d7302ab0616d64f85a5924e53f801f1b451513 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 6 Apr 2023 08:52:37 -0700 Subject: [PATCH] Prune recursion fixes for nested logic (#677) In expressions where the logic is nested and the residual state would result in the production of a constant value in the expression, ensure that the intermediate state for the expression is updated to reflect the constant value. Also, ensure that special cases of pruning for logical operators happen after argument pruning has happened to ensure that the prune steps are properly recursive. --- interpreter/prune.go | 75 ++++++++++++++++++++++++++++++++------- interpreter/prune_test.go | 69 +++++++++++++++++++++++++++++------ 2 files changed, 122 insertions(+), 22 deletions(-) diff --git a/interpreter/prune.go b/interpreter/prune.go index b8b015a7a..24c7e79d6 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -67,10 +67,15 @@ type astPruner struct { // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { + pruneState := NewEvalState() + for _, id := range state.IDs() { + v, _ := state.Value(id) + pruneState.SetValue(id, v) + } pruner := &astPruner{ expr: expr, macroCalls: macroCalls, - state: state, + state: pruneState, nextExprID: 1} newExpr, _ := pruner.maybePrune(expr) return &exprpb.ParsedExpr{ @@ -91,24 +96,31 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { switch val.Type() { case types.BoolType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true case types.IntType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true case types.UintType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true case types.StringType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true case types.DoubleType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true case types.BytesType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true case types.NullType: + p.state.SetValue(id, val) return p.createLiteral(id, &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true } @@ -128,6 +140,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo } elemExprs[i] = elemExpr } + p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_ListExpr{ @@ -167,6 +180,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo entries[i] = entry i++ } + p.state.SetValue(id, val) return &exprpb.Expr{ Id: id, ExprKind: &exprpb.Expr_StructExpr{ @@ -182,6 +196,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo return nil, false } +func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { + if !p.existsWithUnknownValue(node.GetId()) { + return nil, false + } + call := node.GetCallExpr() + val, valueExists := p.value(call.GetArgs()[1].GetId()) + if !valueExists { + return nil, false + } + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { + return p.maybeCreateLiteral(node.GetId(), types.False) + } + return nil, false +} + +func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { + if !p.existsWithUnknownValue(node.GetId()) { + return nil, false + } + call := node.GetCallExpr() + arg := call.GetArgs()[0] + v, exists := p.value(arg.GetId()) + if !exists { + return nil, false + } + if b, ok := v.(types.Bool); ok { + return p.maybeCreateLiteral(node.GetId(), !b) + } + return nil, false +} + func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { if !p.existsWithUnknownValue(node.GetId()) { return nil, false @@ -224,7 +269,12 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { if call.Function == operators.Conditional { return p.maybePruneConditional(node) } - + if call.Function == operators.In { + return p.maybePruneIn(node) + } + if call.Function == operators.LogicalNot { + return p.maybePruneLogicalNot(node) + } return nil, false } @@ -266,10 +316,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { }, true } case *exprpb.Expr_CallExpr: - if newExpr, pruned := p.maybePruneFunction(node); pruned { - newExpr, _ = p.maybePrune(newExpr) - return newExpr, true - } var prunedCall bool call := node.GetCallExpr() args := call.GetArgs() @@ -290,13 +336,18 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { prunedCall = true newCall.Target = newTarget } + newNode := &exprpb.Expr{ + Id: node.GetId(), + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: newCall, + }, + } + if newExpr, pruned := p.maybePruneFunction(newNode); pruned { + newExpr, _ = p.maybePrune(newExpr) + return newExpr, true + } if prunedCall { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: newCall, - }, - }, true + return newNode, true } case *exprpb.Expr_ListExpr: elems := node.GetListExpr().GetElements() diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 6320d6ec1..dd2574901 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -56,6 +56,32 @@ var testCases = []testInfo{ expr: `a && [1, 1u, 1.0].exists(x, type(x) == uint)`, out: `a`, }, + { + in: unknownActivation("this"), + expr: `this in []`, + out: `false`, + }, + { + in: unknownActivation("this"), + expr: `this in {}`, + out: `false`, + }, + { + in: partialActivation(map[string]any{"rules": []string{}}, "this"), + expr: `this in rules`, + out: `false`, + }, + { + in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"), + expr: `this.size() > 0 ? this in rules.not_in : !(this in rules.not_in)`, + out: `(this.size() > 0) ? false : true`, + }, + { + in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"), + expr: `this.size() > 0 ? this in rules.not_in : + !(this in rules.not_in) ? true : false`, + out: `(this.size() > 0) ? false : true`, + }, { expr: `{'hello': 'world'.size()}`, out: `{"hello": 5}`, @@ -96,6 +122,11 @@ var testCases = []testInfo{ expr: `true ? b < 1.2 : c == ['hello']`, out: `b < 1.2`, }, + { + in: unknownActivation("b", "c"), + expr: `false ? b < 1.2 : c == ['hello']`, + out: `c == ["hello"]`, + }, { in: unknownActivation(), expr: `[1+3, 2+2, 3+1, four]`, @@ -121,18 +152,27 @@ var testCases = []testInfo{ expr: `test in {'a': 1, 'field': [test, 3]}.field`, out: `test in {"a": 1, "field": [test, 3]}.field`, }, - // TODO(issues/) the output test relies on tracking macro expansions back to their original - // call patterns. - /* { - in: unknownActivation(), - expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, - out: `[4, 4, 4, four].exists(x, x == four)`, - }, */ + // TODO: the output of an expression like this relies on either + // a) doing replacements on the original macro call, or + // b) mutating the macro call tracking data rather than the core + // expression in order to render the partial correctly. + // { + // in: unknownActivation(), + // expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, + // out: `[4, 4, 4, four].exists(x, x == four)`, + // }, } func TestPrune(t *testing.T) { + p, err := parser.NewParser( + parser.PopulateMacroCalls(true), + parser.Macros(parser.AllMacros...), + ) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } for i, tst := range testCases { - ast, iss := parser.Parse(common.NewStringSource(tst.expr, "")) + ast, iss := p.Parse(common.NewStringSource(tst.expr, "")) if len(iss.GetErrors()) > 0 { t.Fatalf(iss.ToDisplayString()) } @@ -142,10 +182,10 @@ func TestPrune(t *testing.T) { interp := NewStandardInterpreter(containers.DefaultContainer, reg, reg, attrs) interpretable, _ := interp.NewUncheckedInterpretable( - ast.Expr, + ast.GetExpr(), ExhaustiveEval(), Observe(EvalStateObserver(state))) interpretable.Eval(testActivation(t, tst.in)) - newExpr := PruneAst(ast.Expr, ast.SourceInfo.GetMacroCalls(), state) + newExpr := PruneAst(ast.GetExpr(), ast.GetSourceInfo().GetMacroCalls(), state) actual, err := parser.Unparse(newExpr.GetExpr(), newExpr.GetSourceInfo()) if err != nil { t.Error(err) @@ -165,6 +205,15 @@ func unknownActivation(vars ...string) PartialActivation { return a } +func partialActivation(in map[string]any, vars ...string) PartialActivation { + pats := make([]*AttributePattern, len(vars), len(vars)) + for i, v := range vars { + pats[i] = NewAttributePattern(v) + } + a, _ := NewPartialActivation(in, pats...) + return a +} + func testActivation(t *testing.T, in any) Activation { t.Helper() if in == nil {