diff --git a/ast/parser_ext.go b/ast/parser_ext.go index 62b8304b99..60fcf0401e 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -119,15 +119,15 @@ func MustParseTerm(input string) *Term { // ParseRuleFromBody attempts to return a rule from a body. Equality expressions // of the form = can be converted into rules of the form = // :- true. This is a concise way of defining constants inside modules. -func ParseRuleFromBody(body Body) *Rule { +func ParseRuleFromBody(body Body) (*Rule, error) { if len(body) != 1 { - return nil + return nil, fmt.Errorf("multiple %vs cannot be used for %v", ExprTypeName, HeadTypeName) } expr := body[0] if !expr.IsEquality() { - return nil + return nil, fmt.Errorf("non-equality %v cannot be used for %v", ExprTypeName, HeadTypeName) } terms := expr.Terms.([]*Term) @@ -141,13 +141,13 @@ func ParseRuleFromBody(body Body) *Rule { if v.Equal(RequestRootRef) || v.Equal(DefaultRootRef) { name = Var(v.String()) } else { - return nil + return nil, fmt.Errorf("%v cannot be used for name of %v", RefTypeName, RuleTypeName) } default: - return nil + return nil, fmt.Errorf("%v cannot be used for name of %v", TypeName(v), RuleTypeName) } - return &Rule{ + rule := &Rule{ Location: expr.Location, Name: name, Value: terms[2], @@ -155,6 +155,8 @@ func ParseRuleFromBody(body Body) *Rule { &Expr{Terms: BooleanTerm(true)}, ), } + + return rule, nil } // ParseImports returns a slice of Import objects. @@ -342,10 +344,12 @@ func parseModule(stmts []Statement) (*Module, error) { return nil, nil } + var errs Errors + _package, ok := stmts[0].(*Package) if !ok { loc := stmts[0].(Statement).Loc() - return nil, NewError(ParseErr, loc, "expected package directive (%s must come after package directive)", stmts[0]) + errs = append(errs, NewError(ParseErr, loc, "expected %v (found %v)", PackageTypeName, TypeName(stmts[0]))) } mod := &Module{ @@ -359,15 +363,25 @@ func parseModule(stmts []Statement) (*Module, error) { case *Rule: mod.Rules = append(mod.Rules, stmt) case Body: - rule := ParseRuleFromBody(stmt) - if rule == nil { - return nil, NewError(ParseErr, stmt[0].Location, "expected rule (%s must be declared inside a rule)", stmt[0].Location.Text) + rule, err := ParseRuleFromBody(stmt) + if err != nil { + errs = append(errs, NewError(ParseErr, stmt[0].Location, "expected %v (%v)", RuleTypeName, err)) + } else { + mod.Rules = append(mod.Rules, rule) } - mod.Rules = append(mod.Rules, rule) + case *Package: + errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected "+PackageTypeName)) + case *Comment: // Drop comments for now. + default: + panic("illegal value") // Indicates grammar is out-of-sync with code. } } - return mod, nil + if len(errs) == 0 { + return mod, nil + } + + return nil, errs } func postProcess(filename string, stmts []Statement) error { diff --git a/ast/parser_test.go b/ast/parser_test.go index 904752a51b..68b76987fc 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -527,6 +527,32 @@ func TestExample(t *testing.T) { }) } +func TestModuleParseErrors(t *testing.T) { + input := ` + x = 1 # expect package + package a # unexpected package + 1 = 2 # non-var head + 1 != 2 # non-equality expr + x = y, x = 1 # multiple exprs + ` + + mod, err := ParseModule("test.rego", input) + if err == nil { + t.Fatalf("Expected error but got: %v", mod) + } + + errs, ok := err.(Errors) + if !ok { + panic("unexpected error value") + } + + if len(errs) != 5 { + t.Fatalf("Expected exactly 5 errors but got: %v", err) + } + + fmt.Println(errs) +} + func TestLocation(t *testing.T) { mod, err := ParseModule("test", testModule) if err != nil { diff --git a/ast/policy.go b/ast/policy.go index 911d82766b..36c43c35b2 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -80,7 +80,7 @@ type ( Rules []*Rule } - // Comment represents + // Comment contains the raw text from the comment in the definition. Comment struct { Text []byte Location *Location diff --git a/ast/strings.go b/ast/strings.go index 7ad1088b1b..463bf5c0c4 100644 --- a/ast/strings.go +++ b/ast/strings.go @@ -26,4 +26,10 @@ const ( ObjectTypeName = "object" SetTypeName = "set" ArrayComprehensionTypeName = "arraycomprehension" + ExprTypeName = "expr" + BodyTypeName = "body" + HeadTypeName = "head" + RuleTypeName = "rule" + ImportTypeName = "import" + PackageTypeName = "package" ) diff --git a/repl/repl.go b/repl/repl.go index 35f545d1f5..b74b368c17 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -574,7 +574,7 @@ func (r *REPL) evalStatement(ctx context.Context, stmt interface{}) error { if err != nil { return err } - if rule := ast.ParseRuleFromBody(body); rule != nil { + if rule, err := ast.ParseRuleFromBody(body); err == nil { if err := r.compileRule(rule); err != nil { return err }