Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to how query arguments are handled #197

Merged
merged 12 commits into from
Dec 20, 2016
4 changes: 2 additions & 2 deletions ast/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ func TestCompareModule(t *testing.T) {
}

a = MustParseModule(`package a.b.c
import x.y`)
import request.x.y`)
b = MustParseModule(`package a.b.c
import x.z`)
import request.x.z`)
result = Compare(a, b)

if result != -1 {
Expand Down
41 changes: 10 additions & 31 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ func (c *Compiler) checkSafetyRuleBodies() {
func (c *Compiler) checkSafetyRuleHeads() {
for _, m := range c.Modules {
for _, r := range m.Rules {
unsafe := r.HeadVars().Diff(r.Body.Vars(true))
unsafe := r.HeadVars().Diff(r.Body.Vars(VarVisitorParams{SkipClosures: true}))
for v := range unsafe {
c.err(NewError(UnsafeVarErr, r.Location, "%v: %v is unsafe (variable %v must appear in at least one expression within the body of %v)", r.Name, v, v, r.Name))
}
Expand Down Expand Up @@ -455,7 +455,8 @@ func (c *Compiler) resolveAllRefs() {
rule.Body = resolveRefsInBody(globals, rule.Body)
}

mod.Imports = rewriteImports(mod.Imports)
// Once imports have been resolved, they are no longer needed.
mod.Imports = nil
}

if c.moduleLoader != nil {
Expand Down Expand Up @@ -603,7 +604,7 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error
exports = exist.([]Var)
}
globals = getGlobals(qctx.Package, exports, qc.qctx.Imports)
qctx.Imports = rewriteImports(qctx.Imports)
qctx.Imports = nil
}

return resolveRefsInBody(globals, body), nil
Expand Down Expand Up @@ -954,7 +955,7 @@ func reorderBodyForSafety(globals VarSet, body Body) (Body, unsafeVars) {
safe := VarSet{}

for _, e := range body {
for v := range e.Vars(true) {
for v := range e.Vars(VarVisitorParams{SkipClosures: true}) {
if globals.Contains(v) {
safe.Add(v)
} else {
Expand Down Expand Up @@ -996,7 +997,7 @@ func reorderBodyForSafety(globals VarSet, body Body) (Body, unsafeVars) {
g := globals.Copy()
for i, e := range reordered {
if i > 0 {
g.Update(reordered[i-1].Vars(true))
g.Update(reordered[i-1].Vars(VarVisitorParams{SkipClosures: true}))
}
vis := &bodySafetyVisitor{
current: e,
Expand Down Expand Up @@ -1035,7 +1036,7 @@ func (vis *bodySafetyVisitor) Visit(x interface{}) Visitor {
func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) {
// Check term for safety. This is analogous to the rule head safety check.
tv := ac.Term.Vars()
bv := ac.Body.Vars(true)
bv := ac.Body.Vars(VarVisitorParams{SkipClosures: true})
bv.Update(vis.globals)
uv := tv.Diff(bv)
for v := range uv {
Expand Down Expand Up @@ -1071,15 +1072,15 @@ func reorderBodyForClosures(globals VarSet, body Body) (Body, unsafeVars) {
// expression.
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &varVisitor{vars: vs}
vis := &VarVisitor{vars: vs}
Walk(vis, x)
return true
})

// Compute vars that are closed over from the body but not yet
// contained in the output position of an expression in the reordered
// body. These vars are considered unsafe.
cv := vs.Intersect(body.Vars(true)).Diff(globals)
cv := vs.Intersect(body.Vars(VarVisitorParams{SkipClosures: true})).Diff(globals)
uv := cv.Diff(reordered.OutputVars(globals))

if len(uv) == 0 {
Expand All @@ -1106,7 +1107,7 @@ type localVarGenerator struct {

func newLocalVarGenerator(module *Module) *localVarGenerator {
exclude := NewVarSet()
vis := &varVisitor{
vis := &VarVisitor{
vars: exclude,
}
Walk(vis, module)
Expand Down Expand Up @@ -1264,25 +1265,3 @@ func resolveRefsInTerm(globals map[Var]Value, term *Term) *Term {
return term
}
}

// rewriteImports returns an updated slice of imports that replace the imports
// in a module or query context. Imports against the default root document are
// removed, aliases are unset, and the remaining imports are shortened to the
// head variable. The result is a set of imports that effectively ground
// variables appearing in rules and queries (which refer to query inputs).
func rewriteImports(imports []*Import) (result []*Import) {
for _, imp := range imports {
switch path := imp.Path.Value.(type) {
case Ref:
if !path[0].Equal(DefaultRootDocument) {
imp.Path = path[0]
imp.Alias = Var("")
result = append(result, imp)
}
case Var:
imp.Alias = Var("")
result = append(result, imp)
}
}
return result
}
53 changes: 26 additions & 27 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
"newMod": MustParseModule(`
package a.b

import aref.b.c as foo
import avar as bar
import request.aref.b.c as foo
import request.avar as bar
import data.m.n as baz

# a would be unbound
Expand Down Expand Up @@ -433,9 +433,9 @@ func TestCompilerResolveAllRefs(t *testing.T) {
c := NewCompiler()
c.Modules = getCompilerTestModules()
c.Modules["head"] = MustParseModule(`package head
import x.y.foo
import data.doc1 as bar
import qux as baz
import request.x.y.foo
import request.qux as baz
p[foo[bar[i]]] = {"baz": baz} :- true
`)
compileStages(c, "", "resolveAllRefs")
Expand Down Expand Up @@ -470,7 +470,7 @@ func TestCompilerResolveAllRefs(t *testing.T) {
mod3 := c.Modules["mod3"]
expr4 := mod3.Rules[0].Body[0]
term = expr4.Terms.([]*Term)[2]
e = MustParseTerm("{x.secret: [{x.keyid}]}")
e = MustParseTerm("{request.x.secret: [{request.x.keyid}]}")
if !term.Equal(e) {
t.Errorf("Wrong term (nested refs): expected %v but got: %v", e, term)
}
Expand All @@ -483,42 +483,42 @@ func TestCompilerResolveAllRefs(t *testing.T) {
}

acTerm1 := ac(mod5.Rules[0])
assertTermEqual(t, acTerm1.Term, MustParseTerm("x.a"))
assertTermEqual(t, acTerm1.Term, MustParseTerm("request.x.a"))
acTerm2 := ac(mod5.Rules[1])
assertTermEqual(t, acTerm2.Term, MustParseTerm("a.b.c.q.a"))
assertTermEqual(t, acTerm2.Term, MustParseTerm("request.a.b.c.q.a"))
acTerm3 := ac(mod5.Rules[2])
assertTermEqual(t, acTerm3.Body[0].Terms.([]*Term)[1], MustParseTerm("x.a"))
assertTermEqual(t, acTerm3.Body[0].Terms.([]*Term)[1], MustParseTerm("request.x.a"))
acTerm4 := ac(mod5.Rules[3])
assertTermEqual(t, acTerm4.Body[0].Terms.([]*Term)[1], MustParseTerm("a.b.c.q[i]"))
assertTermEqual(t, acTerm4.Body[0].Terms.([]*Term)[1], MustParseTerm("request.a.b.c.q[i]"))
acTerm5 := ac(mod5.Rules[4])
assertTermEqual(t, acTerm5.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Term, MustParseTerm("x.a"))
assertTermEqual(t, acTerm5.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Term, MustParseTerm("request.x.a"))
acTerm6 := ac(mod5.Rules[5])
assertTermEqual(t, acTerm6.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Body[0].Terms.([]*Term)[1], MustParseTerm("a.b.c.q[i]"))
assertTermEqual(t, acTerm6.Body[0].Terms.([]*Term)[2].Value.(*ArrayComprehension).Body[0].Terms.([]*Term)[1], MustParseTerm("request.a.b.c.q[i]"))

// Nested references.
mod6 := c.Modules["mod6"]
nested1 := mod6.Rules[0].Body[0].Terms.(*Term)
assertTermEqual(t, nested1, MustParseTerm("data.x[x[i].a[data.z.b[j]]]"))
assertTermEqual(t, nested1, MustParseTerm("data.x[request.x[i].a[data.z.b[j]]]"))

nested2 := mod6.Rules[1].Body[1].Terms.(*Term)
assertTermEqual(t, nested2, MustParseTerm("v[x[i]]"))
assertTermEqual(t, nested2, MustParseTerm("v[request.x[i]]"))

nested3 := mod6.Rules[3].Body[0].Terms.(*Term)
assertTermEqual(t, nested3, MustParseTerm("data.x[data.a.b.nested.r]"))

// Refs in head.
mod7 := c.Modules["head"]
assertTermEqual(t, mod7.Rules[0].Key, MustParseTerm("x.y.foo[data.doc1[i]]"))
assertTermEqual(t, mod7.Rules[0].Value, MustParseTerm(`{"baz": qux}`))
assertTermEqual(t, mod7.Rules[0].Key, MustParseTerm("request.x.y.foo[data.doc1[i]]"))
assertTermEqual(t, mod7.Rules[0].Value, MustParseTerm(`{"baz": request.qux}`))
}

func TestCompilerRewriteRefsInHead(t *testing.T) {
c := NewCompiler()
c.Modules["head"] = MustParseModule(`package head
import x.y.foo
import data.doc1 as bar
import qux as baz
import data.doc2 as corge
import request.x.y.foo
import request.qux as baz
p[foo[bar[i]]] = {"baz": baz, "corge": corge} :- true
`)

Expand All @@ -534,8 +534,8 @@ func TestCompilerRewriteRefsInHead(t *testing.T) {
t.Fatalf("Expected rule body to contain 3 expressions but got: %v", rule)
}

assertExprEqual(t, rule.Body[1], MustParseExpr("__local0__ = x.y.foo[data.doc1[i]]"))
assertExprEqual(t, rule.Body[2], MustParseExpr(`__local1__ = {"baz": qux, "corge": data.doc2}`))
assertExprEqual(t, rule.Body[1], MustParseExpr("__local0__ = request.x.y.foo[data.doc1[i]]"))
assertExprEqual(t, rule.Body[2], MustParseExpr(`__local1__ = {"baz": request.qux, "corge": data.doc2}`))
}

func TestCompilerSetRuleGraph(t *testing.T) {
Expand Down Expand Up @@ -869,7 +869,7 @@ func TestCompilerLazyLoading(t *testing.T) {

mod3 := MustParseModule(`package x
import data.foo.bar
import input
import request.input
z1 :- [ localvar | count(bar.baz.qux, localvar) ]`)

mod4 := MustParseModule(`
Expand Down Expand Up @@ -960,7 +960,7 @@ func TestQueryCompiler(t *testing.T) {
{"exports resolved", "z", "package a.b.c", nil, "data.a.b.c.z"},
{"imports resolved", "z", "package a.b.c.d", []string{"import data.a.b.c.z"}, "data.a.b.c.z"},
{"unsafe vars", "z", "", nil, fmt.Errorf("1 error occurred: 1:1: z is unsafe (variable z must appear in the output position of at least one non-negated expression)")},
{"safe vars", "data, abc", "package ex", []string{"import xyz as abc"}, "data, xyz"},
{"safe vars", "data, abc", "package ex", []string{"import request.xyz as abc"}, "data, request.xyz"},
{"reorder", "x != 1, x = 0", "", nil, "x = 0, x != 1"},
{"bad builtin", "deadbeef(1,2,3)", "", nil, fmt.Errorf("1 error occurred: 1:1: deadbeef is unsafe (variable deadbeef must appear in the output position of at least one non-negated expression)")},
}
Expand Down Expand Up @@ -1033,9 +1033,8 @@ func getCompilerTestModules() map[string]*Module {

mod3 := MustParseModule(`
package a.b.d
import req
import x as y
t = true :- req = {y.secret: [{y.keyid}]}
import request.x as y
t = true :- request = {y.secret: [{y.keyid}]}
x = false :- true
`)

Expand All @@ -1046,8 +1045,8 @@ func getCompilerTestModules() map[string]*Module {
mod5 := MustParseModule(`
package a.b.compr

import x as y
import a.b.c.q
import request.x as y
import request.a.b.c.q

p :- [y.a | true]
r :- [q.a | true]
Expand All @@ -1061,8 +1060,8 @@ func getCompilerTestModules() map[string]*Module {
package a.b.nested

import data.x
import x as y
import data.z
import request.x as y

p :- x[y[i].a[z.b[j]]]
q :- x = v, v[y[i]]
Expand Down
10 changes: 5 additions & 5 deletions ast/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func ExampleCompiler_Compile() {
package opa.example

import data.foo
import bar
import request.bar

p[x] :- foo[x], not bar[x], x >= min_x

Expand Down Expand Up @@ -50,7 +50,7 @@ func ExampleCompiler_Compile() {
// Output:
//
// Expr 1: data.foo[x]
// Expr 2: not bar[x]
// Expr 2: not request.bar[x]
// Expr 3: gte(x, data.opa.example.min_x)
}

Expand All @@ -62,7 +62,7 @@ func ExampleQueryCompiler_Compile() {
package opa.example

import data.foo
import bar
import request.bar

p[x] :- foo[x], not bar[x], x >= min_x

Expand Down Expand Up @@ -98,7 +98,7 @@ func ExampleQueryCompiler_Compile() {
// ast.Parse<X> functions that return meaningful error messages
// instead.
ast.MustParsePackage("package opa.example"),
ast.MustParseImports("import queryinput"),
ast.MustParseImports("import request.queryinput"),
))

// Parse the input query to obtain the AST representation.
Expand All @@ -116,5 +116,5 @@ func ExampleQueryCompiler_Compile() {

// Output:
//
// Compiled: data.opa.example.p[x], lt(x, queryinput)
// Compiled: data.opa.example.p[x], lt(x, request.queryinput)
}
Loading