diff --git a/ast/compile.go b/ast/compile.go index 1aeb683912..2470257d76 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -72,6 +72,7 @@ type Compiler struct { generatedVars map[*Module]VarSet moduleLoader ModuleLoader + ruleIndices *util.HashMap stages []func() } @@ -161,6 +162,12 @@ func NewCompiler() *Compiler { Modules: map[string]*Module{}, TypeEnv: NewTypeEnv(), generatedVars: map[*Module]VarSet{}, + ruleIndices: util.NewHashMap(func(a, b util.T) bool { + r1, r2 := a.(Ref), b.(Ref) + return r1.Equal(r2) + }, func(x util.T) int { + return x.(Ref).Hash() + }), } c.ModuleTree = NewModuleTree(nil) @@ -178,6 +185,7 @@ func NewCompiler() *Compiler { c.checkSafetyRuleBodies, c.checkRecursion, c.checkTypes, + c.buildRuleIndices, } return c @@ -339,6 +347,18 @@ func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { return rules } +// RuleIndex returns a RuleIndex built for the rule set referred to by path. +// The path must refer to the rule set exactly, i.e., given a rule set at path +// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a +// RuleIndex built for the rule. +func (c *Compiler) RuleIndex(path Ref) RuleIndex { + r, ok := c.ruleIndices.Get(path) + if !ok { + return nil + } + return r.(RuleIndex) +} + // ModuleLoader defines the interface that callers can implement to enable lazy // loading of modules during compilation. type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error) @@ -356,6 +376,23 @@ func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler { return c } +// buildRuleIndices constructs indices for rules. +func (c *Compiler) buildRuleIndices() { + + c.RuleTree.DepthFirst(func(node *RuleTreeNode) bool { + if len(node.Rules) > 1 { + index := newBaseDocEqIndex(func(ref Ref) bool { + return len(c.GetRules(ref.GroundPrefix())) > 0 + }) + if index.Build(node.Rules) { + c.ruleIndices.Put(node.Rules[0].Path(), index) + } + } + return false + }) + +} + // checkRecursion ensures that there are no recursive rule definitions, i.e., there are // no cycles in the RuleGraph. func (c *Compiler) checkRecursion() { @@ -690,7 +727,6 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) { qc.resolveRefs, qc.checkWithModifiers, qc.checkSafety, - qc.checkInput, qc.checkTypes, } @@ -742,73 +778,6 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { return reordered, nil } -func (qc *queryCompiler) checkInput(qctx *QueryContext, body Body) (Body, error) { - return body, qc.checkInputRec(qctx.InputDefined(), body) -} - -func (qc *queryCompiler) checkInputRec(definedPrev bool, body Body) error { - - // Perform DFS for conflicting or missing input document. - for _, expr := range body { - - definedCurr := definesInput(expr) - - if definedPrev && definedCurr { - return NewError(InputErr, expr.Location, "input document conflict") - } else if !definedCurr && !definedPrev && referencesInput(expr) { - return NewError(InputErr, expr.Location, "input document not defined") - } - - var err error - - // Check closures contained in this expression. - vis := NewGenericVisitor(func(x interface{}) bool { - if err != nil { - return true - } - switch x := x.(type) { - case *ArrayComprehension: - if err = qc.checkInputRec(definedPrev || definedCurr, x.Body); err != nil { - return true - } - } - return false - }) - - Walk(vis, expr) - - if err != nil { - return err - } - - // Check rule bodies referred to by this expression. - vis = NewGenericVisitor(func(x interface{}) bool { - if err != nil { - return true - } - switch x := x.(type) { - case Ref: - if x.HasPrefix(DefaultRootRef) { - for _, rule := range qc.compiler.GetRules(x.GroundPrefix()) { - if err = qc.checkInputRec(definedPrev || definedCurr, rule.Body); err != nil { - return true - } - } - } - } - return false - }) - - Walk(vis, expr) - - if err != nil { - return err - } - } - - return nil -} - // referencesInput returns true if expr refers to the input document. This // function will not visit closures. func referencesInput(expr *Expr) bool { diff --git a/ast/compile_test.go b/ast/compile_test.go index 83a32374f4..47f6ac13cc 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1081,12 +1081,8 @@ func TestQueryCompiler(t *testing.T) { {"unsafe vars", "z", "", nil, "", fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: var z is unsafe")}, {"safe vars", `data; abc`, `package ex`, []string{"import input.xyz as abc"}, `{}`, `data; input.xyz`}, {"reorder", `x != 1; x = 0`, "", nil, "", `x = 0; x != 1`}, - // {"bad builtin", "deadbeef(1,2,3)", "", nil, "", fmt.Errorf("1 error occurred: 1:1: rego_type_error: undefined built-in function")}, {"bad with target", "x = 1 with data.p as null", "", nil, "", fmt.Errorf("1 error occurred: 1:7: rego_type_error: with keyword target must be input")}, - // wrapping refs in extra terms to cover error handling - {"undefined input", `[[true | [data.a.b.d.t, true]], true]`, "", nil, "", fmt.Errorf("5:12: rego_input_error: input document not defined")}, - {"conflicting input", `[true | data.a.b.d.t with input as 1]`, "", nil, "2", fmt.Errorf("1:9: rego_input_error: input document conflict")}, - {"conflicting input-2", `sum([1 | data.a.b.d.t with input as 2], x) with input as 3`, "", nil, "", fmt.Errorf("1:10: rego_input_error: input document conflict")}, + {"check types", "x = data.a.b.c.z; y = null; x = y", "", nil, "", fmt.Errorf("match error\n\tleft : number\n\tright : null")}, } for _, tc := range tests { @@ -1259,7 +1255,7 @@ func runQueryCompilerTest(t *testing.T, note, q, pkg string, imports []string, i if err == nil { t.Fatalf("Expected error from %v but got: %v", query, result) } - if err.Error() != expected.Error() { + if !strings.Contains(err.Error(), expected.Error()) { t.Fatalf("Expected error %v but got: %v", expected, err) } } diff --git a/ast/index.go b/ast/index.go new file mode 100644 index 0000000000..ab728d2bd2 --- /dev/null +++ b/ast/index.go @@ -0,0 +1,422 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "io" + "sort" + "strings" + + "github.com/open-policy-agent/opa/util" +) + +// RuleIndex defines the interface for rule indices. +type RuleIndex interface { + + // Build tries to construct an index for the given rules. If the index was + // constructed, ok is true, otherwise false. + Build(rules []*Rule) (ok bool) + + // Index returns a set of rules to evaluate and a default rule if one was + // present when the index was built. + Index(resolver ValueResolver) (rules []*Rule, defaultRule *Rule, err error) +} + +type baseDocEqIndex struct { + isVirtual func(Ref) bool + root *trieNode + defaultRule *Rule +} + +func newBaseDocEqIndex(isVirtual func(Ref) bool) *baseDocEqIndex { + return &baseDocEqIndex{ + isVirtual: isVirtual, + root: newTrieNodeImpl(), + } +} + +func (i *baseDocEqIndex) Build(rules []*Rule) bool { + + refs := make(refValueIndex, len(rules)) + + // freq is map[ref]int where the values represent the frequency of the + // ref/key. + freq := util.NewHashMap(func(a, b util.T) bool { + r1, r2 := a.(Ref), b.(Ref) + return r1.Equal(r2) + }, func(x util.T) int { + return x.(Ref).Hash() + }) + + // Build refs and freq maps + for _, rule := range rules { + + if rule.Default { + // Compiler guarantees that only one default will be defined per path. + i.defaultRule = rule + continue + } + + for _, expr := range rule.Body { + ref, value, ok := i.getRefAndValue(expr) + if ok { + refs.Insert(rule, ref, value) + count, ok := freq.Get(ref) + if !ok { + count = 0 + } + count = count.(int) + 1 + freq.Put(ref, count) + } + } + } + + // Sort by frequency + type refCountPair struct { + ref Ref + count int + } + + sorted := make([]refCountPair, 0, freq.Len()) + freq.Iter(func(k, v util.T) bool { + ref, count := k.(Ref), v.(int) + sorted = append(sorted, refCountPair{ref, count}) + return false + }) + + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].count > sorted[j].count { + return true + } + return false + }) + + // Build trie + for _, rule := range rules { + + if rule.Default { + continue + } + + node := i.root + + if refs := refs[rule]; refs != nil { + for _, pair := range sorted { + value := refs.Get(pair.ref) + node = node.Insert(pair.ref, value) + } + } + + node.rules = append(node.rules, rule) + } + + return true +} + +func (i *baseDocEqIndex) Index(resolver ValueResolver) ([]*Rule, *Rule, error) { + rules, err := i.traverse(i.root, resolver) + return rules, i.defaultRule, err +} + +func (i *baseDocEqIndex) traverse(node *trieNode, resolver ValueResolver) ([]*Rule, error) { + + if node == nil { + return nil, nil + } + + result := make([]*Rule, len(node.rules)) + copy(result, node.rules) + next := node.next + + if next == nil { + return result, nil + } + + children, err := next.Resolve(resolver) + if err != nil { + return nil, err + } + + for _, child := range children { + rules, err := i.traverse(child, resolver) + if err != nil { + return nil, err + } + result = append(result, rules...) + } + + return result, nil +} + +func (i *baseDocEqIndex) getRefAndValue(expr *Expr) (Ref, Value, bool) { + + if !expr.IsEquality() || expr.Negated { + return nil, nil, false + } + + a, b := expr.Operand(0), expr.Operand(1) + + if ref, value, ok := i.getRefAndValueFromTerms(a, b); ok { + return ref, value, true + } + + return i.getRefAndValueFromTerms(b, a) +} + +func (i *baseDocEqIndex) getRefAndValueFromTerms(a, b *Term) (Ref, Value, bool) { + + ref, ok := a.Value.(Ref) + if !ok { + return nil, nil, false + } + + if !RootDocumentNames.Contains(ref[0]) { + return nil, nil, false + } + + if i.isVirtual(ref) { + return nil, nil, false + } + + if ref.IsNested() || !ref.IsGround() { + return nil, nil, false + } + + switch b := b.Value.(type) { + case Null, Boolean, Number, String, Var: + return ref, b, true + case Array: + stop := false + first := true + vis := NewGenericVisitor(func(x interface{}) bool { + if first { + first = false + return false + } + switch x.(type) { + // No nested structures or values that require evaluation (other than var). + case Array, Object, *Set, *ArrayComprehension, Ref: + stop = true + } + return stop + }) + Walk(vis, b) + if !stop { + return ref, b, true + } + } + + return nil, nil, false +} + +type refValueIndex map[*Rule]*ValueMap + +func (m refValueIndex) Insert(rule *Rule, ref Ref, value Value) { + vm, ok := m[rule] + if !ok { + vm = NewValueMap() + m[rule] = vm + } + vm.Put(ref, value) +} + +type trieWalker interface { + Do(x interface{}) trieWalker +} + +type trieNode struct { + ref Ref + next *trieNode + any *trieNode + undefined *trieNode + scalars map[Value]*trieNode + array *trieNode + rules []*Rule +} + +func newTrieNodeImpl() *trieNode { + return &trieNode{ + scalars: map[Value]*trieNode{}, + } +} + +func (node *trieNode) Do(walker trieWalker) { + next := walker.Do(node) + if next == nil { + return + } + if node.next != nil { + node.next.Do(next) + return + } + if node.any != nil { + node.any.Do(next) + } + if node.undefined != nil { + node.undefined.Do(next) + } + for _, child := range node.scalars { + child.Do(next) + } + if node.array != nil { + node.array.Do(next) + } +} + +func (node *trieNode) Insert(ref Ref, value Value) *trieNode { + + if node.next == nil { + node.next = newTrieNodeImpl() + node.next.ref = ref + } + + return node.next.insertValue(value) +} + +func (node *trieNode) Resolve(resolver ValueResolver) ([]*trieNode, error) { + + v, err := resolver.Resolve(node.ref) + if err != nil { + return nil, err + } + + result := []*trieNode{} + + if node.undefined != nil { + result = append(result, node.undefined) + } + + if v == nil { + return result, nil + } + + if node.any != nil { + result = append(result, node.any) + } + + result = append(result, node.resolveValue(v)...) + return result, nil +} + +func (node *trieNode) insertValue(value Value) *trieNode { + + switch value := value.(type) { + case nil: + if node.undefined == nil { + node.undefined = newTrieNodeImpl() + } + return node.undefined + case Var: + if node.any == nil { + node.any = newTrieNodeImpl() + } + return node.any + case Null, Boolean, Number, String: + child, ok := node.scalars[value] + if !ok { + child = newTrieNodeImpl() + node.scalars[value] = child + } + return child + case Array: + if node.array == nil { + node.array = newTrieNodeImpl() + } + return node.array.insertArray(value) + } + + panic("illegal value") +} + +func (node *trieNode) insertArray(arr Array) *trieNode { + + if len(arr) == 0 { + return node + } + + switch head := arr[0].Value.(type) { + case Var: + if node.any == nil { + node.any = newTrieNodeImpl() + } + return node.any.insertArray(arr[1:]) + case Null, Boolean, Number, String: + child, ok := node.scalars[head] + if !ok { + child = newTrieNodeImpl() + node.scalars[head] = child + } + return child.insertArray(arr[1:]) + } + + panic("illegal value") +} + +func (node *trieNode) resolveValue(value Value) []*trieNode { + + switch value := value.(type) { + case Array: + if node.array == nil { + return nil + } + return node.array.resolveArray(value) + + case Null, Boolean, Number, String: + child, ok := node.scalars[value] + if !ok { + return nil + } + return []*trieNode{child} + } + + return nil +} + +func (node *trieNode) resolveArray(arr Array) []*trieNode { + + if len(arr) == 0 { + if node.next != nil || len(node.rules) > 0 { + return []*trieNode{node} + } + return nil + } + + head := arr[0].Value + + if !IsScalar(head) { + return nil + } + + var result []*trieNode + + if node.any != nil { + result = append(result, node.any.resolveArray(arr[1:])...) + } + + child, ok := node.scalars[head] + if !ok { + return result + } + + return append(result, child.resolveArray(arr[1:])...) +} + +type triePrinter struct { + depth int + w io.Writer +} + +func (p triePrinter) Do(x interface{}) trieWalker { + padding := strings.Repeat(" ", p.depth) + fmt.Fprintf(p.w, "%v%v\n", padding, x) + p.depth++ + return p +} + +func printTrie(w io.Writer, trie *trieNode) { + pp := triePrinter{0, w} + trie.Do(pp) +} diff --git a/ast/index_test.go b/ast/index_test.go new file mode 100644 index 0000000000..961e94a886 --- /dev/null +++ b/ast/index_test.go @@ -0,0 +1,350 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "testing" + + "github.com/open-policy-agent/opa/util/test" +) + +type testResolver struct { + input *Term + failRef Ref +} + +func (r testResolver) Resolve(ref Ref) (Value, error) { + if ref.Equal(r.failRef) { + return nil, fmt.Errorf("some error") + } + if ref.HasPrefix(InputRootRef) { + v, err := r.input.Value.Find(ref[1:]) + if err != nil { + return nil, nil + } + return v, nil + } + panic("illegal value") +} + +func TestBaseDocEqIndexing(t *testing.T) { + + module := MustParseModule(` + package test + + exact { + input.x = 1 + input.y = 2 + } { + input.x = 3 + input.y = 4 + } + + scalars { + input.x = 0 + input.y = 1 + } { + 1 = input.y # exercise ordering + input.x = 0 + } { + input.y = 2 + input.z = 2 + } { + input.x = 2 + } + + vars { + input.x = 1 + input.y = 2 + } { + input.x = x + input.y = 3 + } { + input.x = 4 + input.z = 5 + } + + composite_arr { + input.x = 1 + input.y = [1,2,3] + input.z = 1 + } { + input.x = 1 + input.y = [1,2,4,x] + } { + input.y = [1,2,y,5] + input.z = 3 + } { + input.y = [] + } { + # Must be included in all results as nested composites are not indexed. + input.y = [1,[2,3],4] + } + + composite_obj { + input.y = {"foo": "bar", "bar": x} + } + + # filtering ruleset contains rules that cannot be indexed (for different reasons). + filtering { + count([], x) + } { + not input.x = 0 + } { + x = [1,2,3] + x[0] = 1 + } { + input.x[_] = 1 + } { + input.x[input.y] = 1 + } { + # include one rule that can be indexed to exercise merging of root non-indexable + # rules with other rules. + input.x = 1 + } + + # exercise default keyword + default allow = false + allow { + input.x = 1 + } { + input.x = 0 + } + `) + + tests := []struct { + note string + ruleset string + input string + expectedRS interface{} + expectedDR *Rule + }{ + { + note: "exact match", + ruleset: "exact", + input: `{"x": 3, "y": 4}`, + expectedRS: []string{ + `exact { input.x = 3; input.y = 4 }`, + }, + }, + { + note: "undefined match", + ruleset: "scalars", + input: `{"x": 2, "y": 2}`, + expectedRS: []string{ + `scalars { input.x = 2 }`}, + }, + { + note: "disjoint match", + ruleset: "scalars", + input: `{"x": 2, "y": 2, "z": 2}`, + expectedRS: []string{ + `scalars { input.x = 2 }`, + `scalars { input.y = 2; input.z = 2}`}, + }, + { + note: "ordering match", + ruleset: "scalars", + input: `{"x": 0, "y": 1}`, + expectedRS: []string{ + `scalars { input.x = 0; input.y = 1 }`, + `scalars { 1 = input.y; input.x = 0 }`}, + }, + { + note: "type no match", + ruleset: "vars", + input: `{"y": 3, "x": {1,2,3}}`, + expectedRS: []string{ + `vars { input.x = x; input.y = 3 }`, + }, + }, + { + note: "var match", + ruleset: "vars", + input: `{"x": 1, "y": 3}`, + expectedRS: []string{ + `vars { input.x = x; input.y = 3 }`, + }, + }, + { + note: "var match disjoint", + ruleset: "vars", + input: `{"x": 4, "z": 5, "y": 3}`, + expectedRS: []string{ + `vars { input.x = x; input.y = 3 }`, + `vars { input.x = 4; input.z = 5 }`, + }, + }, + { + note: "array match", + ruleset: "composite_arr", + input: `{ + "x": 1, + "y": [1,2,3], + "z": 1, + }`, + expectedRS: []string{ + `composite_arr { input.x = 1; input.y = [1,2,3]; input.z = 1 }`, + `composite_arr { input.y = [1,[2,3],4] }`, + }, + }, + { + note: "array var match", + ruleset: "composite_arr", + input: `{ + "x": 1, + "y": [1,2,4,5], + }`, + expectedRS: []string{ + `composite_arr { input.x = 1; input.y = [1,2,4,x] }`, + `composite_arr { input.y = [1,[2,3],4] }`, + }, + }, + { + note: "array var multiple match", + ruleset: "composite_arr", + input: `{ + "x": 1, + "y": [1,2,4,5], + "z": 3, + }`, + expectedRS: []string{ + `composite_arr { input.x = 1; input.y = [1,2,4,x] }`, + `composite_arr { input.y = [1,2,y,5]; input.z = 3 }`, + `composite_arr { input.y = [1,[2,3],4] }`, + }, + }, + { + note: "array nested match non-indexable rules", + ruleset: "composite_arr", + input: `{ + "x": 1, + "y": [1,[2,3],4], + }`, + expectedRS: []string{ + `composite_arr { input.y = [1,[2,3],4] }`, + }, + }, + { + note: "array empty match", + ruleset: "composite_arr", + input: `{"y": []}`, + expectedRS: []string{ + `composite_arr { input.y = [] }`, + `composite_arr { input.y = [1,[2,3],4] }`, + }, + }, + { + note: "object match non-indexable rule", + ruleset: "composite_obj", + input: `{"y": {"foo": "bar", "bar": "baz"}}`, + expectedRS: []string{ + `composite_obj { input.y = {"foo": "bar", "bar": x} }`, + }, + }, + { + note: "default rule only", + ruleset: "allow", + input: `{"x": 2}`, + expectedRS: []string{}, + expectedDR: MustParseRule(`default allow = false`), + }, + { + note: "match and default rule", + ruleset: "allow", + input: `{"x": 1}`, + expectedRS: []string{"allow { input.x = 1 }"}, + expectedDR: MustParseRule(`default allow = false`), + }, + { + note: "match and non-indexable rules", + ruleset: "filtering", + input: `{"x": 1}`, + expectedRS: module.RuleSet(Var("filtering")), + }, + { + note: "non-indexable rules", + ruleset: "filtering", + input: `{}`, + expectedRS: module.RuleSet(Var("filtering")).Diff(NewRuleSet(MustParseRule(`filtering { input.x = 1 }`))), + }, + } + + for _, tc := range tests { + test.Subtest(t, tc.note, func(t *testing.T) { + + rules := []*Rule{} + for _, rule := range module.Rules { + if rule.Head.Name == Var(tc.ruleset) { + rules = append(rules, rule) + } + } + + input := MustParseTerm(tc.input) + var expectedRS RuleSet + + switch e := tc.expectedRS.(type) { + case []string: + for _, r := range e { + expectedRS.Add(MustParseRule(r)) + } + case RuleSet: + expectedRS = e + default: + panic("Unexpected test case expected value") + } + + index := newBaseDocEqIndex(func(Ref) bool { + return false + }) + + if !index.Build(rules) { + t.Fatalf("Expected index build to succeed") + } + + rs, dr, err := index.Index(testResolver{input, nil}) + if err != nil { + t.Fatalf("Unexpected error during index lookup: %v", err) + } + + result := NewRuleSet(rs...) + + if !result.Equal(expectedRS) { + t.Fatalf("Expected ruleset %v but got: %v", expectedRS, rs) + } + + if dr == nil && tc.expectedDR != nil { + t.Fatalf("Expected default rule but got nil") + } else if dr != nil && tc.expectedDR == nil { + t.Fatalf("Unexpected default rule %v", dr) + } else if dr != nil && tc.expectedDR != nil && !dr.Equal(tc.expectedDR) { + t.Fatalf("Expected default rule %v but got: %v", tc.expectedDR, dr) + } + }) + } + +} + +func TestBaseDocEqIndexingErrors(t *testing.T) { + index := newBaseDocEqIndex(func(Ref) bool { + return false + }) + + module := MustParseModule(` + package ex + + p { input.raise_error = 1 }`) + + if !index.Build(module.Rules) { + t.Fatalf("Expected index to build") + } + + _, _, err := index.Index(testResolver{MustParseTerm(`{}`), MustParseRef("input.raise_error")}) + + if err == nil || err.Error() != "some error" { + t.Fatalf("Expected error but got: %v", err) + } +} diff --git a/ast/policy.go b/ast/policy.go index 8cb90ea503..2411ebb56d 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -215,6 +215,17 @@ func (mod *Module) String() string { return strings.Join(buf, "\n") } +// RuleSet returns a RuleSet containing named rules in the mod. +func (mod *Module) RuleSet(name Var) RuleSet { + rs := NewRuleSet() + for _, rule := range mod.Rules { + if rule.Head.Name.Equal(name) { + rs.Add(rule) + } + } + return rs +} + // NewComment returns a new Comment object. func NewComment(text []byte) *Comment { return &Comment{ @@ -992,6 +1003,74 @@ func (w With) Hash() int { return w.Target.Hash() + w.Value.Hash() } +// RuleSet represents a collection of rules that produce a virtual document. +type RuleSet []*Rule + +// NewRuleSet returns a new RuleSet containing the given rules. +func NewRuleSet(rules ...*Rule) RuleSet { + rs := make(RuleSet, 0, len(rules)) + for _, rule := range rules { + rs.Add(rule) + } + return rs +} + +// Add inserts the rule into rs. +func (rs *RuleSet) Add(rule *Rule) { + for _, exist := range *rs { + if exist.Equal(rule) { + return + } + } + *rs = append(*rs, rule) +} + +// Contains returns true if rs contains rule. +func (rs RuleSet) Contains(rule *Rule) bool { + for i := range rs { + if rs[i].Equal(rule) { + return true + } + } + return false +} + +// Diff returns a new RuleSet containing rules in rs that are not in other. +func (rs RuleSet) Diff(other RuleSet) RuleSet { + result := NewRuleSet() + for i := range rs { + if !other.Contains(rs[i]) { + result.Add(rs[i]) + } + } + return result +} + +// Equal returns true if rs equals other. +func (rs RuleSet) Equal(other RuleSet) bool { + return len(rs.Diff(other)) == 0 && len(other.Diff(rs)) == 0 +} + +// Merge returns a ruleset containing the union of rules from rs an other. +func (rs RuleSet) Merge(other RuleSet) RuleSet { + result := NewRuleSet() + for i := range rs { + result.Add(rs[i]) + } + for i := range other { + result.Add(other[i]) + } + return result +} + +func (rs RuleSet) String() string { + buf := make([]string, 0, len(rs)) + for _, rule := range rs { + buf = append(buf, rule.String()) + } + return "{" + strings.Join(buf, ", ") + "}" +} + type ruleSlice []*Rule func (s ruleSlice) Less(i, j int) bool { return Compare(s[i], s[j]) < 0 } diff --git a/ast/term.go b/ast/term.go index 91a4865864..25da8bb6f5 100644 --- a/ast/term.go +++ b/ast/term.go @@ -73,11 +73,11 @@ func (loc *Location) String() string { // - Variables, References // - Array Comprehensions type Value interface { - Equal(other Value) bool // Equal returns true if this value equals the other value. - Find(path []string) (Value, error) // Find returns value referred or an error if path is not found. - Hash() int // Returns hash code of the value. - IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. - String() string // String returns a human readable string representation of the value. + Equal(other Value) bool // Equal returns true if this value equals the other value. + Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found. + Hash() int // Returns hash code of the value. + IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. + String() string // String returns a human readable string representation of the value. } // InterfaceToValue converts a native Go value x to a Value. @@ -125,6 +125,11 @@ type Resolver interface { Resolve(ref Ref) (value interface{}, err error) } +// ValueResolver defines the interface for resolving references to AST values. +type ValueResolver interface { + Resolve(ref Ref) (value Value, err error) +} + type illegalResolver struct{} func (illegalResolver) Resolve(ref Ref) (interface{}, error) { @@ -367,7 +372,7 @@ func (null Null) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (null Null) Find(path []string) (Value, error) { +func (null Null) Find(path Ref) (Value, error) { if len(path) == 0 { return null, nil } @@ -407,7 +412,7 @@ func (bol Boolean) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (bol Boolean) Find(path []string) (Value, error) { +func (bol Boolean) Find(path Ref) (Value, error) { if len(path) == 0 { return bol, nil } @@ -462,7 +467,7 @@ func (num Number) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (num Number) Find(path []string) (Value, error) { +func (num Number) Find(path Ref) (Value, error) { if len(path) == 0 { return num, nil } @@ -521,7 +526,7 @@ func (str String) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (str String) Find(path []string) (Value, error) { +func (str String) Find(path Ref) (Value, error) { if len(path) == 0 { return str, nil } @@ -563,7 +568,7 @@ func (v Var) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (v Var) Find(path []string) (Value, error) { +func (v Var) Find(path Ref) (Value, error) { if len(path) == 0 { return v, nil } @@ -657,7 +662,7 @@ func (ref Ref) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (ref Ref) Find(path []string) (Value, error) { +func (ref Ref) Find(path Ref) (Value, error) { if len(path) == 0 { return ref, nil } @@ -776,19 +781,41 @@ func (arr Array) Equal(other Value) bool { } // Find returns the value at the index or an out-of-range error. -func (arr Array) Find(path []string) (Value, error) { +func (arr Array) Find(path Ref) (Value, error) { if len(path) == 0 { return arr, nil } - idx, err := strconv.ParseInt(path[0], 10, 64) - if err != nil { - return nil, err + num, ok := path[0].Value.(Number) + if !ok { + return nil, fmt.Errorf("find: not found") + } + i, ok := num.Int() + if !ok { + return nil, fmt.Errorf("find: not found") } - i := int(idx) if i < 0 || i >= len(arr) { return nil, fmt.Errorf("find: not found") } - return arr[idx].Value.Find(path[1:]) + return arr[i].Value.Find(path[1:]) +} + +// Get returns the element at pos or nil if not possible. +func (arr Array) Get(pos *Term) *Term { + num, ok := pos.Value.(Number) + if !ok { + return nil + } + + i, ok := num.Int() + if !ok { + return nil + } + + if i > 0 && i < len(arr) { + return arr[i] + } + + return nil } // Hash returns the hash code for the Value. @@ -870,7 +897,7 @@ func (s *Set) Equal(v Value) bool { } // Find returns the current value or a not found error. -func (s *Set) Find(path []string) (Value, error) { +func (s *Set) Find(path Ref) (Value, error) { if len(path) == 0 { return s, nil } @@ -987,11 +1014,11 @@ func (obj Object) Equal(other Value) bool { } // Find returns the value at the key or undefined. -func (obj Object) Find(path []string) (Value, error) { +func (obj Object) Find(path Ref) (Value, error) { if len(path) == 0 { return obj, nil } - value := obj.Get(StringTerm(path[0])) + value := obj.Get(path[0]) if value == nil { return nil, fmt.Errorf("find: not found") } @@ -1163,7 +1190,7 @@ func (ac *ArrayComprehension) Equal(other Value) bool { } // Find returns the current value or a not found error. -func (ac *ArrayComprehension) Find(path []string) (Value, error) { +func (ac *ArrayComprehension) Find(path Ref) (Value, error) { if len(path) == 0 { return ac, nil } diff --git a/ast/term_test.go b/ast/term_test.go index a54ed9c1e7..a8b1187d81 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -131,16 +131,16 @@ func TestFind(t *testing.T) { term := MustParseTerm(`{"foo": [1,{"bar": {2,3,4}}], "baz": {"qux": ["hello", "world"]}}`) tests := []struct { - path string + path *Term expected interface{} }{ - {"foo/1/bar", MustParseTerm(`{2, 3, 4}`)}, - {"foo/2", fmt.Errorf("not found")}, - {"baz/qux/0", MustParseTerm(`"hello"`)}, + {RefTerm(StringTerm("foo"), IntNumberTerm(1), StringTerm("bar")), MustParseTerm(`{2, 3, 4}`)}, + {RefTerm(StringTerm("foo"), IntNumberTerm(2)), fmt.Errorf("not found")}, + {RefTerm(StringTerm("baz"), StringTerm("qux"), IntNumberTerm(0)), MustParseTerm(`"hello"`)}, } for _, tc := range tests { - result, err := term.Value.Find(strings.Split(tc.path, "/")) + result, err := term.Value.Find(tc.path.Value.(Ref)) switch expected := tc.expected.(type) { case *Term: if err != nil { diff --git a/repl/repl_test.go b/repl/repl_test.go index 9d3583e6ba..8f63e5c18a 100644 --- a/repl/repl_test.go +++ b/repl/repl_test.go @@ -321,11 +321,13 @@ func TestUnset(t *testing.T) { t.Fatalf("Expected unset to succeed for input: %v", err) } - err = repl.OneShot(ctx, `true = input`) + buffer.Reset() + repl.OneShot(ctx, `not input`) - if !strings.Contains(err.Error(), "input document not defined") { - t.Fatalf("Expected undefined error but got: %v", err) + if buffer.String() != "true\n" { + t.Fatalf("Expected unset input to remove input document: %v", buffer.String()) } + } func TestOneShotEmptyBufferOneExpr(t *testing.T) { @@ -827,12 +829,14 @@ func TestEvalBodyWith(t *testing.T) { repl := newRepl(store, &buffer) repl.OneShot(ctx, `p = true { input.foo = "bar" }`) - err := repl.OneShot(ctx, "p") + repl.OneShot(ctx, "p") - if err == nil || !strings.Contains(err.Error(), "input document not defined") { - t.Fatalf("Expected input document undefined error") + if buffer.String() != "undefined\n" { + t.Fatalf("Expected undefined but got: %v", buffer.String()) } + buffer.Reset() + repl.OneShot(ctx, `p with input.foo as "bar"`) result := buffer.String() diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index 46e484dd19..ba057d8420 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -175,14 +175,14 @@ func TestBasic(t *testing.T) { t.Fatalf("Expected JSON response but got: %v", recorder) } response := ast.MustInterfaceToValue(x) - code, err := response.Find([]string{"code"}) + code, err := response.Find(ast.RefTerm(ast.StringTerm("code")).Value.(ast.Ref)) if err != nil { t.Fatalf("Missing code in response: %v", recorder) } else if !code.Equal(ast.String(tc.expectedCode)) { t.Fatalf("Expected code %v but got: %v", tc.expectedCode, recorder) } - msg, err := response.Find([]string{"message"}) + msg, err := response.Find(ast.RefTerm(ast.StringTerm("message")).Value.(ast.Ref)) if err != nil { t.Fatalf("Missing message in response: %v", recorder) } else if !strings.Contains(msg.String(), tc.expectedMsg) { diff --git a/server/server_test.go b/server/server_test.go index 71c7b4fa5b..acbace6a2f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -54,11 +54,6 @@ undef = true { false }` p = [1, 2, 3, 4] { true } q = {"a": 1, "b": 2} { true }` - testMod3 := `package testmod - -p = true { loopback with input as true } -loopback = input { true }` - testMod4 := `package testmod p = true { true } @@ -168,24 +163,6 @@ p = true { false }` tr{"PUT", "/policies/test", testMod1, 200, ""}, tr{"GET", "/data/testmod/g?input=req1%3A%7B%22a%22%3A%5B1%5D%7D&input=req2%3A%7B%22b%22%3A%5B0%2C1%5D%7D", "", 200, `{"result": true}`}, }}, - {"get missing input", []tr{ - tr{"PUT", "/policies/test", testMod1, 200, ""}, - tr{"GET", "/data/testmod/g", "", 400, `{ - "code": "invalid_parameter", - "errors": [ - { - "code": "rego_input_error", - "location": { - "col": 12, - "file": "test", - "row": 10 - }, - "message": "input document not defined" - } - ], - "message": "input document is missing or conflicts with query" - }`}, - }}, {"get with input (missing input value)", []tr{ tr{"PUT", "/policies/test", testMod1, 200, ""}, tr{"GET", "/data/testmod/g?input=req1%3A%7B%22a%22%3A%5B1%5D%7D", "", 200, "{}"}, // req2 not specified @@ -257,24 +234,6 @@ p = true { false }` tr{"PUT", "/policies/test", testMod1, 200, ""}, tr{"POST", "/data/testmod/gt1", `{"input": {"req1": 2}}`, 200, `{"result": true}`}, }}, - {"post missing input", []tr{ - tr{"PUT", "/policies/test", testMod1, 200, ""}, - tr{"POST", "/data/testmod/gt1", ``, 400, `{ - "code": "invalid_parameter", - "message": "input document is missing or conflicts with query", - "errors": [ - { - "code": "rego_input_error", - "location": { - "file": "test", - "row": 12, - "col": 14 - }, - "message": "input document not defined" - } - ] - }`}, - }}, {"post malformed input", []tr{ tr{"POST", "/data/deadbeef", `{"input": @}`, 400, `{ "code": "invalid_parameter", @@ -299,24 +258,6 @@ p = true { false }` "message": "error(s) occurred while evaluating query" }`}, }}, - {"input conflict", []tr{ - tr{"PUT", "/policies/test", testMod3, 200, ""}, - tr{"POST", "/data/testmod/p", `{"input": false}`, 400, `{ - "code": "invalid_parameter", - "errors": [ - { - "code": "rego_input_error", - "location": { - "col": 12, - "file": "test", - "row": 3 - }, - "message": "input document conflict" - } - ], - "message": "input document is missing or conflicts with query" - }`}, - }}, {"query wildcards omitted", []tr{ tr{"PATCH", "/data/x", `[{"op": "add", "path": "/", "value": [1,2,3,4]}]`, 204, ""}, tr{"GET", "/query?q=data.x[_]%20=%20x", "", 200, `{"result": [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]}`}, @@ -593,7 +534,7 @@ func TestPoliciesPutV1ParseError(t *testing.T) { v := ast.MustInterfaceToValue(response) - name, err := v.Find([]string{"errors", "0", "location", "file"}) + name, err := v.Find(ast.MustParseRef("_.errors[0].location.file")[1:]) if err != nil { t.Fatalf("Expecfted to find name in errors but: %v", err) } @@ -629,7 +570,7 @@ q[x] { p[x] }`, v := ast.MustInterfaceToValue(response) - name, err := v.Find([]string{"errors", "0", "location", "file"}) + name, err := v.Find(ast.MustParseRef("_.errors[0].location.file")[1:]) if err != nil { t.Fatalf("Expecfted to find name in errors but: %v", err) } diff --git a/topdown/topdown.go b/topdown/topdown.go index c97ef4bd74..a6a4c27606 100644 --- a/topdown/topdown.go +++ b/topdown/topdown.go @@ -235,6 +235,18 @@ func (t *Topdown) Resolve(ref ast.Ref) (interface{}, error) { return doc, nil } +type valueResolver struct { + t *Topdown +} + +func (r valueResolver) Resolve(ref ast.Ref) (ast.Value, error) { + v, err := lookupValue(r.t, ref) + if storage.IsNotFound(err) { + return nil, nil + } + return v, nil +} + // Step returns a new Topdown object to evaluate the next expression. func (t *Topdown) Step() *Topdown { cpy := *t @@ -1211,6 +1223,21 @@ func evalRefRecNonGround(t *Topdown, ref, prefix ast.Ref, iter Iterator) error { func evalRefRule(t *Topdown, ref ast.Ref, path ast.Ref, rules []*ast.Rule, iter Iterator) error { + if index := t.Compiler.RuleIndex(path); index != nil { + rs, dr, err := index.Index(valueResolver{t}) + if err != nil { + return err + } + if dr != nil { + rs = append(rs, dr) + } + rules = rs + } + + if len(rules) == 0 { + return nil + } + suffix := ref[len(path):] switch rules[0].Head.DocKind() { @@ -2104,11 +2131,24 @@ func indexingAllowed(ref ast.Ref, term *ast.Term) bool { } func lookupValue(t *Topdown, ref ast.Ref) (ast.Value, error) { - r, err := t.Resolve(ref) - if err != nil { - return nil, err + if ref[0].Equal(ast.DefaultRootDocument) { + r, err := t.Resolve(ref) + if err != nil { + return nil, err + } + return ast.InterfaceToValue(r) + } + if ref[0].Equal(ast.InputRootDocument) { + if t.Input == nil { + return nil, nil + } + r, err := t.Input.Find(ref[1:]) + if err != nil { + return nil, nil + } + return r, nil } - return ast.InterfaceToValue(r) + panic("illegal value") } // valueMapStack is used to store a stack of bindings. diff --git a/topdown/topdown_bench_test.go b/topdown/topdown_bench_test.go index 1b5baf91fb..c128537433 100644 --- a/topdown/topdown_bench_test.go +++ b/topdown/topdown_bench_test.go @@ -14,77 +14,53 @@ import ( "github.com/open-policy-agent/opa/storage" ) -func BenchmarkVirtualDocs1(b *testing.B) { - runVirtualDocsBenchmark(b, 1) +func BenchmarkVirtualDocs1x1(b *testing.B) { + runVirtualDocsBenchmark(b, 1, 1) } -func BenchmarkVirtualDocs10(b *testing.B) { - runVirtualDocsBenchmark(b, 10) +func BenchmarkVirtualDocs10x1(b *testing.B) { + runVirtualDocsBenchmark(b, 10, 1) } -func BenchmarkVirtualDocs100(b *testing.B) { - runVirtualDocsBenchmark(b, 100) +func BenchmarkVirtualDocs100x1(b *testing.B) { + runVirtualDocsBenchmark(b, 100, 1) } -func BenchmarkVirtualDocs1000(b *testing.B) { - runVirtualDocsBenchmark(b, 1000) +func BenchmarkVirtualDocs1000x1(b *testing.B) { + runVirtualDocsBenchmark(b, 1000, 1) } -func runVirtualDocsBenchmark(b *testing.B, numRules int) { - - // Generate test module containing numRules instances of allow. - testRule := ` - allow { - input.method = "POST" - input.path = ["accounts", account_id] - input.user_id = account_id - } - ` - - testModuleTmpl := ` - package a.b.c - - {{range . }} - {{ . }} - {{end}} - ` +func BenchmarkVirtualDocs10x10(b *testing.B) { + runVirtualDocsBenchmark(b, 10, 10) +} - tmpl, err := template.New("Test").Parse(testModuleTmpl) - if err != nil { - b.Fatalf("Unexpected error while parsing template: %v", err) - } +func BenchmarkVirtualDocs100x10(b *testing.B) { + runVirtualDocsBenchmark(b, 100, 10) +} - var buf bytes.Buffer +func BenchmarkVirtualDocs1000x10(b *testing.B) { + runVirtualDocsBenchmark(b, 1000, 10) +} - rules := make([]string, numRules) - for i := range rules { - rules[i] = testRule - } +func runVirtualDocsBenchmark(b *testing.B, numTotalRules, numHitRules int) { - err = tmpl.Execute(&buf, rules) - if err != nil { - b.Fatalf("Unexpected error while executing template: %v", err) - } + // Generate test input + mod, input := generateRulesAndInput(numTotalRules, numHitRules) // Setup evaluation... ctx := context.Background() compiler := ast.NewCompiler() - mod := ast.MustParseModule(buf.String()) mods := map[string]*ast.Module{"module": mod} store := storage.New(storage.InMemoryConfig()) txn := storage.NewTransactionOrDie(ctx, store) - input := ast.MustParseTerm(`{ - "path": ["accounts", "alice"], - "method": "POST", - "user_id": "alice" - }`).Value - if compiler.Compile(mods); compiler.Failed() { b.Fatalf("Unexpected compiler error: %v", compiler.Errors) } params := NewQueryParams(ctx, compiler, store, txn, input, ast.MustParseRef("data.a.b.c.allow")) + b.ResetTimer() + // Run query N times. for i := 0; i < b.N; i++ { func() { @@ -101,3 +77,76 @@ func runVirtualDocsBenchmark(b *testing.B, numRules int) { } } + +func generateRulesAndInput(numTotalRules, numHitRules int) (*ast.Module, ast.Value) { + + hitRule := ` + allow { + input.method = "POST" + input.path = ["accounts", account_id] + input.user_id = account_id + } + ` + + missRule := ` + allow { + input.method = "GET" + input.path = ["salaries", account_id] + input.user_id = account_id + } + ` + + testModuleTmpl := ` + package a.b.c + + {{range .MissRules }} + {{ . }} + {{end}} + + {{range .HitRules }} + {{ . }} + {{end}} + ` + + tmpl, err := template.New("Test").Parse(testModuleTmpl) + if err != nil { + panic(err) + } + + var buf bytes.Buffer + + var missRules []string + + if numTotalRules > numHitRules { + missRules = make([]string, numTotalRules-numHitRules) + for i := range missRules { + missRules[i] = missRule + } + } + + hitRules := make([]string, numHitRules) + for i := range hitRules { + hitRules[i] = hitRule + } + + params := struct { + MissRules []string + HitRules []string + }{ + MissRules: missRules, + HitRules: hitRules, + } + + err = tmpl.Execute(&buf, params) + if err != nil { + panic(err) + } + + input := ast.MustParseTerm(`{ + "path": ["accounts", "alice"], + "method": "POST", + "user_id": "alice" + }`).Value + + return ast.MustParseModule(buf.String()), input +} diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 5fca16276d..ec73e106eb 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -1186,7 +1186,7 @@ loopback = input { true }`}) assertTopDown(t, compiler, store, "loopback", []string{"z", "loopback"}, `{"foo": 1}`, `{"foo": 1}`) - assertTopDown(t, compiler, store, "loopback undefined", []string{"z", "loopback"}, ``, fmt.Errorf("input document not defined")) + assertTopDown(t, compiler, store, "loopback undefined", []string{"z", "loopback"}, ``, ``) assertTopDown(t, compiler, store, "simple", []string{"z", "p"}, `{ "req1": {"foo": 4},