Skip to content

Commit

Permalink
chanutils: add BatchWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Jul 11, 2023
1 parent aaf9521 commit c74c72f
Show file tree
Hide file tree
Showing 2 changed files with 364 additions and 0 deletions.
154 changes: 154 additions & 0 deletions chanutils/batch_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package chanutils

import (
"sync"
"time"

"github.com/btcsuite/btclog"
)

// BatchWriterConfig holds the configuration options for BatchWriter.
type BatchWriterConfig[T any] struct {
// QueueBufferSize sets the buffer size of the output channel of the
// concurrent queue used by the BatchWriter.
QueueBufferSize int

// MaxBatch is the maximum number of filters to be persisted to the DB
// in one go.
MaxBatch int

// DBWritesTickerDuration is the time after receiving a filter that the
// writer will wait for more filters before writing the current batch
// to the DB.
DBWritesTickerDuration time.Duration

// Logger is the logger that the BatchWriter should use for any logs.
Logger btclog.Logger

// PutItems will be used by the BatchWriter to persist filters in
// batches.
PutItems func(...T) error
}

// BatchWriter manages writing Filters to the DB and tries to batch the writes
// as much as possible.
type BatchWriter[T any] struct {
started sync.Once
stopped sync.Once

cfg *BatchWriterConfig[T]

queue *ConcurrentQueue[T]

quit chan struct{}
wg sync.WaitGroup
}

// NewBatchWriter constructs a new BatchWriter using the given
// BatchWriterConfig.
func NewBatchWriter[T any](cfg *BatchWriterConfig[T]) *BatchWriter[T] {
return &BatchWriter[T]{
cfg: cfg,
queue: NewConcurrentQueue[T](cfg.QueueBufferSize),
quit: make(chan struct{}),
}
}

// Start starts the BatchWriter.
func (b *BatchWriter[T]) Start() {
b.started.Do(func() {
b.queue.Start()

b.wg.Add(1)
go b.manageNewItems()
})
}

// Stop stops the BatchWriter.
func (b *BatchWriter[T]) Stop() {
b.stopped.Do(func() {
close(b.quit)
b.wg.Wait()

b.queue.Stop()
})
}

// AddItem adds a given item to the BatchWriter queue.
func (b *BatchWriter[T]) AddItem(item T) {
b.queue.ChanIn() <- item
}

// manageNewItems manages collecting filters and persisting them to the DB.
// There are two conditions for writing a batch of filters to the DB: the first
// is if a certain threshold (MaxBatch) of filters has been collected and the
// other is if at least one filter has been collected and a timeout has been
// reached.
//
// NOTE: this must be run in a goroutine.
func (b *BatchWriter[T]) manageNewItems() {
defer b.wg.Done()

batch := make([]T, 0, b.cfg.MaxBatch)

// writeBatch writes the current contents of the batch slice to the
// filters DB.
writeBatch := func() {
if len(batch) == 0 {
return
}

err := b.cfg.PutItems(batch...)
if err != nil {
b.cfg.Logger.Warnf("Couldn't write filters to "+
"filterDB: %v", err)
}

// Empty the batch slice.
batch = make([]T, 0, b.cfg.MaxBatch)
}

ticker := time.NewTicker(b.cfg.DBWritesTickerDuration)
defer ticker.Stop()

// Stop the ticker since we don't want it to tick unless there is at
// least one item in the queue.
ticker.Stop()

for {
select {
case filter, ok := <-b.queue.ChanOut():
if !ok {
return
}

batch = append(batch, filter)

switch len(batch) {
// If the batch slice is full, we stop the timer and
// write the batch contents to disk.
case b.cfg.MaxBatch:
ticker.Stop()
writeBatch()

// If an item is added to the batch, we reset the timer.
// This ensures that if the batch threshold is not met
// then items are still persisted in a timely manner.
default:
ticker.Reset(b.cfg.DBWritesTickerDuration)
}

case <-ticker.C:
// If the ticker ticks, then we stop it and write the
// current batch contents to the db. If any more items
// are added, the ticker will be reset.
ticker.Stop()
writeBatch()

case <-b.quit:
writeBatch()

return
}
}
}
210 changes: 210 additions & 0 deletions chanutils/batch_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package chanutils

