Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add skip options in checks #36

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions leaktest.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (g goroutineByID) Len() int { return len(g) }
func (g goroutineByID) Less(i, j int) bool { return g[i].id < g[j].id }
func (g goroutineByID) Swap(i, j int) { g[i], g[j] = g[j], g[i] }

func interestingGoroutine(g string) (*goroutine, error) {
func interestingGoroutine(g string, opts ...SkipGoroutineOption) (*goroutine, error) {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
return nil, fmt.Errorf("error parsing stack: %q", g)
Expand Down Expand Up @@ -71,17 +71,24 @@ func interestingGoroutine(g string) (*goroutine, error) {
return nil, fmt.Errorf("error parsing goroutine id: %s", err)
}

return &goroutine{id: id, stack: strings.TrimSpace(g)}, nil
stack = strings.TrimSpace(g)
for _, opt := range opts {
if opt(stack) {
return nil, nil
}
}

return &goroutine{id: id, stack: stack}, nil
}

// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines(t ErrorReporter) []*goroutine {
func interestingGoroutines(t ErrorReporter, opts ...SkipGoroutineOption) []*goroutine {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
var gs []*goroutine
for _, g := range strings.Split(string(buf), "\n\n") {
gr, err := interestingGoroutine(g)
gr, err := interestingGoroutine(g, opts...)
if err != nil {
t.Errorf("leaktest: %s", err)
continue
Expand Down Expand Up @@ -114,17 +121,21 @@ type ErrorReporter interface {
Errorf(format string, args ...interface{})
}

// SkipGoroutineOption is a function that can be passed to check functions
// to skip some leaked goroutines based on the content of their stack.
type SkipGoroutineOption func(stack string) (skip bool)

// Check snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked, waiting up to 5 seconds in error conditions
func Check(t ErrorReporter) func() {
return CheckTimeout(t, 5*time.Second)
func Check(t ErrorReporter, opts ...SkipGoroutineOption) func() {
return CheckTimeout(t, 5*time.Second, opts...)
}

// CheckTimeout is the same as Check, but with a configurable timeout
func CheckTimeout(t ErrorReporter, dur time.Duration) func() {
func CheckTimeout(t ErrorReporter, dur time.Duration, opts ...SkipGoroutineOption) func() {
ctx, cancel := context.WithCancel(context.Background())
fn := CheckContext(ctx, t)
fn := CheckContext(ctx, t, opts...)
return func() {
timer := time.AfterFunc(dur, cancel)
fn()
Expand All @@ -136,16 +147,16 @@ func CheckTimeout(t ErrorReporter, dur time.Duration) func() {

// CheckContext is the same as Check, but uses a context.Context for
// cancellation and timeout control
func CheckContext(ctx context.Context, t ErrorReporter) func() {
func CheckContext(ctx context.Context, t ErrorReporter, opts ...SkipGoroutineOption) func() {
orig := map[uint64]bool{}
for _, g := range interestingGoroutines(t) {
for _, g := range interestingGoroutines(t, opts...) {
orig[g.id] = true
}
return func() {
var leaked []string
var ok bool
// fast check if we have no leaks
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t)); ok {
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t, opts...)); ok {
return
}
ticker := time.NewTicker(TickerInterval)
Expand All @@ -154,7 +165,7 @@ func CheckContext(ctx context.Context, t ErrorReporter) func() {
for {
select {
case <-ticker.C:
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t)); ok {
if leaked, ok = leakedGoroutines(orig, interestingGoroutines(t, opts...)); ok {
return
}
continue
Expand Down
23 changes: 20 additions & 3 deletions leaktest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"regexp"
"sync"
"testing"
"time"
Expand All @@ -28,6 +29,7 @@ func TestCheck(t *testing.T) {
leakyFuncs := []struct {
f func()
name string
opts []SkipGoroutineOption
expectLeak bool
}{
{
Expand Down Expand Up @@ -105,6 +107,21 @@ func TestCheck(t *testing.T) {
}
},
},
{
name: "Skip leak with option",
opts: []SkipGoroutineOption{
func(stack string) bool {
re := regexp.MustCompile("created by .*leaktest\\.TestCheck")
return re.MatchString(stack)
},
},
expectLeak: false,
f: func() {
go func() {
time.Sleep(2 * time.Second)
}()
},
},
}

// Start our keep alive server for keep alive tests
Expand All @@ -119,7 +136,7 @@ func TestCheck(t *testing.T) {

t.Run(leakyTestcase.name, func(t *testing.T) {
checker := &testReporter{}
snapshot := CheckTimeout(checker, time.Second)
snapshot := CheckTimeout(checker, time.Second, leakyTestcase.opts...)
go leakyTestcase.f()

snapshot()
Expand All @@ -128,7 +145,7 @@ func TestCheck(t *testing.T) {
t.Error("didn't catch sleeping goroutine")
}
if checker.failed && !leakyTestcase.expectLeak {
t.Error("got leak but didn't expect it")
t.Errorf("got leak but didn't expect it:\n\t%v", checker.msg)
}
})
}
Expand All @@ -138,7 +155,7 @@ func TestCheck(t *testing.T) {
// be based on time after the test finishes rather than time after the test's
// start.
func TestSlowTest(t *testing.T) {
defer CheckTimeout(t, 1000 * time.Millisecond)()
defer CheckTimeout(t, 1000*time.Millisecond)()

go time.Sleep(1500 * time.Millisecond)
time.Sleep(750 * time.Millisecond)
Expand Down