Skip to content

Commit

Permalink
ast: Fix rule indexing when multiple glob.match mappers are required
Browse files Browse the repository at this point in the history
Previously when glob.match statements were indexed, the index trie
node would have a mapper function set on it. The problem was that if
subsequent rules were added to the index and required a different
mapper (or none at all in the case of equality statements), the
original mapper would be overwritten. This would result in
false-negatives when the index was queried (i.e., the rule that was indexed
first would not be returned).

This commit fixes the issue by storing multiple mappers on the trie
node (one per delimiter in the case of glob.match). If multiple
mappers are encountered during traversal, each one will be tested.

Fixes open-policy-agent#2617

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Aug 12, 2020
1 parent 9cc0c70 commit 6c40f3f
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 19 deletions.
55 changes: 39 additions & 16 deletions ast/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,15 @@ func (r *ruleWalker) Do(x interface{}) trieWalker {
return r
}

type valueMapper func(Value) Value
type valueMapper struct {
Key string
MapValue func(Value) Value
}

type refindex struct {
Ref Ref
Value Value
Mapper func(Value) Value
Mapper *valueMapper
}

type refindices struct {
Expand Down Expand Up @@ -280,7 +283,7 @@ func (i *refindices) Value(rule *Rule, ref Ref) Value {
return nil
}

func (i *refindices) Mapper(rule *Rule, ref Ref) valueMapper {
func (i *refindices) Mapper(rule *Rule, ref Ref) *valueMapper {
if index := i.index(rule, ref); index != nil {
return index.Mapper
}
Expand Down Expand Up @@ -320,11 +323,14 @@ func (i *refindices) updateGlobMatch(rule *Rule, expr *Expr) {
i.insert(rule, &refindex{
Ref: other.Ref,
Value: arr.Value,
Mapper: func(v Value) Value {
if s, ok := v.(String); ok {
return stringSliceToArray(splitStringEscaped(string(s), delim))
}
return v
Mapper: &valueMapper{
Key: delim,
MapValue: func(v Value) Value {
if s, ok := v.(String); ok {
return stringSliceToArray(splitStringEscaped(string(s), delim))
}
return v
},
},
})
}
Expand Down Expand Up @@ -387,7 +393,7 @@ func (tr *trieTraversalResult) Add(node *ruleNode) {

type trieNode struct {
ref Ref
mapper valueMapper
mappers []*valueMapper
next *trieNode
any *trieNode
undefined *trieNode
Expand Down Expand Up @@ -425,8 +431,8 @@ func (node *trieNode) String() string {
if len(node.rules) > 0 {
flags = append(flags, fmt.Sprintf("%d rule(s)", len(node.rules)))
}
if node.mapper != nil {
flags = append(flags, "mapper")
if len(node.mappers) > 0 {
flags = append(flags, "mapper(s)")
}
return strings.Join(flags, " ")
}
Expand Down Expand Up @@ -464,14 +470,16 @@ func (node *trieNode) Do(walker trieWalker) {
}
}

func (node *trieNode) Insert(ref Ref, value Value, mapper valueMapper) *trieNode {
func (node *trieNode) Insert(ref Ref, value Value, mapper *valueMapper) *trieNode {

if node.next == nil {
node.next = newTrieNodeImpl()
node.next.ref = ref
}

node.next.mapper = mapper
if mapper != nil {
node.next.addMapper(mapper)
}

return node.next.insertValue(value)
}
Expand All @@ -489,6 +497,15 @@ func (node *trieNode) Traverse(resolver ValueResolver, tr *trieTraversalResult)
return node.next.traverse(resolver, tr)
}

func (node *trieNode) addMapper(mapper *valueMapper) {
for i := range node.mappers {
if node.mappers[i].Key == mapper.Key {
return
}
}
node.mappers = append(node.mappers, mapper)
}

func (node *trieNode) insertValue(value Value) *trieNode {

switch value := value.(type) {
Expand Down Expand Up @@ -569,11 +586,17 @@ func (node *trieNode) traverse(resolver ValueResolver, tr *trieTraversalResult)
node.any.Traverse(resolver, tr)
}

if node.mapper != nil {
v = node.mapper(v)
if len(node.mappers) == 0 {
return node.traverseValue(resolver, tr, v)
}

for i := range node.mappers {
if err := node.traverseValue(resolver, tr, node.mappers[i].MapValue(v)); err != nil {
return err
}
}

return node.traverseValue(resolver, tr, v)
return nil
}

func (node *trieNode) traverseValue(resolver ValueResolver, tr *trieTraversalResult, value Value) error {
Expand Down
89 changes: 86 additions & 3 deletions ast/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ func TestBaseDocEqIndexing(t *testing.T) {
x = input.x
glob.match("dead:*:beef", [":"], x)
}
glob_match_mappers {
input.x = x
glob.match("foo:*", [":"], x)
}
glob_match_mappers {
input.x = x
}
glob_match_overlapped_mappers {
input.x = x
glob.match("foo:*", [":"], x)
}
glob_match_overlapped_mappers {
input.x = x
glob.match("foo/*", ["/"], x)
}
glob_match_disjoint_mappers {
input.x = x
glob.match("foo:*", [":"], x)
}
glob_match_disjoint_mappers {
input.x = x
glob.match("bar/*", ["/"], x)
}
`)

tests := []struct {
Expand Down Expand Up @@ -346,6 +375,60 @@ func TestBaseDocEqIndexing(t *testing.T) {
glob.match("foo:*:*", [":"], x)
}`},
},
{
note: "glob.match - mapper and no mapper",
ruleset: "glob_match_mappers",
input: `{"x": "foo:bar"}`,
expectedRS: []string{
`
glob_match_mappers {
input.x = x
glob.match("foo:*", [":"], x)
}
`,
`
glob_match_mappers {
input.x = x
}
`},
},
{
// NOTE(tsandall): The rule index returns both rules because the trie nodes
// store multiple mappers and will traverse each one. Since both mappers
// generate a trie structure of:
//
// array
// scalar("foo")
// any
//
// The rules are added to the same leaf node. In the future, we could improve
// the indexer to distinguish the trie nodes using the delimiter but until
// then the indexer can just return extra rules.
note: "glob.match - multiple overlapped mappers",
ruleset: "glob_match_overlapped_mappers",
input: `{"x": "foo:bar"}`,
expectedRS: []string{
`
glob_match_overlapped_mappers {
input.x = x
glob.match("foo:*", [":"], x)
}
`, `
glob_match_overlapped_mappers {
input.x = x
glob.match("foo/*", ["/"], x)
}
`,
},
},
{
note: "glob.match - multiple disjoint mappers",
ruleset: "glob_match_disjoint_mappers",
input: `{"x": "foo:bar"}`,
expectedRS: []string{
`glob_match_disjoint_mappers { input.x = x; glob.match("foo:*", [":"], x) }`,
},
},
{
note: "glob.match unexpected value type",
ruleset: "glob_match",
Expand Down Expand Up @@ -565,9 +648,9 @@ func TestSplitStringEscaped(t *testing.T) {
func TestGetAllRules(t *testing.T) {
module := MustParseModule(`
package test
default p = 42
p {
input.x = "x1"
input.y = "y1"
Expand All @@ -578,7 +661,7 @@ func TestGetAllRules(t *testing.T) {
}
p {
input.z = "z1"
input.z = "z1"
}
`)

Expand Down

0 comments on commit 6c40f3f

Please sign in to comment.