-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aaf9521
commit c74c72f
Showing
2 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package chanutils | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
|
||
"github.com/btcsuite/btclog" | ||
) | ||
|
||
// BatchWriterConfig holds the configuration options for BatchWriter. | ||
type BatchWriterConfig[T any] struct { | ||
// QueueBufferSize sets the buffer size of the output channel of the | ||
// concurrent queue used by the BatchWriter. | ||
QueueBufferSize int | ||
|
||
// MaxBatch is the maximum number of filters to be persisted to the DB | ||
// in one go. | ||
MaxBatch int | ||
|
||
// DBWritesTickerDuration is the time after receiving a filter that the | ||
// writer will wait for more filters before writing the current batch | ||
// to the DB. | ||
DBWritesTickerDuration time.Duration | ||
|
||
// Logger is the logger that the BatchWriter should use for any logs. | ||
Logger btclog.Logger | ||
|
||
// PutItems will be used by the BatchWriter to persist filters in | ||
// batches. | ||
PutItems func(...T) error | ||
} | ||
|
||
// BatchWriter manages writing Filters to the DB and tries to batch the writes | ||
// as much as possible. | ||
type BatchWriter[T any] struct { | ||
started sync.Once | ||
stopped sync.Once | ||
|
||
cfg *BatchWriterConfig[T] | ||
|
||
queue *ConcurrentQueue[T] | ||
|
||
quit chan struct{} | ||
wg sync.WaitGroup | ||
} | ||
|
||
// NewBatchWriter constructs a new BatchWriter using the given | ||
// BatchWriterConfig. | ||
func NewBatchWriter[T any](cfg *BatchWriterConfig[T]) *BatchWriter[T] { | ||
return &BatchWriter[T]{ | ||
cfg: cfg, | ||
queue: NewConcurrentQueue[T](cfg.QueueBufferSize), | ||
quit: make(chan struct{}), | ||
} | ||
} | ||
|
||
// Start starts the BatchWriter. | ||
func (b *BatchWriter[T]) Start() { | ||
b.started.Do(func() { | ||
b.queue.Start() | ||
|
||
b.wg.Add(1) | ||
go b.manageNewItems() | ||
}) | ||
} | ||
|
||
// Stop stops the BatchWriter. | ||
func (b *BatchWriter[T]) Stop() { | ||
b.stopped.Do(func() { | ||
close(b.quit) | ||
b.wg.Wait() | ||
|
||
b.queue.Stop() | ||
}) | ||
} | ||
|
||
// AddItem adds a given item to the BatchWriter queue. | ||
func (b *BatchWriter[T]) AddItem(item T) { | ||
b.queue.ChanIn() <- item | ||
} | ||
|
||
// manageNewItems manages collecting filters and persisting them to the DB. | ||
// There are two conditions for writing a batch of filters to the DB: the first | ||
// is if a certain threshold (MaxBatch) of filters has been collected and the | ||
// other is if at least one filter has been collected and a timeout has been | ||
// reached. | ||
// | ||
// NOTE: this must be run in a goroutine. | ||
func (b *BatchWriter[T]) manageNewItems() { | ||
defer b.wg.Done() | ||
|
||
batch := make([]T, 0, b.cfg.MaxBatch) | ||
|
||
// writeBatch writes the current contents of the batch slice to the | ||
// filters DB. | ||
writeBatch := func() { | ||
if len(batch) == 0 { | ||
return | ||
} | ||
|
||
err := b.cfg.PutItems(batch...) | ||
if err != nil { | ||
b.cfg.Logger.Warnf("Couldn't write filters to "+ | ||
"filterDB: %v", err) | ||
} | ||
|
||
// Empty the batch slice. | ||
batch = make([]T, 0, b.cfg.MaxBatch) | ||
} | ||
|
||
ticker := time.NewTicker(b.cfg.DBWritesTickerDuration) | ||
defer ticker.Stop() | ||
|
||
// Stop the ticker since we don't want it to tick unless there is at | ||
// least one item in the queue. | ||
ticker.Stop() | ||
|
||
for { | ||
select { | ||
case filter, ok := <-b.queue.ChanOut(): | ||
if !ok { | ||
return | ||
} | ||
|
||
batch = append(batch, filter) | ||
|
||
switch len(batch) { | ||
// If the batch slice is full, we stop the timer and | ||
// write the batch contents to disk. | ||
case b.cfg.MaxBatch: | ||
ticker.Stop() | ||
writeBatch() | ||
|
||
// If an item is added to the batch, we reset the timer. | ||
// This ensures that if the batch threshold is not met | ||
// then items are still persisted in a timely manner. | ||
default: | ||
ticker.Reset(b.cfg.DBWritesTickerDuration) | ||
} | ||
|
||
case <-ticker.C: | ||
// If the ticker ticks, then we stop it and write the | ||
// current batch contents to the db. If any more items | ||
// are added, the ticker will be reset. | ||
ticker.Stop() | ||
writeBatch() | ||
|
||
case <-b.quit: | ||
writeBatch() | ||
|
||
return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
package chanutils | ||
|
||
import ( | ||
"fmt" | ||
"math/rand" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const waitTime = time.Second * 5 | ||
|
||
// TestBatchWriter tests that the BatchWriter behaves as expected. | ||
func TestBatchWriter(t *testing.T) { | ||
t.Parallel() | ||
rand.Seed(time.Now().UnixNano()) | ||
|
||
// waitForItems is a helper function that will wait for a given set of | ||
// items to appear in the db. | ||
waitForItems := func(db *mockItemsDB, items ...*item) { | ||
err := waitFor(func() bool { | ||
return db.hasItems(items...) | ||
}, waitTime) | ||
require.NoError(t, err) | ||
} | ||
|
||
t.Run("filters persisted after ticker", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Create a mock filters DB. | ||
db := newMockItemsDB() | ||
|
||
// Construct a new BatchWriter backed by the mock db. | ||
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ | ||
QueueBufferSize: 10, | ||
MaxBatch: 20, | ||
DBWritesTickerDuration: time.Millisecond * 500, | ||
PutItems: db.PutItems, | ||
}) | ||
b.Start() | ||
t.Cleanup(b.Stop) | ||
|
||
fs := genFilterSet(5) | ||
for _, f := range fs { | ||
b.AddItem(f) | ||
} | ||
waitForItems(db, fs...) | ||
}) | ||
|
||
t.Run("write once threshold is reached", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Create a mock filters DB. | ||
db := newMockItemsDB() | ||
|
||
// Construct a new BatchWriter backed by the mock db. | ||
// Make the DB writes ticker duration extra long so that we | ||
// can explicitly test that the batch gets persisted if the | ||
// MaxBatch threshold is reached. | ||
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ | ||
QueueBufferSize: 10, | ||
MaxBatch: 20, | ||
DBWritesTickerDuration: time.Hour, | ||
PutItems: db.PutItems, | ||
}) | ||
b.Start() | ||
t.Cleanup(b.Stop) | ||
|
||
// Generate 30 filters and add each one to the batch writer. | ||
fs := genFilterSet(30) | ||
for _, f := range fs { | ||
b.AddItem(f) | ||
} | ||
|
||
// Since the MaxBatch threshold has been reached, we expect the | ||
// first 20 filters to be persisted. | ||
waitForItems(db, fs[:20]...) | ||
|
||
// Since the last 10 filters don't reach the threshold and since | ||
// the ticker has definitely not ticked yet, we don't expect the | ||
// last 10 filters to be in the db yet. | ||
require.False(t, db.hasItems(fs[21:]...)) | ||
}) | ||
|
||
t.Run("stress test", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Create a mock filters DB. | ||
db := newMockItemsDB() | ||
|
||
// Construct a new BatchWriter backed by the mock db. | ||
// Make the DB writes ticker duration extra long so that we | ||
// can explicitly test that the batch gets persisted if the | ||
// MaxBatch threshold is reached. | ||
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ | ||
QueueBufferSize: 5, | ||
MaxBatch: 5, | ||
DBWritesTickerDuration: time.Millisecond * 2, | ||
PutItems: db.PutItems, | ||
}) | ||
b.Start() | ||
t.Cleanup(b.Stop) | ||
|
||
// Generate lots of filters and add each to the batch writer. | ||
// Sleep for a bit between each filter to ensure that we | ||
// sometimes hit the timeout write and sometimes the threshold | ||
// write. | ||
fs := genFilterSet(1000) | ||
for _, f := range fs { | ||
b.AddItem(f) | ||
|
||
n := rand.Intn(3) | ||
time.Sleep(time.Duration(n) * time.Millisecond) | ||
} | ||
|
||
// Since the MaxBatch threshold has been reached, we expect the | ||
// first 20 filters to be persisted. | ||
waitForItems(db, fs...) | ||
}) | ||
} | ||
|
||
type item struct { | ||
i int | ||
} | ||
|
||
// mockItemsDB is a mock DB that holds a set of items. | ||
type mockItemsDB struct { | ||
items map[int]bool | ||
mu sync.Mutex | ||
} | ||
|
||
// newMockItemsDB constructs a new mockItemsDB. | ||
func newMockItemsDB() *mockItemsDB { | ||
return &mockItemsDB{ | ||
items: make(map[int]bool), | ||
} | ||
} | ||
|
||
// hasItems returns true if the db contains all the given items. | ||
func (m *mockItemsDB) hasItems(items ...*item) bool { | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
for _, i := range items { | ||
_, ok := m.items[i.i] | ||
if !ok { | ||
return false | ||
} | ||
} | ||
|
||
return true | ||
} | ||
|
||
// PutItems adds a set of items to the db. | ||
func (m *mockItemsDB) PutItems(items ...*item) error { | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
for _, i := range items { | ||
m.items[i.i] = true | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// genItemSet generates a set of numFilters items. | ||
func genFilterSet(numFilters int) []*item { | ||
res := make([]*item, numFilters) | ||
for i := 0; i < numFilters; i++ { | ||
res[i] = &item{i: i} | ||
} | ||
|
||
return res | ||
} | ||
|
||
// pollInterval is a constant specifying a 200 ms interval. | ||
const pollInterval = 200 * time.Millisecond | ||
|
||
// waitFor is a helper test function that will wait for a timeout period of | ||
// time until the passed predicate returns true. This function is helpful as | ||
// timing doesn't always line up well when running integration tests with | ||
// several running lnd nodes. This function gives callers a way to assert that | ||
// some property is upheld within a particular time frame. | ||
func waitFor(pred func() bool, timeout time.Duration) error { | ||
exitTimer := time.After(timeout) | ||
result := make(chan bool, 1) | ||
|
||
for { | ||
<-time.After(pollInterval) | ||
|
||
go func() { | ||
result <- pred() | ||
}() | ||
|
||
// Each time we call the pred(), we expect a result to be | ||
// returned otherwise it will timeout. | ||
select { | ||
case <-exitTimer: | ||
return fmt.Errorf("predicate not satisfied after " + | ||
"time out") | ||
|
||
case succeed := <-result: | ||
if succeed { | ||
return nil | ||
} | ||
} | ||
} | ||
} |