diff --git a/cmd/run.go b/cmd/run.go index 3922d19706..beb1ec52d7 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -28,7 +28,7 @@ var defaultAddr = ":8181" func init() { - params := &runtime.Params{} + params := runtime.NewParams() runCommand := &cobra.Command{ Use: "run", @@ -39,13 +39,13 @@ To run the interactive shell: $ opa run -To run the server without saving policies: +To run the server: $ opa run -s -To run the server and persist policies to a local directory: +To evaluate a query from the command line: - $ opa run -s -p ./policies/ + $ opa run -e 'data.repl.version[key] = value' The 'run' command starts an instance of the OPA runtime. The OPA runtime can be started as an interactive shell or a server. @@ -70,6 +70,7 @@ In addition, API calls to delete policies will remove the definition file. } runCommand.Flags().BoolVarP(¶ms.Server, "server", "s", false, "start the runtime in server mode") + runCommand.Flags().StringVarP(¶ms.Eval, "eval", "e", "", "evaluate, print, exit") runCommand.Flags().StringVarP(¶ms.HistoryPath, "history", "H", historyPath(), "set path of history file") runCommand.Flags().StringVarP(¶ms.PolicyDir, "policy-dir", "p", "", "set directory to store policy definitions") runCommand.Flags().StringVarP(¶ms.Addr, "addr", "a", defaultAddr, "set listening address of the server") diff --git a/repl/errors.go b/repl/errors.go new file mode 100644 index 0000000000..22eeed3a2e --- /dev/null +++ b/repl/errors.go @@ -0,0 +1,41 @@ +// Copyright 2016 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package repl + +import "fmt" + +// Error is the error type returned by the REPL. +type Error struct { + Code ErrCode + Message string +} + +func (err *Error) Error() string { + return fmt.Sprintf("code %v: %v", err.Code, err.Message) +} + +// ErrCode represents the collection of errors that may be returned by the REPL. +type ErrCode int + +const ( + // BadArgsErr indicates bad arguments were provided to a built-in REPL + // command. + BadArgsErr ErrCode = iota +) + +func newBadArgsErr(f string, a ...interface{}) *Error { + return &Error{ + Code: BadArgsErr, + Message: fmt.Sprintf(f, a...), + } +} + +// stop is returned by the 'exit' command to indicate to the REPL that it should +// break and return. +type stop struct{} + +func (stop) Error() string { + return "" +} diff --git a/repl/repl.go b/repl/repl.go index 51aa334c92..eac7d3575f 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -46,6 +46,9 @@ type REPL struct { initPrompt string bufferPrompt string banner string + + bufferDisabled bool + undefinedDisabled bool } type explainMode int @@ -153,9 +156,13 @@ func (r *REPL) Loop() { os.Exit(1) } - if r.OneShot(input) { - fmt.Fprintln(r.output, "Exiting") - break + if err := r.OneShot(input); err != nil { + switch err := err.(type) { + case stop: + break + default: + fmt.Fprintln(r.output, "error:", err) + } } line.AppendHistory(input) @@ -164,17 +171,14 @@ func (r *REPL) Loop() { r.saveHistory(line) } -// OneShot evaluates a single line and prints the result. Returns true if caller -// should exit. -func (r *REPL) OneShot(line string) bool { +// OneShot evaluates the line and prints the result. If an error occurs it is +// returned for the caller to display. +func (r *REPL) OneShot(line string) error { var err error - r.txn, err = r.store.NewTransaction() - if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } defer r.store.Close(r.txn) @@ -211,7 +215,21 @@ func (r *REPL) OneShot(line string) bool { return r.evalBufferMulti() } - return false + return nil +} + +// DisableMultiLineBuffering causes the REPL to not buffer lines when a parse +// error occurs. Instead, the error will be returned to the caller. +func (r *REPL) DisableMultiLineBuffering(yes bool) *REPL { + r.bufferDisabled = yes + return r +} + +// DisableUndefinedOutput causes the REPL to not print any output when the query +// is undefined. +func (r *REPL) DisableUndefinedOutput(yes bool) *REPL { + r.undefinedDisabled = yes + return r } func (r *REPL) complete(line string) (c []string) { @@ -251,90 +269,80 @@ func (r *REPL) complete(line string) (c []string) { return c } -func (r *REPL) cmdDump(args []string) bool { +func (r *REPL) cmdDump(args []string) error { if len(args) == 0 { return r.cmdDumpOutput() } return r.cmdDumpPath(args[0]) } -func (r *REPL) cmdDumpOutput() bool { - if err := dumpStorage(r.store, r.txn, r.output); err != nil { - fmt.Fprintln(r.output, "error:", err) - } - return false +func (r *REPL) cmdDumpOutput() error { + return dumpStorage(r.store, r.txn, r.output) } -func (r *REPL) cmdDumpPath(filename string) bool { +func (r *REPL) cmdDumpPath(filename string) error { f, err := os.Create(filename) if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } defer f.Close() - if err := dumpStorage(r.store, r.txn, f); err != nil { - fmt.Fprintln(r.output, "error:", err) - } - return false + return dumpStorage(r.store, r.txn, f) } -func (r *REPL) cmdExit() bool { - return true +func (r *REPL) cmdExit() error { + return stop{} } -func (r *REPL) cmdFormat(s string) bool { +func (r *REPL) cmdFormat(s string) error { r.outputFormat = s - return false + return nil } -func (r *REPL) cmdHelp() bool { +func (r *REPL) cmdHelp() error { fmt.Fprintln(r.output, "") printHelpExamples(r.output, r.initPrompt) printHelpCommands(r.output) - return false + return nil } -func (r *REPL) cmdShow() bool { +func (r *REPL) cmdShow() error { module := r.modules[r.currentModuleID] fmt.Fprintln(r.output, module) - return false + return nil } -func (r *REPL) cmdTrace() bool { +func (r *REPL) cmdTrace() error { if r.explain == explainTrace { r.explain = explainOff } else { r.explain = explainTrace } - return false + return nil } -func (r *REPL) cmdTruth() bool { +func (r *REPL) cmdTruth() error { if r.explain == explainTruth { r.explain = explainOff } else { r.explain = explainTruth } - return false + return nil } -func (r *REPL) cmdUnset(args []string) bool { +func (r *REPL) cmdUnset(args []string) error { if len(args) != 1 { - fmt.Fprintln(r.output, "error: unset : expects exactly one argument") - return false + return newBadArgsErr("unset : expects exactly one argument") } term, err := ast.ParseTerm(args[0]) if err != nil { - fmt.Fprintln(r.output, "error: argument must identify a rule") - return false + return newBadArgsErr("argument must identify a rule") } v, ok := term.Value.(ast.Var) if !ok { - fmt.Fprintln(r.output, "error: argument must identify a rule") - return false + return newBadArgsErr("argument must identify a rule") } mod := r.modules[r.currentModuleID] @@ -348,7 +356,7 @@ func (r *REPL) cmdUnset(args []string) bool { if len(rules) == len(mod.Rules) { fmt.Fprintln(r.output, "warning: no matching rules in current module") - return false + return nil } cpy := mod.Copy() @@ -366,13 +374,12 @@ func (r *REPL) cmdUnset(args []string) bool { compiler := ast.NewCompiler() if compiler.Compile(policies); compiler.Failed() { - fmt.Fprintln(r.output, "error:", compiler.Errors) - return false + return compiler.Errors } r.modules[r.currentModuleID] = cpy - return false + return nil } func (r *REPL) compileBody(body ast.Body) (ast.Body, error) { @@ -416,54 +423,59 @@ func (r *REPL) compileRule(rule *ast.Rule) error { return nil } -func (r *REPL) evalBufferOne() bool { +func (r *REPL) evalBufferOne() error { line := strings.Join(r.buffer, "\n") if len(strings.TrimSpace(line)) == 0 { r.buffer = []string{} - return false + return nil } // The user may enter lines with comments on the end or // multiple lines with comments interspersed. In these cases // the parser will return multiple statements. stmts, err := ast.ParseStatements("", line) - if err != nil { - return false + if r.bufferDisabled { + return err + } + return nil } r.buffer = []string{} for _, stmt := range stmts { - r.evalStatement(stmt) + if err := r.evalStatement(stmt); err != nil { + return err + } } - return false + return nil } -func (r *REPL) evalBufferMulti() bool { +func (r *REPL) evalBufferMulti() error { line := strings.Join(r.buffer, "\n") r.buffer = []string{} if len(strings.TrimSpace(line)) == 0 { - return false + return nil } stmts, err := ast.ParseStatements("", line) if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } for _, stmt := range stmts { - r.evalStatement(stmt) + if err := r.evalStatement(stmt); err != nil { + return err + } } - return false + return nil } func (r *REPL) loadCompiler() (*ast.Compiler, error) { @@ -522,29 +534,26 @@ func (r *REPL) loadGlobals(compiler *ast.Compiler) (*ast.ValueMap, error) { return topdown.MakeGlobals(pairs) } -func (r *REPL) evalStatement(stmt interface{}) bool { +func (r *REPL) evalStatement(stmt interface{}) error { switch s := stmt.(type) { case ast.Body: body, err := r.compileBody(s) if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } if rule := ast.ParseConstantRule(body); rule != nil { if err := r.compileRule(rule); err != nil { - fmt.Fprintln(r.output, "error:", err) + return err } - return false + return nil } compiler, err := r.loadCompiler() if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } globals, err := r.loadGlobals(compiler) if err != nil { - fmt.Fprintln(r.output, "error:", err) - return false + return err } return r.evalBody(compiler, globals, body) case *ast.Rule: @@ -556,10 +565,10 @@ func (r *REPL) evalStatement(stmt interface{}) bool { case *ast.Package: return r.evalPackage(s) } - return false + return nil } -func (r *REPL) evalBody(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool { +func (r *REPL) evalBody(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) error { // Special case for positive, single term inputs. if len(body) == 1 { @@ -633,8 +642,7 @@ func (r *REPL) evalBody(compiler *ast.Compiler, globals *ast.ValueMap, body ast. } if err != nil { - fmt.Fprintf(r.output, "error: %v\n", err) - return false + return err } if isTrue { @@ -647,30 +655,31 @@ func (r *REPL) evalBody(compiler *ast.Compiler, globals *ast.ValueMap, body ast. fmt.Fprintln(r.output, "false") } - return false + return nil } -func (r *REPL) evalImport(i *ast.Import) bool { +func (r *REPL) evalImport(i *ast.Import) error { mod := r.modules[r.currentModuleID] + for _, other := range mod.Imports { if other.Equal(i) { - return false + return nil } } mod.Imports = append(mod.Imports, i) - return false + return nil } -func (r *REPL) evalPackage(p *ast.Package) bool { +func (r *REPL) evalPackage(p *ast.Package) error { moduleID := p.Path.String() if _, ok := r.modules[moduleID]; ok { r.currentModuleID = moduleID - return false + return nil } r.modules[moduleID] = &ast.Module{ @@ -679,7 +688,7 @@ func (r *REPL) evalPackage(p *ast.Package) bool { r.currentModuleID = moduleID - return false + return nil } // evalTermSingleValue evaluates and prints terms in cases where the term evaluates to a @@ -687,7 +696,7 @@ func (r *REPL) evalPackage(p *ast.Package) bool { // and comprehensions always evaluate to a single value. To handle references, this function // still executes the query, except it does so by rewriting the body to assign the term // to a variable. This allows the REPL to obtain the result even if the term is false. -func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool { +func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) error { term := body[0].Terms.(*ast.Term) outputVar := ast.Wildcard @@ -722,19 +731,21 @@ func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, globals *ast.ValueMap } if err != nil { - fmt.Fprintln(r.output, "error:", err) - } else if isTrue { + return err + } + + if isTrue { r.printJSON(result) - } else { + } else if !r.undefinedDisabled { r.printUndefined() } - return false + return nil } // evalTermMultiValue evaluates and prints terms in cases where the term may evaluate to multiple // ground values, e.g., a[i], [servers[x]], etc. -func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool { +func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) error { // Mangle the expression in the same way we do for evalTermSingleValue. When handling the // evaluation result below, we will ignore this variable. @@ -807,8 +818,10 @@ func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, globals *ast.ValueMap, } if err != nil { - fmt.Fprintln(r.output, "error:", err) - } else if len(results) > 0 { + return err + } + + if len(results) > 0 { keys := []string{} for v := range vars { keys = append(keys, v) @@ -818,11 +831,11 @@ func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, globals *ast.ValueMap, keys = append(keys, resultKey) } r.printResults(keys, results) - } else { + } else if !r.undefinedDisabled { r.printUndefined() } - return false + return nil } func (r *REPL) getPrompt() string { diff --git a/repl/repl_test.go b/repl/repl_test.go index 27f44771e5..31a8c43991 100644 --- a/repl/repl_test.go +++ b/repl/repl_test.go @@ -199,65 +199,51 @@ func TestUnset(t *testing.T) { repl.OneShot("magic = 23") repl.OneShot("p = 3.14") repl.OneShot("unset p") - repl.OneShot("p") - result := buffer.String() - if result != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" { - t.Errorf("Expected p to be unsafe but got: %v", result) - return + + err := repl.OneShot("p") + if _, ok := err.(ast.Errors); !ok { + t.Fatalf("Expected AST error but got: %v", err) } buffer.Reset() repl.OneShot("p = 3.14") repl.OneShot("p = 3 :- false") repl.OneShot("unset p") - repl.OneShot("p") - result = buffer.String() - if result != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" { - t.Errorf("Expected p to be unsafe but got: %v", result) - return + + err = repl.OneShot("p") + if _, ok := err.(ast.Errors); !ok { + t.Fatalf("Expected AST error but got err: %v, output: %v", err, buffer.String()) } - buffer.Reset() - repl.OneShot("unset ") - result = buffer.String() - if result != "error: unset : expects exactly one argument\n" { - t.Errorf("Expected unset error for bad syntax but got: %v", result) + if err := repl.OneShot("unset "); err == nil { + t.Fatalf("Expected unset error for bad syntax but got: %v", buffer.String()) } - buffer.Reset() - repl.OneShot("unset 1=1") - result = buffer.String() - if result != "error: argument must identify a rule\n" { - t.Errorf("Expected unset error for bad syntax but got: %v", result) + if err := repl.OneShot("unset 1=1"); err == nil { + t.Fatalf("Expected unset error for bad syntax but got: %v", buffer.String()) } - buffer.Reset() - repl.OneShot(`unset "p"`) - result = buffer.String() - if result != "error: argument must identify a rule\n" { - t.Errorf("Expected unset error for bad syntax but got: %v", result) + if err := repl.OneShot(`unset "p"`); err == nil { + t.Fatalf("Expected unset error for bad syntax but got: %v", buffer.String()) } buffer.Reset() repl.OneShot(`unset q`) - result = buffer.String() - if result != "warning: no matching rules in current module\n" { - t.Errorf("Expected unset error for missing rule but got: %v", result) + if buffer.String() != "warning: no matching rules in current module\n" { + t.Fatalf("Expected unset error for missing rule but got: %v", buffer.String()) } buffer.Reset() repl.OneShot(`magic`) - result = buffer.String() - if result != "23\n" { - t.Errorf("Expected magic to be defined but got: %v", result) + if buffer.String() != "23\n" { + t.Fatalf("Expected magic to be defined but got: %v", buffer.String()) } buffer.Reset() repl.OneShot(`package data.other`) repl.OneShot(`unset magic`) - result = buffer.String() - if result != "warning: no matching rules in current module\n" { - t.Errorf("Expected unset error for bad syntax but got: %v", result) + if buffer.String() != "warning: no matching rules in current module\n" { + t.Fatalf("Expected unset error for bad syntax but got: %v", buffer.String()) } } @@ -614,17 +600,14 @@ func TestEvalBodyCompileError(t *testing.T) { var buffer bytes.Buffer repl := newRepl(store, &buffer) repl.outputFormat = "json" - repl.OneShot("x = 1, y > x") - result1 := buffer.String() - expected1 := "error: 1 error occurred: 1:1: y is unsafe (variable y must appear in the output position of at least one non-negated expression)\n" - if result1 != expected1 { - t.Errorf("Expected error message in output but got`: %v", result1) - return + err := repl.OneShot("x = 1, y > x") + if _, ok := err.(ast.Errors); !ok { + t.Fatalf("Expected error message in output but got`: %v", buffer.String()) } buffer.Reset() repl.OneShot("x = 1, y = 2, y > x") var result2 []interface{} - err := json.Unmarshal(buffer.Bytes(), &result2) + err = json.Unmarshal(buffer.Bytes(), &result2) if err != nil { t.Errorf("Expected valid JSON output but got: %v", buffer.String()) return @@ -722,10 +705,9 @@ func TestEvalPackage(t *testing.T) { repl.OneShot("p = true :- true") repl.OneShot("package baz.qux") buffer.Reset() - repl.OneShot("p") - if buffer.String() != "error: 1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)\n" { - t.Errorf("Expected unsafe variable error but got: %v", buffer.String()) - return + err := repl.OneShot("p") + if err.Error() != "1 error occurred: 1:1: p is unsafe (variable p must appear in the output position of at least one non-negated expression)" { + t.Fatalf("Expected unsafe variable error but got: %v", err) } repl.OneShot("import data.foo.bar.p") buffer.Reset() diff --git a/runtime/runtime.go b/runtime/runtime.go index 20358430c9..1155f4771d 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -30,6 +31,9 @@ type Params struct { // Addr is the listening address that the OPA server will bind to. Addr string + // Eval is a string to evaluate in the REPL. + Eval string + // HistoryPath is the filename to store the interactive shell user // input history. HistoryPath string @@ -57,6 +61,17 @@ type Params struct { // and reload the storage layer each time they change. This is useful for // interactive development. Watch bool + + // Output is the output stream used when run as an interactive shell. This + // is mostly for test purposes. + Output io.Writer +} + +// NewParams returns a new Params object. +func NewParams() *Params { + return &Params{ + Output: os.Stdout, + } } // Runtime represents a single OPA instance. @@ -77,6 +92,7 @@ func (rt *Runtime) Start(params *Params) { } else { rt.startRepl(params) } + } func (rt *Runtime) init(params *Params) error { @@ -144,18 +160,28 @@ func (rt *Runtime) startServer(params *Params) { func (rt *Runtime) startRepl(params *Params) { banner := rt.getBanner() - repl := repl.New(rt.Store, params.HistoryPath, os.Stdout, params.OutputFormat, banner) + repl := repl.New(rt.Store, params.HistoryPath, params.Output, params.OutputFormat, banner) if params.Watch { watcher, err := rt.getWatcher(params.Paths) if err != nil { - fmt.Println("error opening watch:", err) + fmt.Fprintln(params.Output, "error opening watch:", err) os.Exit(1) } go rt.readWatcher(watcher, params.Paths) } - repl.Loop() + if params.Eval == "" { + repl.Loop() + } else { + repl.DisableUndefinedOutput(true) + repl.DisableMultiLineBuffering(true) + if err := repl.OneShot(params.Eval); err != nil { + fmt.Fprintln(params.Output, "error:", err) + os.Exit(1) + } + } + } func (rt *Runtime) getWatcher(paths []string) (*fsnotify.Watcher, error) { diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 4ecf77c5f0..ee8014dd54 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -5,9 +5,12 @@ package runtime import ( + "bytes" + "encoding/json" "io/ioutil" "os" "path/filepath" + "reflect" "testing" "github.com/open-policy-agent/opa/ast" @@ -15,6 +18,21 @@ import ( "github.com/open-policy-agent/opa/util" ) +func TestEval(t *testing.T) { + params := NewParams() + var buffer bytes.Buffer + params.Output = &buffer + params.OutputFormat = "json" + params.Eval = `a = b, a = 1, c = 2, c > b` + rt := &Runtime{} + rt.Start(params) + expected := parseJSON(`[{"a": 1, "b": 1, "c": 2}]`) + result := parseJSON(buffer.String()) + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v but got: %v", expected, result) + } +} + func TestInit(t *testing.T) { tmp1, err := ioutil.TempFile("", "docFile") if err != nil { @@ -93,3 +111,11 @@ func TestInit(t *testing.T) { } } + +func parseJSON(s string) interface{} { + var x interface{} + if err := json.Unmarshal([]byte(s), &x); err != nil { + panic(err) + } + return x +}