Skip to content

Commit

Permalink
fix: make assert.CollectT concurrency safe
Browse files Browse the repository at this point in the history
  • Loading branch information
czeslavo committed Jul 28, 2023
1 parent 486eb6f commit b93dd46
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
20 changes: 13 additions & 7 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"runtime"
"runtime/debug"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
Expand Down Expand Up @@ -1862,6 +1863,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
// CollectT implements the TestingT interface and collects all errors.
type CollectT struct {
errors []error
mu sync.RWMutex
}

// Errorf collects the error.
Expand Down Expand Up @@ -1912,8 +1914,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper()
}

collect := new(CollectT)
ch := make(chan bool, 1)
var lastTickErrs []error
ch := make(chan []error, 1)

timer := time.NewTimer(waitFor)
defer timer.Stop()
Expand All @@ -1924,19 +1926,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; {
select {
case <-timer.C:
collect.Copy(t)
for _, err := range lastTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
collect.Reset()
go func() {
collect := new(CollectT)
condition(collect)
ch <- len(collect.errors) == 0
ch <- collect.errors
}()
case v := <-ch:
if v {
case errs := <-ch:
if len(errs) == 0 {
return true
}
// Keep the last tick's errors, so that they can be copied to t if the condition is not met on time.
lastTickErrs = errs
tick = ticker.C
}
}
Expand Down
12 changes: 12 additions & 0 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,18 @@ func TestEventuallyWithTTrue(t *testing.T) {
Len(t, mockT.errors, 0)
}

func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
mockT := new(CollectT)

condition := func(collect *CollectT) {
True(collect, false)
}

// To trigger race conditions, we run EventuallyWithT with a nanosecond tick.
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond))
Len(t, mockT.errors, 2)
}

func TestNeverFalse(t *testing.T) {
condition := func() bool {
return false
Expand Down

0 comments on commit b93dd46

Please sign in to comment.