Skip to content

Commit

Permalink
Support "data" query in REPL
Browse files Browse the repository at this point in the history
Now that the QueryCompiler is being used, we can repl.OneShot("data") and see
the entire global document evaluated!

Fixes open-policy-agent#150
  • Loading branch information
tsandall committed Nov 24, 2016
1 parent 79d3a95 commit 431b146
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 30 deletions.
22 changes: 14 additions & 8 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ type QueryContext struct {
Imports []*Import
}

// NewQueryContext returns a new QueryContext object.
func NewQueryContext(pkg *Package, imports []*Import) *QueryContext {
return &QueryContext{
Package: pkg,
Imports: imports,
}
}

// NewQueryContextForModule returns a new QueryContext object based on the
// provided module.
func NewQueryContextForModule(mod *Module) *QueryContext {
return NewQueryContext(mod.Package, mod.Imports)
}

// Copy returns a deep copy of qc.
func (qc *QueryContext) Copy() *QueryContext {
if qc == nil {
Expand All @@ -92,14 +106,6 @@ func (qc *QueryContext) Copy() *QueryContext {
return &cpy
}

// NewQueryContext returns a new QueryContext object.
func NewQueryContext(pkg *Package, imports []*Import) *QueryContext {
return &QueryContext{
Package: pkg,
Imports: imports,
}
}

// QueryCompiler defines the interface for compiling ad-hoc queries.
type QueryCompiler interface {

Expand Down
21 changes: 4 additions & 17 deletions repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,6 @@ func (r *REPL) cmdUnset(args []string) bool {
}

func (r *REPL) compileBody(body ast.Body) (ast.Body, error) {
name := r.generateRuleName()

rule := &ast.Rule{
Location: body[0].Location,
Name: name,
Value: ast.BooleanTerm(true),
Body: body,
}

mod := r.modules[r.currentModuleID]
prev := mod.Rules
mod.Rules = append(mod.Rules, rule)

policies := r.store.ListPolicies(r.txn)

Expand All @@ -355,14 +343,13 @@ func (r *REPL) compileBody(body ast.Body) (ast.Body, error) {
compiler := ast.NewCompiler()

if compiler.Compile(policies); compiler.Failed() {
mod.Rules = prev
return nil, compiler.Errors
}

compiledMod := compiler.Modules[r.currentModuleID]
compiledBody := compiledMod.Rules[len(prev)].Body

return compiledBody, nil
qctx := ast.NewQueryContextForModule(r.modules[r.currentModuleID])
return compiler.QueryCompiler().
WithContext(qctx).
Compile(body)
}

func (r *REPL) compileRule(rule *ast.Rule) error {
Expand Down
56 changes: 51 additions & 5 deletions repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestUnset(t *testing.T) {
repl.OneShot("unset p")
repl.OneShot("p")
result := buffer.String()
if result != "error: 1 error occurred: 1:1: repl2: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
if result != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
t.Errorf("Expected p to be unsafe but got: %v", result)
return
}
Expand All @@ -155,7 +155,7 @@ func TestUnset(t *testing.T) {
repl.OneShot("unset p")
repl.OneShot("p")
result = buffer.String()
if result != "error: 1 error occurred: 1:1: repl4: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
if result != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
t.Errorf("Expected p to be unsafe but got: %v", result)
return
}
Expand Down Expand Up @@ -301,6 +301,52 @@ func TestOneShotJSON(t *testing.T) {
}
}

func TestEvalData(t *testing.T) {
store := newTestStore()
var buffer bytes.Buffer
repl := newRepl(store, &buffer)
testmod := ast.MustParseModule(`package ex
p = [1,2,3]`)
if err := storage.InsertPolicy(store, "test", testmod, nil, false); err != nil {
panic(err)
}
repl.OneShot("data")
expected := parseJSON(`
{
"a": [
{
"b": {
"c": [
true,
2,
false
]
}
},
{
"b": {
"c": [
false,
true,
1
]
}
}
],
"ex": {
"p": [
1,
2,
3
]
}
}`)
result := parseJSON(buffer.String())
if !reflect.DeepEqual(result, expected) {
t.Fatalf("Expected:\n%v\n\nGot:\n%v", expected, result)
}
}

func TestEvalFalse(t *testing.T) {
store := newTestStore()
var buffer bytes.Buffer
Expand Down Expand Up @@ -497,7 +543,7 @@ func TestEvalRuleCompileError(t *testing.T) {
repl.OneShot("p = true :- true")
result = buffer.String()
if result != "" {
t.Errorf("Expected valid rule to compile (because state should have been rolled back) but got: %v", result)
t.Errorf("Expected valid rule to compile (because state should be unaffected) but got: %v", result)
}
}

Expand All @@ -508,7 +554,7 @@ func TestEvalBodyCompileError(t *testing.T) {
repl.outputFormat = "json"
repl.OneShot("x = 1, y > x")
result1 := buffer.String()
expected1 := "error: 1 error occurred: 1:1: repl0: y is unsafe (variable y must appear in the output position of at least one non-negated expression)\n"
expected1 := "error: 1 error occurred: 1:1: y is unsafe (variable y must appear in the output position of at least one non-negated expression)\n"
if result1 != expected1 {
t.Errorf("Expected error message in output but got`: %v", result1)
return
Expand Down Expand Up @@ -615,7 +661,7 @@ func TestEvalPackage(t *testing.T) {
repl.OneShot("package baz.qux")
buffer.Reset()
repl.OneShot("p")
if buffer.String() != "error: 1 error occurred: 1:1: repl0: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
if buffer.String() != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" {
t.Errorf("Expected unsafe variable error but got: %v", buffer.String())
return
}
Expand Down

0 comments on commit 431b146

Please sign in to comment.