diff --git a/datadriven.go b/datadriven.go index a7692f5..d6f9e04 100644 --- a/datadriven.go +++ b/datadriven.go @@ -22,6 +22,7 @@ import ( "os" "path/filepath" "regexp" + "runtime" "strconv" "strings" "testing" @@ -568,6 +569,53 @@ func (td *TestData) ScanArgs(t testing.TB, key string, dests ...interface{}) { arg.scan(t, td.Pos, dests...) } +// Retry is used for tests that depend on background goroutines to finish work. +// It takes a function that produces the output of the testcase and calls it +// repeatedly until it matches the expected output (for at most 1 second). +// +// Returns the last value returned by f (which can be directly returned from the +// function passed to RunTest). +// +// If --rewrite is used, just sleeps for 100ms. +func (td *TestData) Retry(tb testing.TB, f func() string) string { + return td.RetryFor(tb, time.Second, f) +} + +// RetryFor is like Retry but with a custom timeout. +func (td *TestData) RetryFor(tb testing.TB, d time.Duration, f func() string) string { + if td.Rewrite { + // For rewrite mode, we have nothing to compare the output to. Just sleep a + // reasonable amount, under the assumption that --rewrite won't be used + // under stress or a loaded system. + time.Sleep(d / 10) + return f() + } + runtime.Gosched() + // We are going to evaluate f until it produces the correct answer numStable + // times in a row. + const numAttempts = 100 + const numStable = 3 + // numOk is the number of consecutive calls of f() that have returned the + // correct answer. + numOk := 0 + expected := strings.TrimSpace(td.Expected) + for i := 0; ; i++ { + s := f() + if strings.TrimSpace(s) == expected { + numOk++ + } else { + numOk = 0 + } + if numOk == numStable || i == numAttempts { + if i >= numStable { + td.Logf(tb, "retried for %s (%d times)", time.Duration(i-numStable+1)*d/numAttempts, i-numStable+1) + } + return s + } + time.Sleep(d/numAttempts + 1) + } +} + // CmdArg contains information about an argument on the directive line. An // argument is specified in one of the following forms: // - argument @@ -766,6 +814,13 @@ func (arg CmdArg) scanScalarErr(i int, dest interface{}) error { return nil } +// Logf is a wrapper for tb.Logf which adds file position information, so +// that it's easy to locate the source of the log. +func (td TestData) Logf(tb testing.TB, format string, args ...interface{}) { + tb.Helper() + tb.Logf("%s: %s", td.Pos, fmt.Sprintf(format, args...)) +} + // Fatalf wraps a fatal testing error with test file position information, so // that it's easy to locate the source of the error. func (td TestData) Fatalf(tb testing.TB, format string, args ...interface{}) { diff --git a/datadriven_test.go b/datadriven_test.go index 7fe0e0e..0f790d6 100644 --- a/datadriven_test.go +++ b/datadriven_test.go @@ -18,11 +18,13 @@ import ( "bytes" "fmt" "io/ioutil" + "math/rand" "os" "path/filepath" "reflect" "sort" "strings" + "sync/atomic" "testing" "time" @@ -140,6 +142,33 @@ output` }) } +func TestRetry(t *testing.T) { + var v atomic.Uint32 + RunTest(t, "testdata/retry", func(t *testing.T, d *TestData) string { + switch d.Cmd { + case "inc": + n := 1 + d.MaybeScanArgs(t, "n", &n) + for i := 0; i < n; i++ { + go func() { + time.Sleep(time.Duration(rand.Intn(10)) * time.Microsecond) + v.Add(1) + }() + } + return "" + + case "read": + return d.Retry(t, func() string { + return fmt.Sprint(v.Load()) + }) + + default: + t.Fatalf("unknown directive: %s", d.Cmd) + } + return d.Expected + }) +} + func TestDirective(t *testing.T) { RunTest(t, "testdata/directive", func(t *testing.T, d *TestData) string { var buf bytes.Buffer diff --git a/testdata/retry b/testdata/retry new file mode 100644 index 0000000..703b15b --- /dev/null +++ b/testdata/retry @@ -0,0 +1,17 @@ +read +---- +0 + +inc +---- + +read +---- +1 + +inc n=20 +---- + +read +---- +21