Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in test #337

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions sync2/device_data_ticker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,31 @@ import (
"github.com/matrix-org/sliding-sync/pubsub"
)

type syncSlice[T any] struct {
slice []T
mu sync.Mutex
}

func (s *syncSlice[T]) append(item T) {
s.mu.Lock()
defer s.mu.Unlock()
s.slice = append(s.slice, item)
}

func (s *syncSlice[T]) clone() []T {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]T, len(s.slice))
copy(result, s.slice)
return result
}
kegsay marked this conversation as resolved.
Show resolved Hide resolved

func TestDeviceTickerBasic(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
payloads.append(payload)
})
var wg sync.WaitGroup
wg.Add(1)
Expand All @@ -31,29 +50,31 @@ func TestDeviceTickerBasic(t *testing.T) {
DeviceID: "b",
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
result := payloads.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to clone here, sorry?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd need to lock every time we access the underlying slice, which could then lead to inconsistent results (e.g len=1 when len() is called, but then increases to len=2 later on. It's easier to snapshot the slice.

if len(result) != 1 {
t.Fatalf("expected 1 callback, got %d", len(result))
}
want := map[string][]string{
"a": {"b"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
// check stopping works
payloads = []*pubsub.V2DeviceData{}
payloads = syncSlice[*pubsub.V2DeviceData]{}
ticker.Stop()
wg.Wait()
time.Sleep(duration * 2)
if len(payloads) != 0 {
t.Fatalf("got extra payloads: %+v", payloads)
result = payloads.clone()
if len(result) != 0 {
t.Fatalf("got extra payloads: %+v", result)
}
}

func TestDeviceTickerBatchesCorrectly(t *testing.T) {
duration := 100 * time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData
var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
payloads.append(payload)
})
go ticker.Run()
defer ticker.Stop()
Expand All @@ -74,23 +95,23 @@ func TestDeviceTickerBatchesCorrectly(t *testing.T) {
DeviceID: "y", // new device and user
})
time.Sleep(duration * 2)
if len(payloads) != 1 {
t.Fatalf("expected 1 callback, got %d", len(payloads))
result := payloads.clone()
if len(result) != 1 {
t.Fatalf("expected 1 callback, got %d", len(result))
}
want := map[string][]string{
"a": {"b", "bb"},
"x": {"y"},
}
assertPayloadEqual(t, payloads[0].UserIDToDeviceIDs, want)
assertPayloadEqual(t, result[0].UserIDToDeviceIDs, want)
}

func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
duration := time.Millisecond
ticker := NewDeviceDataTicker(duration)
var payloads []*pubsub.V2DeviceData

var payloads syncSlice[*pubsub.V2DeviceData]
ticker.SetCallback(func(payload *pubsub.V2DeviceData) {
payloads = append(payloads, payload)
payloads.append(payload)
})
ticker.Remember(PollerID{
UserID: "a",
Expand All @@ -104,8 +125,9 @@ func TestDeviceTickerForgetsAfterEmitting(t *testing.T) {
DeviceID: "b",
})
time.Sleep(10 * duration)
if len(payloads) != 1 {
t.Fatalf("got %d payloads, want 1", len(payloads))
result := payloads.clone()
if len(result) != 1 {
t.Fatalf("got %d payloads, want 1", len(result))
}
}

Expand Down
30 changes: 27 additions & 3 deletions sync2/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ import (

const initialSinceToken = "0"

// monkey patch out time.Since with a test controlled value.
// This is done in the init block so we can make sure we swap it out BEFORE any pollers
// start. If we wait until pollers exist, we get data races. This includes pollers in tests
// which don't use timeSince, hence the init block.
var (
timeSinceMu sync.Mutex
timeSinceValue = time.Duration(0) // 0 means use the real impl
)

func setTimeSinceValue(val time.Duration) {
timeSinceMu.Lock()
timeSinceValue = time.Minute * 2
timeSinceMu.Unlock()
}
Comment on lines +30 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val seems unused here?

func init() {
timeSince = func(t time.Time) time.Duration {
timeSinceMu.Lock()
defer timeSinceMu.Unlock()
if timeSinceValue == 0 {
return time.Since(t)
}
return timeSinceValue
}
}

// Tests that EnsurePolling works in the happy case
func TestPollerMapEnsurePolling(t *testing.T) {
nextSince := "next"
Expand Down Expand Up @@ -528,9 +553,8 @@ func TestPollerPollUpdateDeviceSincePeriodically(t *testing.T) {
wantSinceFromSync = next

// 4. ... some time has passed, this triggers the 1min limit
timeSince = func(d time.Time) time.Duration {
return time.Minute * 2
}
setTimeSinceValue(time.Minute * 2)
defer setTimeSinceValue(0) // reset
next = "10"
syncResponses <- &SyncResponse{NextBatch: next}
mustEqualSince(t, <-syncCalledWithSince, wantSinceFromSync)
Expand Down
Loading