Skip to content

Commit

Permalink
ast/compile: use all vars from rule body for index candidates (open-p…
Browse files Browse the repository at this point in the history
…olicy-agent#3709)

Before, we'd only looked at the vars preceding the comprehension in the body
containing it. In the case of nested comprehensions, that would have excluded
the rule body OUTSIDE of the nested body.

Now, we'll accumulate candidates over multiple bodies -- capturing the ones
that had been missing before.

Fixes open-policy-agent#3579.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
Signed-off-by: Dolev Farhi <farhi.dolev@gmail.com>
  • Loading branch information
srenatus authored and dolevf committed Nov 4, 2021
1 parent 3e1d221 commit 45cd4ca
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 53 deletions.
2 changes: 1 addition & 1 deletion ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -1833,8 +1833,8 @@ func (ci *ComprehensionIndex) String() string {

func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 {
var n uint64
cpy := candidates.Copy()
WalkBodies(node, func(b Body) bool {
cpy := candidates.Copy()
for _, expr := range b {
index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr)
if index != nil {
Expand Down
116 changes: 64 additions & 52 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3678,12 +3678,14 @@ func TestCompilerWithStageAfterWithMetrics(t *testing.T) {

func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {

type expectedComprehension struct {
term, keys string
}
type exp map[int]expectedComprehension
tests := []struct {
note string
module string
atRow int
wantTerm string
wantKeys string
expected exp
wantDebug int
}{
{
Expand All @@ -3696,9 +3698,10 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
keys = [j | value = input[j]]
}
`,
atRow: 6,
wantTerm: `[j | value = input[j]]`,
wantKeys: `[value]`,
expected: exp{6: {
term: `[j | value = input[j]]`,
keys: `[value]`,
}},
wantDebug: 1,
},
{
Expand All @@ -3712,9 +3715,10 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
keys = [j | v1 = input[j].v1; v2 = input[j].v2]
}
`,
atRow: 7,
wantTerm: `[j | v1 = input[j].v1; v2 = input[j].v2]`,
wantKeys: `[v1, v2]`,
expected: exp{7: {
term: `[j | v1 = input[j].v1; v2 = input[j].v2]`,
keys: `[v1, v2]`,
}},
wantDebug: 1,
},
{
Expand All @@ -3727,9 +3731,10 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
ys = {y | x = input[y]}
}
`,
atRow: 6,
wantTerm: `{y | x = input[y]}`,
wantKeys: `[x]`,
expected: exp{6: {
term: `{y | x = input[y]}`,
keys: `[x]`,
}},
// there are still things going on here that'll be reported, besides successful indexing
wantDebug: 2,
},
Expand Down Expand Up @@ -3798,19 +3803,24 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
wantDebug: 1,
},
{
note: "skip: due to nested comprehension containing candidate",
note: "mixed: due to nested comprehension containing candidate + indexed nested comprehension with key from rule body",
module: `
package test
p {
x = input[i] # 'x' is a candidate
y = 2 # 'y' is a candidate
x = input[i] # 'x' is a candidate for z (line 7)
y = 2 # 'y' is a candidate for z
z = [1 |
x = data.foo[j] # 'x' is an index key
t = [1 | data.bar[k] = y] # 'y' disqualifies indexing because it is nested inside a comprehension
x = data.foo[j] # 'x' is an index key for z
t = [1 | data.bar[k] = y] # 'y' disqualifies indexing of z because it is nested inside a comprehension
]
}
`,
// Note: no comprehension index for line 7 (`z = [ ...`)
expected: exp{9: {
keys: `[y]`,
term: `[1 | data.bar[k] = y]`,
}},
wantDebug: 2,
},
{
Expand Down Expand Up @@ -3855,9 +3865,10 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
y = input[x]
ys = [y | y = input[z]; z = x]
}`,
atRow: 6,
wantTerm: ` [y | y = input[z]; z = x]`,
wantKeys: `[x, y]`,
expected: exp{6: {
term: ` [y | y = input[z]; z = x]`,
keys: `[x, y]`,
}},
wantDebug: 1,
},
}
Expand Down Expand Up @@ -3886,48 +3897,49 @@ func TestCompilerBuildComprehensionIndexKeySet(t *testing.T) {
}

n := m.Counter(compileStageComprehensionIndexBuild).Value().(uint64)

if tc.atRow == 0 {
if n > 0 || len(compiler.comprehensionIndices) > 0 {
t.Fatal("expected no indices to be built. got:", compiler.comprehensionIndices)
}
if exp, act := len(tc.expected), len(compiler.comprehensionIndices); exp != act {
t.Fatalf("expected %d indices to be built. got: %d", exp, act)
}
if len(tc.expected) == 0 {
return
}

if n != 1 {
if n == 0 {
t.Fatal("expected counter to be incremented")
}

var comprehension *Term
WalkTerms(compiler.Modules["test.rego"], func(x *Term) bool {
if !IsComprehension(x.Value) {
return true
}
if x.Location.Row != tc.atRow {
for row, exp := range tc.expected {
var comprehension *Term
WalkTerms(compiler.Modules["test.rego"], func(x *Term) bool {
if !IsComprehension(x.Value) {
return true
}
_, ok := tc.expected[x.Location.Row]
if !ok {
return false
} else if comprehension != nil {
t.Fatal("expected at most one comprehension per line in test module")
}
comprehension = x
return false
} else if comprehension != nil {
t.Fatal("expected at most one comprehension per line in test module")
})
if comprehension == nil {
t.Fatal("expected comprehension at line:", row)
}
comprehension = x
return false
})
if comprehension == nil {
t.Fatal("expected comprehension at line:", tc.atRow)
}

result := compiler.ComprehensionIndex(comprehension)
if result == nil {
t.Fatal("expected result")
}
result := compiler.ComprehensionIndex(comprehension)
if result == nil {
t.Fatal("expected result")
}

expTerm := MustParseTerm(tc.wantTerm)
if !result.Term.Equal(expTerm) {
t.Fatalf("expected term to be %v but got: %v", expTerm, result.Term)
}
expTerm := MustParseTerm(exp.term)
if !result.Term.Equal(expTerm) {
t.Fatalf("expected term to be %v but got: %v", expTerm, result.Term)
}

expKeys := MustParseTerm(tc.wantKeys).Value.(*Array)
if NewArray(result.Keys...).Compare(expKeys) != 0 {
t.Fatalf("expected keys to be %v but got: %v", expKeys, result.Keys)
expKeys := MustParseTerm(exp.keys).Value.(*Array)
if NewArray(result.Keys...).Compare(expKeys) != 0 {
t.Fatalf("expected keys to be %v but got: %v", expKeys, result.Keys)
}
}
})
}
Expand Down

0 comments on commit 45cd4ca

Please sign in to comment.