Skip to content

Commit

Permalink
feat: add skip options in checks
Browse files Browse the repository at this point in the history
  • Loading branch information
affo committed Mar 10, 2020
1 parent d73c753 commit e4877f4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
35 changes: 23 additions & 12 deletions leaktest.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (g goroutineByID) Len() int { return len(g) }
func (g goroutineByID) Less(i, j int) bool { return g[i].id < g[j].id }
func (g goroutineByID) Swap(i, j int) { g[i], g[j] = g[j], g[i] }

func interestingGoroutine(g string) (*goroutine, error) {
func interestingGoroutine(g string, opts ...SkipGoroutineOption) (*goroutine, error) {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
return nil, fmt.Errorf("error parsing stack: %q", g)
Expand Down Expand Up @@ -71,17 +71,24 @@ func interestingGoroutine(g string) (*goroutine, error) {
return nil, fmt.Errorf("error parsing goroutine id: %s", err)
}

return &goroutine{id: id, stack: strings.TrimSpace(g)}, nil
stack = strings.TrimSpace(g)
for _, opt := range opts {
if opt(id, stack) {
return nil, nil
}
}

return &goroutine{id: id, stack: stack}, nil
}

// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines(t ErrorReporter) []*goroutine {
func interestingGoroutines(t ErrorReporter, opts ...SkipGoroutineOption) []*goroutine {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
var gs []*goroutine
for _, g := range strings.Split(string(buf), "\n\n") {
gr, err := interestingGoroutine(g)
gr, err := interestingGoroutine(g, opts...)
if err != nil {
t.Errorf("leaktest: %s", err)
continue
Expand Down Expand Up @@ -114,17 +121,21 @@ type ErrorReporter interface {
Errorf(format string, args ...interface{})
}

// SkipGoroutineOption is a function that can be passed to check functions
// to skip some leaked goroutines based on their stack/id.
type SkipGoroutineOption func(id uint64, stack string) (skip bool)

// Check snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked, waiting up to 5 seconds in error conditions
func Check(t ErrorReporter) func() {
return CheckTimeout(t, 5*time.Second)
func Check(t ErrorReporter, opts ...SkipGoroutineOption) func() {
return CheckTimeout(t, 5*time.Second, opts...)
}

// CheckTimeout is the same as Check, but with a configurable timeout
func CheckTimeout(t ErrorReporter, dur time.Duration) func() {
func CheckTimeout(t ErrorReporter, dur time.Duration, opts ...SkipGoroutineOption) func() {
ctx, cancel := context.WithCancel(context.Background())
fn := CheckContext(ctx, t)
fn := CheckContext(ctx, t, opts...)
return func() {
timer := time.AfterFunc(dur, cancel)
fn()
Expand All @@ -136,16 +147,16 @@ func CheckTimeout(t ErrorReporter, dur time.Duration) func() {

// CheckContext is the same as Check, but uses a context.Context for
// cancellation and timeout control
func CheckContext(ctx context.Context, t ErrorReporter) func() {
func CheckContext(ctx context.Context, t ErrorReporter, opts ...SkipGoroutineOption) func() {
orig := map[uint64]bool{}
for _, g := range interestingGoroutines(t) {
for _, g := range interestingGoroutines(t, opts...) {
orig[g.id] = true
}
return func() {
var leaked []string
var ok bool
// fast check if we have no leaks
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t)); ok {
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t, opts...)); ok {
return
}
ticker := time.NewTicker(TickerInterval)
Expand All @@ -154,7 +165,7 @@ func CheckContext(ctx context.Context, t ErrorReporter) func() {
for {
select {
case <-ticker.C:
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t)); ok {
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t, opts...)); ok {
return
}
continue
Expand Down
22 changes: 19 additions & 3 deletions leaktest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
Expand All @@ -28,6 +29,7 @@ func TestCheck(t *testing.T) {
leakyFuncs := []struct {
f func()
name string
opts []SkipGoroutineOption
expectLeak bool
}{
{
Expand Down Expand Up @@ -105,6 +107,20 @@ func TestCheck(t *testing.T) {
}
},
},
{
name: "Skip leak with option",
opts: []SkipGoroutineOption{
func(id uint64, stack string) bool {
return strings.Contains(stack, "created by leaktest.TestCheck")
},
},
expectLeak: false,
f: func() {
go func() {
time.Sleep(time.Second)
}()
},
},
}

// Start our keep alive server for keep alive tests
Expand All @@ -119,7 +135,7 @@ func TestCheck(t *testing.T) {

t.Run(leakyTestcase.name, func(t *testing.T) {
checker := &testReporter{}
snapshot := CheckTimeout(checker, time.Second)
snapshot := CheckTimeout(checker, time.Second, leakyTestcase.opts...)
go leakyTestcase.f()

snapshot()
Expand All @@ -128,7 +144,7 @@ func TestCheck(t *testing.T) {
t.Error("didn't catch sleeping goroutine")
}
if checker.failed && !leakyTestcase.expectLeak {
t.Error("got leak but didn't expect it")
t.Errorf("got leak but didn't expect it:\n\t%v", checker.msg)
}
})
}
Expand All @@ -138,7 +154,7 @@ func TestCheck(t *testing.T) {
// be based on time after the test finishes rather than time after the test's
// start.
func TestSlowTest(t *testing.T) {
defer CheckTimeout(t, 1000 * time.Millisecond)()
defer CheckTimeout(t, 1000*time.Millisecond)()

go time.Sleep(1500 * time.Millisecond)
time.Sleep(750 * time.Millisecond)
Expand Down

0 comments on commit e4877f4

Please sign in to comment.