Skip to content

Commit

Permalink
Fix reordering for unsafe ref heads
Browse files Browse the repository at this point in the history
Evaluation assumes that ref heads are safe when evaluating expressions.
In the future this might be relaxed, however for now, reordering must
not consider vars safe if they would unify with a ref that has an unsafe
head.

These changes also update the compiler to keep track of generated vars
in each module. This allows the compiler to suppress error messages for
unsafe generated vars--this is fine as an unsafe generated var message
will always be accompanied with another unsafe var message. It would be
confusing to users if unsafe var messages were reported for vars that
they had not specified.

Also, refactored safety check test case. It was becoming difficult to
track down errors because the test case exercised the entire module at
once. Now the individual test cases are isolated and it's easier to
identify errors.

Fixes #297
  • Loading branch information
tsandall committed Mar 22, 2017
1 parent f3a57a1 commit 9b6adb6
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 116 deletions.
32 changes: 22 additions & 10 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ type Compiler struct {
// A rule depends on another rule if it refers to it.
RuleGraph map[*Rule]map[*Rule]struct{}

moduleLoader ModuleLoader
stages []stage
generatedVars map[*Module]VarSet
moduleLoader ModuleLoader
stages []stage
}

// QueryContext contains contextual information for running an ad-hoc query.
Expand Down Expand Up @@ -155,8 +156,9 @@ type stage struct {
func NewCompiler() *Compiler {

c := &Compiler{
Modules: map[string]*Module{},
RuleGraph: map[*Rule]map[*Rule]struct{}{},
Modules: map[string]*Module{},
RuleGraph: map[*Rule]map[*Rule]struct{}{},
generatedVars: map[*Module]VarSet{},
}

c.ModuleTree = NewModuleTree(nil)
Expand Down Expand Up @@ -435,7 +437,9 @@ func (c *Compiler) checkSafetyRuleBodies() {
reordered, unsafe := reorderBodyForSafety(safe, r.Body)
if len(unsafe) != 0 {
for v := range unsafe.Vars() {
c.err(NewError(UnsafeVarErr, r.Loc(), "%v %v is unsafe", VarTypeName, v))
if !c.generatedVars[m].Contains(v) {
c.err(NewError(UnsafeVarErr, r.Loc(), "%v %v is unsafe", VarTypeName, v))
}
}
} else {
r.Body = reordered
Expand All @@ -456,7 +460,9 @@ func (c *Compiler) checkSafetyRuleHeads() {
for _, r := range m.Rules {
unsafe := r.Head.Vars().Diff(r.Body.Vars(safetyCheckVarVisitorParams))
for v := range unsafe {
c.err(NewError(UnsafeVarErr, r.Loc(), "%v %v is unsafe", VarTypeName, v))
if !c.generatedVars[m].Contains(v) {
c.err(NewError(UnsafeVarErr, r.Loc(), "%v %v is unsafe", VarTypeName, v))
}
}
}
}
Expand Down Expand Up @@ -634,6 +640,7 @@ func (c *Compiler) rewriteRefsInHead() {
}
}
}
c.generatedVars[mod] = generator.Generated()
}
}

