Skip to content

Commit

Permalink
Fix eval of objects/sets containing vars
Browse files Browse the repository at this point in the history
If object keys (or set elements) are vars, then they must be plugged
before lookup is attempted, otherwise lookup will fail and expression
will be undefined. This could result in poor performance if rule outputs
are large enough. In that case, we could revisit how the rule outputs
are constructed (e.g., we could construct outputs with the first level
of set/object plugged and propagate binding lists for each key/value
pair.)

Also, updated test runner to accept OPA_TRACE_TEST environment var to
selectively enable tracing output. This is useful for debugging
purposes.

Fixes open-policy-agent#505
  • Loading branch information
tsandall committed Nov 16, 2017
1 parent 02871a4 commit 46e8c74
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
48 changes: 32 additions & 16 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,15 +1328,6 @@ func (e evalTerm) next(iter unifyIterator, plugged *ast.Term) error {
func (e evalTerm) enumerate(iter unifyIterator) error {

switch v := e.term.Value.(type) {
case *ast.Set:
for _, elem := range *v {
err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, elem)
})
if err != nil {
return err
}
}
case ast.Array:
for i := range v {
k := ast.IntNumberTerm(i)
Expand All @@ -1349,8 +1340,17 @@ func (e evalTerm) enumerate(iter unifyIterator) error {
}
case ast.Object:
for _, pair := range v {
err := e.e.biunify(pair[0], e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, pair[0])
err := e.e.biunify(pair[0], e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.bindings.Plug(e.ref[e.pos]))
})
if err != nil {
return err
}
}
case *ast.Set:
for _, elem := range *v {
err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.bindings.Plug(e.ref[e.pos]))
})
if err != nil {
return err
Expand All @@ -1364,13 +1364,29 @@ func (e evalTerm) enumerate(iter unifyIterator) error {
func (e evalTerm) get(plugged *ast.Term) (*ast.Term, *bindings) {
switch v := e.term.Value.(type) {
case *ast.Set:
if v.Contains(plugged) {
return e.termbindings.apply(plugged)
if v.IsGround() {
if v.Contains(plugged) {
return e.termbindings.apply(plugged)
}
} else {
for _, elem := range *v {
if e.termbindings.Plug(elem).Equal(plugged) {
return e.termbindings.apply(plugged)
}
}
}
case ast.Object:
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
if v.IsGround() {
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
}
} else {
for i := range v {
if e.termbindings.Plug(v[i][0]).Equal(plugged) {
return e.termbindings.apply(v[i][1])
}
}
}
case ast.Array:
term := v.Get(plugged)
Expand Down
31 changes: 31 additions & 0 deletions topdown/topdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"os"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -471,6 +472,26 @@ func TestTopDownVirtualDocs(t *testing.T) {

{"empty partial set", []string{"p[1] { a[0] = 100 }"}, "[]"},
{"empty partial object", []string{`p["x"] = 1 { a[0] = 100 }`}, "{}"},

{"input: non-ground object keys", []string{
`p = x { q.a.b = x }`,
`q = {x: {y: 1}} { x = "a"; y = "b" }`,
}, "1"},

{"input: non-ground set elements", []string{
`p { q["c"] }`,
`q = {x, "b", z} { x = "a"; z = "c" }`,
}, "true"},

{"output: non-ground object keys", []string{
`p[x] { q[i][j] = x }`,
`q = {x: {x1: 1}, y: {y1: 2}} { x = "a"; y = "b"; x1 = "a1"; y1 = "b1" }`,
}, "[1, 2]"},

{"output: non-ground set elements", []string{
`p[x] { q[x] }`,
`q = {x, "b", z} { x = "a"; z = "c" }`,
}, `["a", "b", "c"]`},
}

data := loadSmallTestData()
Expand Down Expand Up @@ -2375,6 +2396,12 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S
WithTransaction(txn).
WithInput(inputTerm)

var tracer BufferTracer

if os.Getenv("OPA_TRACE_TEST") != "" {
query = query.WithTracer(&tracer)
}

testutil.Subtest(t, note, func(t *testing.T) {
switch e := expected.(type) {
case error:
Expand All @@ -2391,6 +2418,10 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S
case string:
qrs, err := query.Run(ctx)

if tracer != nil {
PrettyTrace(os.Stdout, tracer)
}

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
Expand Down

0 comments on commit 46e8c74

Please sign in to comment.