import (
"fmt"
"math/rand"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

const waitTime = time.Second * 5

// TestBatchWriter tests that the BatchWriter behaves as expected.
func TestBatchWriter(t *testing.T) {
t.Parallel()
rand.Seed(time.Now().UnixNano())

// waitForItems is a helper function that will wait for a given set of
// items to appear in the db.
waitForItems := func(db *mockItemsDB, items ...*item) {
err := waitFor(func() bool {
return db.hasItems(items...)
}, waitTime)
require.NoError(t, err)
}

t.Run("filters persisted after ticker", func(t *testing.T) {
t.Parallel()

// Create a mock filters DB.
db := newMockItemsDB()

// Construct a new BatchWriter backed by the mock db.
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{
QueueBufferSize: 10,
MaxBatch: 20,
DBWritesTickerDuration: time.Millisecond * 500,
PutItems: db.PutItems,
})
b.Start()
t.Cleanup(b.Stop)

fs := genFilterSet(5)
for _, f := range fs {
b.AddItem(f)
}
waitForItems(db, fs...)
})

t.Run("write once threshold is reached", func(t *testing.T) {
t.Parallel()

// Create a mock filters DB.
db := newMockItemsDB()

// Construct a new BatchWriter backed by the mock db.
// Make the DB writes ticker duration extra long so that we
// can explicitly test that the batch gets persisted if the
// MaxBatch threshold is reached.
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{
QueueBufferSize: 10,
MaxBatch: 20,
DBWritesTickerDuration: time.Hour,
PutItems: db.PutItems,
})
b.Start()
t.Cleanup(b.Stop)

// Generate 30 filters and add each one to the batch writer.
fs := genFilterSet(30)
for _, f := range fs {
b.AddItem(f)
}

// Since the MaxBatch threshold has been reached, we expect the
// first 20 filters to be persisted.
waitForItems(db, fs[:20]...)

// Since the last 10 filters don't reach the threshold and since
// the ticker has definitely not ticked yet, we don't expect the
// last 10 filters to be in the db yet.
require.False(t, db.hasItems(fs[21:]...))
})

t.Run("stress test", func(t *testing.T) {
t.Parallel()

// Create a mock filters DB.
db := newMockItemsDB()

// Construct a new BatchWriter backed by the mock db.
// Make the DB writes ticker duration extra long so that we
// can explicitly test that the batch gets persisted if the
// MaxBatch threshold is reached.
b := NewBatchWriter[*item](&BatchWriterConfig[*item]{
QueueBufferSize: 5,
MaxBatch: 5,
DBWritesTickerDuration: time.Millisecond * 2,
PutItems: db.PutItems,
})
b.Start()
t.Cleanup(b.Stop)

// Generate lots of filters and add each to the batch writer.
// Sleep for a bit between each filter to ensure that we
// sometimes hit the timeout write and sometimes the threshold
// write.
fs := genFilterSet(1000)
for _, f := range fs {
b.AddItem(f)

n := rand.Intn(3)
time.Sleep(time.Duration(n) * time.Millisecond)
}

// Since the MaxBatch threshold has been reached, we expect the
// first 20 filters to be persisted.
waitForItems(db, fs...)
})
}

type item struct {
i int
}

// mockItemsDB is a mock DB that holds a set of items.
type mockItemsDB struct {
items map[int]bool
mu sync.Mutex
}

// newMockItemsDB constructs a new mockItemsDB.
func newMockItemsDB() *mockItemsDB {
return &mockItemsDB{
items: make(map[int]bool),
}
}

// hasItems returns true if the db contains all the given items.
func (m *mockItemsDB) hasItems(items ...*item) bool {
m.mu.Lock()
defer m.mu.Unlock()

for _, i := range items {
_, ok := m.items[i.i]
if !ok {
return false
}
}

return true
}

// PutItems adds a set of items to the db.
func (m *mockItemsDB) PutItems(items ...*item) error {
m.mu.Lock()
defer m.mu.Unlock()

for _, i := range items {
m.items[i.i] = true
}

return nil
}

// genItemSet generates a set of numFilters items.
func genFilterSet(numFilters int) []*item {
res := make([]*item, numFilters)
for i := 0; i < numFilters; i++ {
res[i] = &item{i: i}
}

return res
}

// pollInterval is a constant specifying a 200 ms interval.
const pollInterval = 200 * time.Millisecond

// waitFor is a helper test function that will wait for a timeout period of
// time until the passed predicate returns true. This function is helpful as
// timing doesn't always line up well when running integration tests with
// several running lnd nodes. This function gives callers a way to assert that
// some property is upheld within a particular time frame.
func waitFor(pred func() bool, timeout time.Duration) error {
exitTimer := time.After(timeout)
result := make(chan bool, 1)

for {
<-time.After(pollInterval)

go func() {
result <- pred()
}()

// Each time we call the pred(), we expect a result to be
// returned otherwise it will timeout.
select {
case <-exitTimer:
return fmt.Errorf("predicate not satisfied after " +
"time out")

case succeed := <-result:
if succeed {
return nil
}
}
}
}

0 comments on commit c74c72f

Please sign in to comment.