Expand Down Expand Up @@ -1339,7 +1346,8 @@ func reorderBodyForClosures(globals VarSet, body Body) (Body, unsafeVars) {
const localVarFmt = "__local%d__"

type localVarGenerator struct {
exclude VarSet
exclude VarSet
generated VarSet
}

func newLocalVarGenerator(module *Module) *localVarGenerator {
Expand All @@ -1348,17 +1356,21 @@ func newLocalVarGenerator(module *Module) *localVarGenerator {
vars: exclude,
}
Walk(vis, module)
return &localVarGenerator{exclude}
return &localVarGenerator{exclude, NewVarSet()}
}

func (l *localVarGenerator) Generated() VarSet {
return l.generated
}

func (l *localVarGenerator) Generate() Var {
name := Var("")
x := 0
for len(name) == 0 || l.exclude.Contains(name) {
for len(name) == 0 || l.generated.Contains(name) || l.exclude.Contains(name) {
name = Var(fmt.Sprintf(localVarFmt, x))
x++
}
l.exclude.Add(name)
l.generated.Add(name)
return name
}

Expand Down
196 changes: 101 additions & 95 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,28 +170,36 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) {
{"with", `data.a.b.d.t with input as x; x = 1`, `x = 1; data.a.b.d.t with input as x`},
{"with-2", `data.a.b.d.t with input.x as x; x = 1`, `x = 1; data.a.b.d.t with input.x as x`},
{"with-nop", "data.somedoc[x] with input as true", "data.somedoc[x] with input as true"},
{"ref-head", `s = [["foo"], ["bar"]]; x = y[0]; y = s[_]; contains(x, "oo")`, `
s = [["foo"], ["bar"]];
y = s[_];
x = y[0];
contains(x, "oo")
`},
}

for i, tc := range tests {
c := NewCompiler()
c.Modules = getCompilerTestModules()
c.Modules["reordering"] = MustParseModule(fmt.Sprintf(
`package test
test.Subtest(t, tc.note, func(t *testing.T) {
c := NewCompiler()
c.Modules = getCompilerTestModules()
c.Modules["reordering"] = MustParseModule(fmt.Sprintf(
`package test
p { %s }`, tc.body))

compileStages(c, "", "checkSafetyBody")
compileStages(c, "", "checkSafetyBody")

if c.Failed() {
t.Errorf("%v (#%d): Unexpected compilation error: %v", tc.note, i, c.Errors)
return
}
if c.Failed() {
t.Errorf("%v (#%d): Unexpected compilation error: %v", tc.note, i, c.Errors)
return
}

expected := MustParseBody(tc.expected)
result := c.Modules["reordering"].Rules[0].Body
expected := MustParseBody(tc.expected)
result := c.Modules["reordering"].Rules[0].Body

if !expected.Equal(result) {
t.Errorf("%v (#%d): Expected body to be ordered and equal to %v but got: %v", tc.note, i, expected, result)
}
if !expected.Equal(result) {
t.Errorf("%v (#%d): Expected body to be ordered and equal to %v but got: %v", tc.note, i, expected, result)
}
})
}
}

Expand Down Expand Up @@ -226,96 +234,93 @@ q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true !=
}

func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
c := NewCompiler()

c.Modules = getCompilerTestModules()
c.Modules = map[string]*Module{
"newMod": MustParseModule(`package a.b
import input.aref.b.c as foo
import input.avar as bar
import data.m.n as baz
unboundRef1 = true { a.b.c = "foo" }
unboundRef2 = true { {"foo": [{"bar": a.b.c}]} = {"foo": [{"bar": "baz"}]} }
inputPosRef = true { a = [1, 2, 3, 4]; a[i] != 100 }
unboundNegated1 = true { a = [1, 2, 3, 4]; not a[i] = x }
unboundNegated2[x] { a = [1, 2, 3, 4]; not a[i] = x }
unboundNegated3[x] = true { a = [1, 2, 3, 4]; b = [1, 2, 3, 4]; not a[i] = x; not b[j] = x }
unboundNegated4 = true { a = [{"foo": ["bar", "baz"]}]; not a[0].foo = [a[0].foo[i], a[0].foo[j]] }
unsafeBuiltin = true { count([1, 2, x], x) }
unsafeBuiltinOperator = true { count(eq, 1) }
negatedSafe = true { a = [1, 2, 3, 4]; b = [1, 2, 3, 4]; not a[i] = x; b[i] = x }
unboundNoTarget = true { x > 0; x <= 3; x != 2 }
unboundArrayComprBody1 = true { _ = [x | x = data.a[_]; y > 1] }
unboundArrayComprBody2 = true { _ = [x | x = a[_]; a = [y | y = data.a[_]; z > 1]] }
unboundArrayComprBody3 = true { _ = [v | v = [x | x = data.a[_]]; x > 1] }
unboundArrayComprTerm1 = true { _ = [u | true] }
unboundArrayComprTerm2 = true { _ = [v | v = [w | w != 0]] }
unboundArrayComprTerm3 = true { _ = [x[i] | x = []] }
unboundArrayComprMixed1 = true { _ = [x | y = [a | a = z[i]]] }
unboundBuiltinOperatorArrayCompr = true { 1 = 1; [true | eq != 2] }
unsafeClosure1 = true { x = [x | x = 1] }
unsafeClosure2 = true { x = y; x = [y | y = 1] }
unsafeNestedHead = true { count(baz[i].attr[bar[dead.beef]], n) }
negatedImport1 = true { not foo }
negatedImport2 = true { not bar }
negatedImport3 = true { not baz }
rewriteUnsafe[{"foo": dead[i]}] { true }
unsafeWithValue1 = true { data.a.b.d.t with input as x }
unsafeWithValue2 = true { x = data.a.b.d.t with input as x }`,
)}
compileStages(c, "", "checkSafetyBody")
moduleBegin := `
package a.b
import input.aref.b.c as foo
import input.avar as bar
import data.m.n as baz
`

tests := []struct {
note string
moduleContent string
expected string
}{
{"ref-head", `p { a.b.c = "foo" }`, `{a,}`},
{"ref-head-2", `p { {"foo": [{"bar": a.b.c}]} = {"foo": [{"bar": "baz"}]} }`, `{a,}`},
{"negation", `p { a = [1, 2, 3, 4]; not a[i] = x }`, `{i, x}`},
{"negation-head", `p[x] { a = [1, 2, 3, 4]; not a[i] = x }`, `{i,x}`},
{"negation-multiple", `p { a = [1, 2, 3, 4]; b = [1, 2, 3, 4]; not a[i] = x; not b[j] = x }`, `{i, x, j}`},
{"negation-nested", `p { a = [{"foo": ["bar", "baz"]}]; not a[0].foo = [a[0].foo[i], a[0].foo[j]] } `, `{i, j}`},
{"builtin-input", `p { count([1, 2, x], x) }`, `{x,}`},
{"builtin-input-name", `p { count(eq, 1) }`, `{eq,}`},
{"builtin-multiple", `p { x > 0; x <= 3; x != 2 }`, `{x,}`},
{"array-compr", `p { _ = [x | x = data.a[_]; y > 1] }`, `{y,}`},
{"array-compr-nested", `p { _ = [x | x = a[_]; a = [y | y = data.a[_]; z > 1]] }`, `{z,}`},
{"array-compr-closure", `p { _ = [v | v = [x | x = data.a[_]]; x > 1] }`, `{x,}`},
{"array-compr-term", `p { _ = [u | true] }`, `{u,}`},
{"array-compr-term-nested", `p { _ = [v | v = [w | w != 0]] }`, `{w,}`},
{"array-compr-term-output", `p { _ = [x[i] | x = []] }`, `{i,}`},
{"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`},
{"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`},
{"closure-self", `p { x = [x | x = 1] }`, `{x,}`},
{"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{y,}`},
{"nested", `p { count(baz[i].attr[bar[dead.beef]], n) }`, `{dead,}`},
{"negated-import", `p { not foo; not bar; not baz }`, `set()`},
{"rewritten", `p[{"foo": dead[i]}] { true }`, `{dead, i}`},
{"with-value", `p { data.a.b.d.t with input as x }`, `{x,}`},
{"with-value-2", `p { x = data.a.b.d.t with input as x }`, `{x,}`},
}

makeErrMsg := func(varName string) string {
return fmt.Sprintf("rego_unsafe_var_error: var %v is unsafe", varName)
}

expected := []string{
makeErrMsg("a"),
makeErrMsg("a"),
makeErrMsg("i"),
makeErrMsg("x"),
makeErrMsg("i"),
makeErrMsg("x"),
makeErrMsg("i"),
makeErrMsg("j"),
makeErrMsg("x"),
makeErrMsg("i"),
makeErrMsg("j"),
makeErrMsg("x"),
makeErrMsg("eq"),
makeErrMsg("x"),
makeErrMsg("y"),
makeErrMsg("z"),
makeErrMsg("x"),
makeErrMsg("u"),
makeErrMsg("w"),
makeErrMsg("i"),
makeErrMsg("x"),
makeErrMsg("z"),
makeErrMsg("eq"),
makeErrMsg("x"),
makeErrMsg("y"),
makeErrMsg("dead"),
makeErrMsg("dead"),
makeErrMsg("x"),
makeErrMsg("x"),
}
for _, tc := range tests {
test.Subtest(t, tc.note, func(t *testing.T) {

result := compilerErrsToStringSlice(c.Errors)
sort.Strings(expected)
// Build slice of expected error messages.
expected := []string{}

if len(result) != len(expected) {
t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n"))
}
MustParseTerm(tc.expected).Value.(*Set).Iter(func(x *Term) bool {
expected = append(expected, makeErrMsg(string(x.Value.(Var))))
return false
})

for i := range result {
if expected[i] != result[i] {
t.Errorf("Expected %v but got: %v", expected[i], result[i])
}
}
sort.Strings(expected)

// Compile test module.
c := NewCompiler()
c.Modules = map[string]*Module{
"newMod": MustParseModule(fmt.Sprintf(`
%v
%v
`, moduleBegin, tc.moduleContent)),
}

compileStages(c, "", "checkSafetyBody")

// Get errors.
result := compilerErrsToStringSlice(c.Errors)

// Check against expected.
if len(result) != len(expected) {
t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n"))
}

for i := range result {
if expected[i] != result[i] {
t.Errorf("Expected %v but got: %v", expected[i], result[i])
}
}

})
}
}

func TestCompilerCheckWithModifiers(t *testing.T) {
Expand Down Expand Up @@ -595,7 +600,8 @@ q[x] = y { p[x] = y }`),
"newMod6": MustParseModule(`package rec5
acp[x] { acq[x] }
acq[x] { a = [x | acp[x]]; a[i] = x }`,
acq[x] { a = [true | acp[_]]; a[_] = x }
`,
),
"newMod7": MustParseModule(`package rec6
Expand Down
15 changes: 9 additions & 6 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func (expr *Expr) OutputVars(safe VarSet) VarSet {

switch terms := expr.Terms.(type) {
case *Term:
return expr.outputVarsRefs()
return expr.outputVarsRefs(safe)
case []*Term:
name := terms[0].Value.(Var)
if b := BuiltinMap[name]; b != nil {
Expand Down Expand Up @@ -832,7 +832,7 @@ func (expr *Expr) Vars(params VarVisitorParams) VarSet {

func (expr *Expr) outputVarsBuiltins(b *Builtin, safe VarSet) VarSet {

o := expr.outputVarsRefs()
o := expr.outputVarsRefs(safe)
terms := expr.Terms.([]*Term)

// Check that all input terms are ground or safe.
Expand Down Expand Up @@ -870,17 +870,20 @@ func (expr *Expr) outputVarsBuiltins(b *Builtin, safe VarSet) VarSet {

func (expr *Expr) outputVarsEquality(safe VarSet) VarSet {
ts := expr.Terms.([]*Term)
o := expr.outputVarsRefs()
o := expr.outputVarsRefs(safe)
o.Update(safe)
o.Update(Unify(o, ts[1], ts[2]))
return o.Diff(safe)
}

func (expr *Expr) outputVarsRefs() VarSet {
func (expr *Expr) outputVarsRefs(safe VarSet) VarSet {
o := VarSet{}
WalkRefs(expr, func(r Ref) bool {
o.Update(r.OutputVars())
return false
if safe.Contains(r[0].Value.(Var)) {
o.Update(r.OutputVars())
return false
}
return true
})
return o
}
Expand Down
16 changes: 11 additions & 5 deletions ast/unify.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ func (u *unifier) unify(a *Term, b *Term) {
}
case Array, Object:
u.unifyAll(a, b)
case Ref:
if u.isSafe(b[0].Value.(Var)) {
u.markSafe(a)
}
default:
u.markSafe(a)
}

case Ref:
switch b := b.Value.(type) {
case Var:
u.markSafe(b)
case Array, Object:
u.markAllSafe(b, a)
if u.isSafe(a[0].Value.(Var)) {
switch b := b.Value.(type) {
case Var:
u.markSafe(b)
case Array, Object:
u.markAllSafe(b, a)
}
}

case *ArrayComprehension:
Expand Down

0 comments on commit 9b6adb6

Please sign in to comment.