Skip to content

Commit

Permalink
ast: Report safety errors on line where expression starts
Browse files Browse the repository at this point in the history
These changes just update the safety errors to report the line of the
expression where the error occurred instead of the rule/query that
contains the expressionm. Better locality should make it easier to
identify safety error causes.

Fixes #1497

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall authored and patrick-east committed Jun 25, 2019
1 parent 3f06359 commit 16197e4
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 16 deletions.
50 changes: 37 additions & 13 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,15 @@ func (c *Compiler) checkSafetyRuleBodies() {
WalkRules(m, func(r *Rule) bool {
safe := ReservedVars.Copy()
safe.Update(r.Head.Args.Vars())
r.Body = c.checkBodySafety(safe, m, r.Body, r.Loc())
r.Body = c.checkBodySafety(safe, m, r.Body)
return false
})
}
}

func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body, l *Location) Body {
func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body) Body {
reordered, unsafe := reorderBodyForSafety(c.GetArity, safe, b)
if errs := safetyErrorSlice(l, unsafe); len(errs) > 0 {
if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
for _, err := range errs {
c.err(err)
}
Expand Down Expand Up @@ -1154,7 +1154,7 @@ func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, err
func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
safe := ReservedVars.Copy()
reordered, unsafe := reorderBodyForSafety(qc.compiler.GetArity, safe, body)
if errs := safetyErrorSlice(body.Loc(), unsafe); len(errs) > 0 {
if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
return nil, errs
}
return reordered, nil
Expand Down Expand Up @@ -1482,6 +1482,11 @@ type unsafePair struct {
Vars VarSet
}

type unsafeVarLoc struct {
Var Var
Loc *Location
}

type unsafeVars map[*Expr]VarSet

func (vs unsafeVars) Add(e *Expr, v Var) {
Expand All @@ -1505,12 +1510,31 @@ func (vs unsafeVars) Update(o unsafeVars) {
}
}

func (vs unsafeVars) Vars() VarSet {
r := VarSet{}
for _, s := range vs {
r.Update(s)
func (vs unsafeVars) Vars() (result []unsafeVarLoc) {

locs := map[Var]*Location{}

// If var appears in multiple sets then pick first by location.
for expr, vars := range vs {
for v := range vars {
if locs[v].Compare(expr.Location) > 0 {
locs[v] = expr.Location
}
}
}
return r

for v, loc := range locs {
result = append(result, unsafeVarLoc{
Var: v,
Loc: loc,
})
}

sort.Slice(result, func(i, j int) bool {
return result[i].Loc.Compare(result[j].Loc) < 0
})

return result
}

func (vs unsafeVars) Slice() (result []unsafePair) {
Expand Down Expand Up @@ -2962,15 +2986,15 @@ func isDataRef(term *Term) bool {
return false
}

func safetyErrorSlice(l *Location, unsafe unsafeVars) (result Errors) {
func safetyErrorSlice(unsafe unsafeVars) (result Errors) {

if len(unsafe) == 0 {
return
}

for v := range unsafe.Vars() {
if !v.IsGenerated() {
result = append(result, NewError(UnsafeVarErr, l, "var %v is unsafe", v))
for _, pair := range unsafe.Vars() {
if !pair.Var.IsGenerated() {
result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", pair.Var))
}
}

Expand Down
28 changes: 26 additions & 2 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ func TestCompilerErrorLimit(t *testing.T) {

errs := c.Errors
exp := []string{
"2:2: rego_unsafe_var_error: var x is unsafe",
"2:2: rego_unsafe_var_error: var z is unsafe",
"2:20: rego_unsafe_var_error: var x is unsafe",
"2:20: rego_unsafe_var_error: var z is unsafe",
"rego_compile_error: error limit reached",
}

Expand Down Expand Up @@ -598,6 +598,30 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
}
}

func TestCompilerCheckSafetyVarLoc(t *testing.T) {

_, err := CompileModules(map[string]string{"test.rego": `package test
p {
not x
x > y
}`})

if err == nil {
t.Fatal("expected error")
}

errs := err.(Errors)

if !strings.Contains(errs[0].Message, "var x is unsafe") || errs[0].Location.Row != 4 {
t.Fatal("expected error on row 4 but got:", err)
}

if !strings.Contains(errs[1].Message, "var y is unsafe") || errs[1].Location.Row != 5 {
t.Fatal("expected y is unsafe on row 5 but got:", err)
}
}

func TestCompilerCheckTypes(t *testing.T) {
c := NewCompiler()
modules := getCompilerTestModules()
Expand Down
3 changes: 2 additions & 1 deletion ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ func (loc *Location) String() string {

// Compare returns -1, 0, or 1 to indicate if this loc is less than, equal to,
// or greater than the other. Comparison is performed on the file, row, and
// column of the Location (but not on the text.)
// column of the Location (but not on the text.) Nil locations are greater than
// non-nil locations.
func (loc *Location) Compare(other *Location) int {
if loc == nil && other == nil {
return 0
Expand Down
80 changes: 80 additions & 0 deletions ast/term_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,86 @@ func TestValueToInterface(t *testing.T) {
}
}

func TestLocationCompare(t *testing.T) {

tests := []struct {
a string
b string
exp int
}{
{
a: "",
b: "",
exp: 0,
},
{
a: "",
b: `{"file": "a", "row": 1, "col": 1}`,
exp: 1,
},
{
a: `{"file": "a", "row": 1, "col": 1}`,
b: "",
exp: -1,
},
{
a: `{"file": "a", "row": 1, "col": 1}`,
b: `{"file": "a", "row": 1, "col": 1}`,
exp: 0,
},
{
a: `{"file": "a", "row": 1, "col": 1}`,
b: `{"file": "b", "row": 1, "col": 1}`,
exp: -1,
},
{
a: `{"file": "b", "row": 1, "col": 1}`,
b: `{"file": "a", "row": 1, "col": 1}`,
exp: 1,
},
{
a: `{"file": "a", "row": 1, "col": 1}`,
b: `{"file": "a", "row": 2, "col": 1}`,
exp: -1,
},
{
a: `{"file": "a", "row": 2, "col": 1}`,
b: `{"file": "a", "row": 1, "col": 1}`,
exp: 1,
},
{
a: `{"file": "a", "row": 1, "col": 1}`,
b: `{"file": "a", "row": 1, "col": 2}`,
exp: -1,
},
{
a: `{"file": "a", "row": 1, "col": 2}`,
b: `{"file": "a", "row": 1, "col": 1}`,
exp: 1,
},
}

unmarshal := func(s string) *Location {
if s != "" {
var loc Location
if err := util.Unmarshal([]byte(s), &loc); err != nil {
t.Fatal(err)
}
return &loc
}
return nil
}

for _, tc := range tests {
locA := unmarshal(tc.a)
locB := unmarshal(tc.b)
result := locA.Compare(locB)
if tc.exp != result {
t.Fatalf("Expected %v but got %v for %v.Compare(%v)", tc.exp, result, locA, locB)
}
}
}

func assertTermEqual(t *testing.T, x *Term, y *Term) {
if !x.Equal(y) {
t.Errorf("Failure on equality: \n%s and \n%s\n", x, y)
Expand Down

0 comments on commit 16197e4

Please sign in to comment.