diff --git a/cmd/test.go b/cmd/test.go index ac4adbf45e..62e8aac05e 100644 --- a/cmd/test.go +++ b/cmd/test.go @@ -113,7 +113,7 @@ Example test run: } func opaTest(args []string) int { - ctx, cancel := context.WithTimeout(context.Background(), testParams.timeout) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() filter := loaderFilter{ @@ -175,7 +175,8 @@ func opaTest(args []string) int { EnableFailureLine(testParams.failureLine). SetRuntime(info). SetModules(modules). - SetBundles(bundles) + SetBundles(bundles). + SetTimeout(testParams.timeout) ch, err := runner.RunTests(ctx, txn) if err != nil { diff --git a/tester/runner.go b/tester/runner.go index ae8008db0d..940c8e448e 100644 --- a/tester/runner.go +++ b/tester/runner.go @@ -99,6 +99,7 @@ type Runner struct { trace bool runtime *ast.Term failureLine bool + timeout time.Duration modules map[string]*ast.Module bundles map[string]*bundle.Bundle } @@ -151,6 +152,12 @@ func (r *Runner) SetRuntime(term *ast.Term) *Runner { return r } +// SetTimeout sets the timeout for the individual test cases +func (r *Runner) SetTimeout(timout time.Duration) *Runner { + r.timeout = timout + return r +} + // SetModules will add modules to the Runner which will be compiled then used // for discovering and evaluating tests. func (r *Runner) SetModules(modules map[string]*ast.Module) *Runner { @@ -259,7 +266,9 @@ func (r *Runner) RunTests(ctx context.Context, txn storage.Transaction) (ch chan if !strings.HasPrefix(string(rule.Head.Name), TestPrefix) { continue } - tr, stop := r.runTest(ctx, txn, module, rule) + runCtx, cancel := context.WithTimeout(ctx, r.timeout) + defer cancel() + tr, stop := r.runTest(runCtx, txn, module, rule) ch <- tr if stop { return diff --git a/tester/runner_test.go b/tester/runner_test.go index ecebc4b3fa..dc4de2416b 100644 --- a/tester/runner_test.go +++ b/tester/runner_test.go @@ -154,20 +154,8 @@ func TestRun(t *testing.T) { func TestRunnerCancel(t *testing.T) { - ast.RegisterBuiltin(&ast.Builtin{ - Name: "test.sleep", - Decl: types.NewFunction( - types.Args(types.S), - types.NewNull(), - ), - }) - - topdown.RegisterFunctionalBuiltin1("test.sleep", func(a ast.Value) (ast.Value, error) { - d, _ := time.ParseDuration(string(a.(ast.String))) - time.Sleep(d) - return ast.Null{}, nil - }) - + registerSleepBuiltin() + ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -189,5 +177,57 @@ func TestRunnerCancel(t *testing.T) { t.Fatalf("Expected cancel error but got: %v", results[0].Error) } }) +} + +func TestRunner_Timeout(t *testing.T) { + + registerSleepBuiltin() + + ctx := context.Background() + + files := map[string]string{ + "/a_test.rego": `package foo + + test_1 { test.sleep("100ms") } + test_2 { true }`, + } + test.WithTempFS(files, func(d string) { + paths := []string{d} + modules, store, err := tester.Load(paths, nil) + if err != nil { + t.Fatal(err) + } + duration, err := time.ParseDuration("1ns") + if err != nil { + t.Fatal(err) + } + ch, err := tester.NewRunner().SetTimeout(duration).SetStore(store).Run(ctx, modules) + if err != nil { + t.Fatal(err) + } + var results []*tester.Result + for r := range ch { + results = append(results, r) + } + if !topdown.IsCancel(results[0].Error) { + t.Fatalf("Expected cancel error but got: %v", results[0].Error) + } + }) } + +func registerSleepBuiltin() { + ast.RegisterBuiltin(&ast.Builtin{ + Name: "test.sleep", + Decl: types.NewFunction( + types.Args(types.S), + types.NewNull(), + ), + }) + + topdown.RegisterFunctionalBuiltin1("test.sleep", func(a ast.Value) (ast.Value, error) { + d, _ := time.ParseDuration(string(a.(ast.String))) + time.Sleep(d) + return ast.Null{}, nil + }) +} \ No newline at end of file