From 46e8c74b548c539fc5f080cbe4b71168605167b7 Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Thu, 16 Nov 2017 14:48:17 -0800 Subject: [PATCH] Fix eval of objects/sets containing vars If object keys (or set elements) are vars, then they must be plugged before lookup is attempted, otherwise lookup will fail and expression will be undefined. This could result in poor performance if rule outputs are large enough. In that case, we could revisit how the rule outputs are constructed (e.g., we could construct outputs with the first level of set/object plugged and propagate binding lists for each key/value pair.) Also, updated test runner to accept OPA_TRACE_TEST environment var to selectively enable tracing output. This is useful for debugging purposes. Fixes #505 --- topdown/eval.go | 48 +++++++++++++++++++++++++++-------------- topdown/topdown_test.go | 31 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 16 deletions(-) diff --git a/topdown/eval.go b/topdown/eval.go index 22bcc7e6b9..e95272f9ae 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -1328,15 +1328,6 @@ func (e evalTerm) next(iter unifyIterator, plugged *ast.Term) error { func (e evalTerm) enumerate(iter unifyIterator) error { switch v := e.term.Value.(type) { - case *ast.Set: - for _, elem := range *v { - err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error { - return e.next(iter, elem) - }) - if err != nil { - return err - } - } case ast.Array: for i := range v { k := ast.IntNumberTerm(i) @@ -1349,8 +1340,17 @@ func (e evalTerm) enumerate(iter unifyIterator) error { } case ast.Object: for _, pair := range v { - err := e.e.biunify(pair[0], e.ref[e.pos], e.bindings, e.bindings, func() error { - return e.next(iter, pair[0]) + err := e.e.biunify(pair[0], e.ref[e.pos], e.termbindings, e.bindings, func() error { + return e.next(iter, e.bindings.Plug(e.ref[e.pos])) + }) + if err != nil { + return err + } + } + case *ast.Set: + for _, elem := range *v { + err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error { + return e.next(iter, e.bindings.Plug(e.ref[e.pos])) }) if err != nil { return err @@ -1364,13 +1364,29 @@ func (e evalTerm) enumerate(iter unifyIterator) error { func (e evalTerm) get(plugged *ast.Term) (*ast.Term, *bindings) { switch v := e.term.Value.(type) { case *ast.Set: - if v.Contains(plugged) { - return e.termbindings.apply(plugged) + if v.IsGround() { + if v.Contains(plugged) { + return e.termbindings.apply(plugged) + } + } else { + for _, elem := range *v { + if e.termbindings.Plug(elem).Equal(plugged) { + return e.termbindings.apply(plugged) + } + } } case ast.Object: - term := v.Get(plugged) - if term != nil { - return e.termbindings.apply(term) + if v.IsGround() { + term := v.Get(plugged) + if term != nil { + return e.termbindings.apply(term) + } + } else { + for i := range v { + if e.termbindings.Plug(v[i][0]).Equal(plugged) { + return e.termbindings.apply(v[i][1]) + } + } } case ast.Array: term := v.Get(plugged) diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 1f9244d28e..f7b19603a7 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "os" "reflect" "sort" "strings" @@ -471,6 +472,26 @@ func TestTopDownVirtualDocs(t *testing.T) { {"empty partial set", []string{"p[1] { a[0] = 100 }"}, "[]"}, {"empty partial object", []string{`p["x"] = 1 { a[0] = 100 }`}, "{}"}, + + {"input: non-ground object keys", []string{ + `p = x { q.a.b = x }`, + `q = {x: {y: 1}} { x = "a"; y = "b" }`, + }, "1"}, + + {"input: non-ground set elements", []string{ + `p { q["c"] }`, + `q = {x, "b", z} { x = "a"; z = "c" }`, + }, "true"}, + + {"output: non-ground object keys", []string{ + `p[x] { q[i][j] = x }`, + `q = {x: {x1: 1}, y: {y1: 2}} { x = "a"; y = "b"; x1 = "a1"; y1 = "b1" }`, + }, "[1, 2]"}, + + {"output: non-ground set elements", []string{ + `p[x] { q[x] }`, + `q = {x, "b", z} { x = "a"; z = "c" }`, + }, `["a", "b", "c"]`}, } data := loadSmallTestData() @@ -2375,6 +2396,12 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S WithTransaction(txn). WithInput(inputTerm) + var tracer BufferTracer + + if os.Getenv("OPA_TRACE_TEST") != "" { + query = query.WithTracer(&tracer) + } + testutil.Subtest(t, note, func(t *testing.T) { switch e := expected.(type) { case error: @@ -2391,6 +2418,10 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S case string: qrs, err := query.Run(ctx) + if tracer != nil { + PrettyTrace(os.Stdout, tracer) + } + if err != nil { t.Fatalf("Unexpected error: %v", err) }