diff --git a/ast/compile.go b/ast/compile.go index cf314358ed..31090085e3 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -479,6 +479,82 @@ func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { return rules } +// GetRulesDynamic returns a slice of rules that could be referred to by a ref. +// When parts of the ref are statically known, we use that information to narrow +// down which rules the ref could refer to, but in the most general case this +// will be an over-approximation. +// +// E.g., given the following modules: +// +// package a.b.c +// +// r1 = 1 # rule1 +// +// and: +// +// package a.d.c +// +// r2 = 2 # rule2 +// +// The following calls yield the rules on the right. +// +// GetRulesDynamic("data.a[x].c[y]") => [rule1, rule2] +// GetRulesDynamic("data.a[x].c.r2") => [rule2] +// GetRulesDynamic("data.a.b[x][y]") => [rule1] +func (c *Compiler) GetRulesDynamic(ref Ref) (rules []*Rule) { + node := c.RuleTree + + set := map[*Rule]struct{}{} + var walk func(node *TreeNode, i int) + walk = func(node *TreeNode, i int) { + if i >= len(ref) { + // We've reached the end of the reference and want to collect everything + // under this "prefix". + node.DepthFirst(func(descendant *TreeNode) bool { + insertRules(set, descendant.Values) + return descendant.Hide + }) + } else if i == 0 || IsConstant(ref[i].Value) { + // The head of the ref is always grounded. In case another part of the + // ref is also grounded, we can lookup the exact child. If it's not found + // we can immediately return... + if child := node.Child(ref[i].Value); child == nil { + return + } else if len(child.Values) > 0 { + // If there are any rules at this position, it's what the ref would + // refer to. We can just append those and stop here. + insertRules(set, child.Values) + } else { + // Otherwise, we continue using the child node. + walk(child, i+1) + } + } else { + // This part of the ref is a dynamic term. We can't know what it refers + // to and will just need to try all of the children. + for _, child := range node.Children { + if child.Hide { + continue + } + insertRules(set, child.Values) + walk(child, i+1) + } + } + } + + walk(node, 0) + for rule := range set { + rules = append(rules, rule) + } + return rules +} + +// Utility: add all rule values to the set. +func insertRules(set map[*Rule]struct{}, rules []util.T) { + for _, rule := range rules { + set[rule.(*Rule)] = struct{}{} + } +} + // RuleIndex returns a RuleIndex built for the rule set referred to by path. // The path must refer to the rule set exactly, i.e., given a rule set at path // data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a @@ -1039,7 +1115,7 @@ func (c *Compiler) setRuleTree() { } func (c *Compiler) setGraph() { - c.Graph = NewGraph(c.Modules, c.GetRules) + c.Graph = NewGraph(c.Modules, c.GetRulesDynamic) } type queryCompiler struct { @@ -1404,7 +1480,7 @@ func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { return NewGenericVisitor(func(x interface{}) bool { switch x := x.(type) { case Ref: - for _, b := range list(x.GroundPrefix()) { + for _, b := range list(x) { for node := b; node != nil; node = node.Else { graph.addDependency(a, node) } diff --git a/ast/compile_test.go b/ast/compile_test.go index aed4851153..1f68bc56ee 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -2253,6 +2253,31 @@ dataref = true { data }`, } } +func TestCompilerCheckDynamicRecursion(t *testing.T) { + // This test tries to circumvent the recursion check by using dynamic + // references. For more background info, see + // . + c := NewCompiler() + c.Modules = map[string]*Module{ + "recursion": MustParseModule(`package recursion + +pkg = "recursion" + +foo[x] { + data[pkg]["foo"][x] +}`), + } + + compileStages(c, c.checkRecursion) + + result := compilerErrsToStringSlice(c.Errors) + expected := "rego_recursion_error: rule foo is recursive: foo -> foo" + + if len(result) != 1 || result[0] != expected { + t.Errorf("Expected %v but got: %v", expected, result) + } +} + func TestCompilerGetRulesExact(t *testing.T) { mods := getCompilerTestModules() @@ -2472,6 +2497,67 @@ q["b"] = 2 { true }`, for _, tc := range tests { test.Subtest(t, tc.input, func(t *testing.T) { result := compiler.GetRules(MustParseRef(tc.input)) + + if len(result) != len(tc.expected) { + t.Fatalf("Expected %v but got: %v", tc.expected, result) + } + + for i := range result { + found := false + for j := range tc.expected { + if result[i].Equal(tc.expected[j]) { + found = true + break + } + } + if !found { + t.Fatalf("Expected %v but got: %v", tc.expected, result) + } + } + }) + } + +} + +func TestCompilerGetRulesDynamic(t *testing.T) { + compiler := getCompilerWithParsedModules(map[string]string{ + "mod1": `package a.b.c.d +r1 = 1`, + "mod2": `package a.b.c.e +r2 = 2`, + "mod3": `package a.b +r3 = 3`, + }) + + compileStages(compiler, nil) + + rule1 := compiler.Modules["mod1"].Rules[0] + rule2 := compiler.Modules["mod2"].Rules[0] + rule3 := compiler.Modules["mod3"].Rules[0] + + tests := []struct { + input string + expected []*Rule + }{ + {"data.a.b.c.d.r1", []*Rule{rule1}}, + {"data.a.b[x]", []*Rule{rule1, rule2, rule3}}, + {"data.a.b[x].d", []*Rule{rule1, rule3}}, + {"data.a.b.c", []*Rule{rule1, rule2}}, + {"data.a.b.d", nil}, + {"data[x]", []*Rule{rule1, rule2, rule3}}, + {"data[data.complex_computation].b[y]", []*Rule{rule1, rule2, rule3}}, + {"data[x][y].c.e", []*Rule{rule2}}, + {"data[x][y].r3", []*Rule{rule3}}, + } + + for _, tc := range tests { + test.Subtest(t, tc.input, func(t *testing.T) { + result := compiler.GetRulesDynamic(MustParseRef(tc.input)) + + if len(result) != len(tc.expected) { + t.Fatalf("Expected %v but got: %v", tc.expected, result) + } + for i := range result { found := false for j := range tc.expected { diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index f39666e415..7eb0429339 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -2416,6 +2416,37 @@ func TestTopDownElseKeyword(t *testing.T) { } } +// Test that dynamic dispatch is not broken by the recursion check. +func TestTopdownDynamicDispatch(t *testing.T) { + compiler := compileModules([]string{` + package animals + + dog = "woof" + cat = "meow" + `, ` + package dynamic + + sound = data.animals[animal] + animal = "dog" { + 2 > 1 + } + `}) + + data := map[string]interface{}{} + store := inmem.NewFromObject(data) + + assertTopDownWithPath(t, compiler, store, "dynamic dispatch", []string{}, `{}`, `{ + "animals": { + "cat": "meow", + "dog": "woof" + }, + "dynamic": { + "animal": "dog", + "sound": "woof" + } + }`) +} + func TestTopDownSystemDocument(t *testing.T) { compiler := compileModules([]string{`