diff --git a/go.mod b/go.mod index 5f51b2b86..6123b652b 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/linkedin/goavro/v2 v2.12.0 github.com/mitchellh/mapstructure v1.5.0 github.com/mwitkow/grpc-proxy v0.0.0-20230212185441-f345521cb9c9 + github.com/pelletier/go-toml/v2 v2.1.1 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.17.0 github.com/riferrei/srclient v0.5.4 diff --git a/go.sum b/go.sum index 19224b433..b6a482b19 100644 --- a/go.sum +++ b/go.sum @@ -237,6 +237,8 @@ github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3 github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opencontainers/runc v1.1.3 h1:vIXrkId+0/J2Ymu2m7VjGvbSlAId9XNRPhn2p4b+d8w= github.com/opencontainers/runc v1.1.3/go.mod h1:1J5XiS+vdZ3wCyZybsuxXZWGrgSr8fFJHLXuG2PsnNg= +github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= +github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/pkg/config/toml.go b/pkg/config/toml.go new file mode 100644 index 000000000..f51db7636 --- /dev/null +++ b/pkg/config/toml.go @@ -0,0 +1,22 @@ +package config + +import ( + "errors" + "io" + + "github.com/pelletier/go-toml/v2" +) + +// DecodeTOML decodes toml from r in to v. +// Requires strict field matches and returns full toml.StrictMissingError details. +func DecodeTOML(r io.Reader, v any) error { + d := toml.NewDecoder(r).DisallowUnknownFields() + if err := d.Decode(v); err != nil { + var strict *toml.StrictMissingError + if errors.As(err, &strict) { + return errors.New(strict.String()) + } + return err + } + return nil +} diff --git a/pkg/utils/bytes/bytes.go b/pkg/utils/bytes/bytes.go index a29e25dda..01586ec3a 100644 --- a/pkg/utils/bytes/bytes.go +++ b/pkg/utils/bytes/bytes.go @@ -15,3 +15,13 @@ func TrimQuotes(input []byte) []byte { } return input } + +// IsEmpty returns true if bytes contains only zero values, or has len 0. +func IsEmpty(bytes []byte) bool { + for _, b := range bytes { + if b != 0 { + return false + } + } + return true +} diff --git a/pkg/utils/bytes/bytes_test.go b/pkg/utils/bytes/bytes_test.go new file mode 100644 index 000000000..fa514493a --- /dev/null +++ b/pkg/utils/bytes/bytes_test.go @@ -0,0 +1,16 @@ +package bytes_test + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/bytes" + "github.com/stretchr/testify/require" +) + +func TestIsEmpty(t *testing.T) { + t.Parallel() + + require.True(t, bytes.IsEmpty([]byte{0, 0, 0})) + require.True(t, bytes.IsEmpty([]byte{})) + require.False(t, bytes.IsEmpty([]byte{1, 2, 3, 5})) +} diff --git a/pkg/utils/sleeper_task.go b/pkg/utils/sleeper_task.go new file mode 100644 index 000000000..bbbcfb287 --- /dev/null +++ b/pkg/utils/sleeper_task.go @@ -0,0 +1,123 @@ +package utils + +import ( + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +// Worker is a simple interface that represents some work to do repeatedly +type Worker interface { + Work() + Name() string +} + +// SleeperTask represents a task that waits in the background to process some work. +type SleeperTask struct { + services.StateMachine + worker Worker + chQueue chan struct{} + chStop chan struct{} + chDone chan struct{} + chWorkDone chan struct{} +} + +// NewSleeperTask takes a worker and returns a SleeperTask. +// +// SleeperTask is guaranteed to call Work on the worker at least once for every +// WakeUp call. +// If the Worker is busy when WakeUp is called, the Worker will be called again +// immediately after it is finished. For this reason you should take care to +// make sure that Worker is idempotent. +// WakeUp does not block. +func NewSleeperTask(worker Worker) *SleeperTask { + s := &SleeperTask{ + worker: worker, + chQueue: make(chan struct{}, 1), + chStop: make(chan struct{}), + chDone: make(chan struct{}), + chWorkDone: make(chan struct{}, 10), + } + + _ = s.StartOnce("SleeperTask-"+worker.Name(), func() error { + go s.workerLoop() + return nil + }) + + return s +} + +// Stop stops the SleeperTask +func (s *SleeperTask) Stop() error { + return s.StopOnce("SleeperTask-"+s.worker.Name(), func() error { + close(s.chStop) + select { + case <-s.chDone: + case <-time.After(15 * time.Second): + return fmt.Errorf("SleeperTask-%s took too long to stop", s.worker.Name()) + } + return nil + }) +} + +func (s *SleeperTask) WakeUpIfStarted() { + s.IfStarted(func() { + select { + case s.chQueue <- struct{}{}: + default: + } + }) +} + +// WakeUp wakes up the sleeper task, asking it to execute its Worker. +func (s *SleeperTask) WakeUp() { + if !s.IfStarted(func() { + select { + case s.chQueue <- struct{}{}: + default: + } + }) { + panic("cannot wake up stopped sleeper task") + } +} + +func (s *SleeperTask) workDone() { + select { + case s.chWorkDone <- struct{}{}: + default: + } +} + +// WorkDone isn't part of the SleeperTask interface, but can be +// useful in tests to assert that the work has been done. +func (s *SleeperTask) WorkDone() <-chan struct{} { + return s.chWorkDone +} + +func (s *SleeperTask) workerLoop() { + defer close(s.chDone) + + for { + select { + case <-s.chQueue: + s.worker.Work() + s.workDone() + case <-s.chStop: + return + } + } +} + +type sleeperTaskWorker struct { + name string + work func() +} + +// SleeperFuncTask returns a Worker to execute the given work function. +func SleeperFuncTask(work func(), name string) Worker { + return &sleeperTaskWorker{name: name, work: work} +} + +func (w *sleeperTaskWorker) Name() string { return w.name } +func (w *sleeperTaskWorker) Work() { w.work() } diff --git a/pkg/utils/sleeper_task_test.go b/pkg/utils/sleeper_task_test.go new file mode 100644 index 000000000..3cc233fbc --- /dev/null +++ b/pkg/utils/sleeper_task_test.go @@ -0,0 +1,159 @@ +package utils_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" +) + +type chanWorker struct { + ch chan struct{} + delay time.Duration +} + +func (t *chanWorker) Name() string { + return "ChanWorker" +} + +func (t *chanWorker) Work() { + if t.delay != 0 { + time.Sleep(t.delay) + } + t.ch <- struct{}{} +} + +func TestSleeperTask_WakeupAfterStopPanics(t *testing.T) { + t.Parallel() + + worker := &chanWorker{ch: make(chan struct{}, 1)} + sleeper := utils.NewSleeperTask(worker) + + require.NoError(t, sleeper.Stop()) + + require.Panics(t, func() { + sleeper.WakeUp() + }) + + select { + case <-worker.ch: + t.Fatal("work was performed when none was expected") + default: + } +} + +func TestSleeperTask_CallingStopTwiceFails(t *testing.T) { + t.Parallel() + + worker := &chanWorker{} + sleeper := utils.NewSleeperTask(worker) + require.NoError(t, sleeper.Stop()) + require.Error(t, sleeper.Stop()) +} + +func TestSleeperTask_WakeupPerformsWork(t *testing.T) { + t.Parallel() + ctx := tests.Context(t) + + worker := &chanWorker{ch: make(chan struct{}, 1)} + sleeper := utils.NewSleeperTask(worker) + + sleeper.WakeUp() + + select { + case <-worker.ch: + case <-ctx.Done(): + t.Error("timed out waiting for work to be performed") + } + + require.NoError(t, sleeper.Stop()) +} + +type controllableWorker struct { + chanWorker + awaitWorkStarted chan struct{} + allowResumeWork chan struct{} + ignoreSignals bool +} + +func (w *controllableWorker) Work() { + if !w.ignoreSignals { + w.awaitWorkStarted <- struct{}{} + <-w.allowResumeWork + } + w.chanWorker.Work() +} + +func TestSleeperTask_WakeupEnqueuesMaxTwice(t *testing.T) { + t.Parallel() + ctx := tests.Context(t) + + worker := &controllableWorker{chanWorker: chanWorker{ch: make(chan struct{}, 1)}, awaitWorkStarted: make(chan struct{}), allowResumeWork: make(chan struct{})} + sleeper := utils.NewSleeperTask(worker) + + sleeper.WakeUp() + <-worker.awaitWorkStarted + sleeper.WakeUp() + sleeper.WakeUp() + sleeper.WakeUp() + sleeper.WakeUp() + sleeper.WakeUp() + worker.ignoreSignals = true + worker.allowResumeWork <- struct{}{} + + for i := 0; i < 2; i++ { + select { + case <-worker.ch: + case <-ctx.Done(): + t.Error("timed out waiting for work to be performed") + } + } + + if !t.Failed() { + select { + case <-worker.ch: + t.Errorf("unexpected work performed") + case <-time.After(time.Second): + } + } + + require.NoError(t, sleeper.Stop()) +} + +func TestSleeperTask_StopWaitsUntilWorkFinishes(t *testing.T) { + t.Parallel() + + worker := &controllableWorker{chanWorker: chanWorker{ch: make(chan struct{}, 1)}, awaitWorkStarted: make(chan struct{}), allowResumeWork: make(chan struct{})} + sleeper := utils.NewSleeperTask(worker) + + sleeper.WakeUp() + <-worker.awaitWorkStarted + + select { + case <-worker.ch: + t.Error("work was performed when none was expected") + assert.NoError(t, sleeper.Stop()) + return + default: + } + + worker.allowResumeWork <- struct{}{} + + require.NoError(t, sleeper.Stop()) + + select { + case <-worker.ch: + default: + t.Fatal("work should have been performed") + } + + select { + case <-worker.ch: + t.Fatal("extra work was performed") + default: + } +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 57ea147c3..9a98ea4b9 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,8 +2,10 @@ package utils import ( "context" + "fmt" "math" mrand "math/rand" + "sync" "time" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -49,3 +51,76 @@ func IsZero[C comparable](val C) bool { var zero C return zero == val } + +// JustError takes a tuple and returns the last entry, the error. +func JustError(_ interface{}, err error) error { + return err +} + +// WrapIfError decorates an error with the given message. It is intended to +// be used with `defer` statements, like so: +// +// func SomeFunction() (err error) { +// defer WrapIfError(&err, "error in SomeFunction:") +// +// ... +// } +func WrapIfError(err *error, msg string) { + if *err != nil { + *err = fmt.Errorf("%s: %w", msg, *err) + } +} + +// AllEqual returns true iff all the provided elements are equal to each other. +func AllEqual[T comparable](elems ...T) bool { + for i := 1; i < len(elems); i++ { + if elems[i] != elems[0] { + return false + } + } + return true +} + +// WaitGroupChan creates a channel that closes when the provided sync.WaitGroup is done. +func WaitGroupChan(wg *sync.WaitGroup) <-chan struct{} { + chAwait := make(chan struct{}) + go func() { + defer close(chAwait) + wg.Wait() + }() + return chAwait +} + +// DependentAwaiter contains Dependent funcs +type DependentAwaiter interface { + AwaitDependents() <-chan struct{} + AddDependents(n int) + DependentReady() +} + +type dependentAwaiter struct { + wg *sync.WaitGroup + ch <-chan struct{} +} + +// NewDependentAwaiter creates a new DependentAwaiter +func NewDependentAwaiter() DependentAwaiter { + return &dependentAwaiter{ + wg: &sync.WaitGroup{}, + } +} + +func (da *dependentAwaiter) AwaitDependents() <-chan struct{} { + if da.ch == nil { + da.ch = WaitGroupChan(da.wg) + } + return da.ch +} + +func (da *dependentAwaiter) AddDependents(n int) { + da.wg.Add(n) +} + +func (da *dependentAwaiter) DependentReady() { + da.wg.Done() +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 000000000..93601638a --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,114 @@ +package utils_test + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/utils" +) + +func TestWrapIfError(t *testing.T) { + t.Parallel() + + t.Run("wraps error", func(t *testing.T) { + err := errors.New("this is an error") + utils.WrapIfError(&err, "wrapped message") + assert.Equal(t, "wrapped message: this is an error", err.Error()) + }) +} + +func TestAllEqual(t *testing.T) { + t.Parallel() + + require.False(t, utils.AllEqual(1, 2, 3, 4, 5)) + require.True(t, utils.AllEqual(1, 1, 1, 1, 1)) + require.False(t, utils.AllEqual(1, 1, 1, 2, 1, 1, 1)) +} + +func TestWaitGroupChan(t *testing.T) { + t.Parallel() + + wg := &sync.WaitGroup{} + wg.Add(2) + + ch := utils.WaitGroupChan(wg) + + select { + case <-ch: + t.Fatal("should not fire immediately") + default: + } + + wg.Done() + + select { + case <-ch: + t.Fatal("should not fire until finished") + default: + } + + go func() { + time.Sleep(2 * time.Second) + wg.Done() + }() + + callbackOrTimeout(t, "WaitGroupChan fires", func() { + <-ch + }, 5*time.Second) +} + +func TestDependentAwaiter(t *testing.T) { + t.Parallel() + + da := utils.NewDependentAwaiter() + da.AddDependents(2) + + select { + case <-da.AwaitDependents(): + t.Fatal("should not fire immediately") + default: + } + + da.DependentReady() + + select { + case <-da.AwaitDependents(): + t.Fatal("should not fire until finished") + default: + } + + go func() { + time.Sleep(2 * time.Second) + da.DependentReady() + }() + + callbackOrTimeout(t, "dependents are now ready", func() { + <-da.AwaitDependents() + }, 5*time.Second) +} + +func callbackOrTimeout(t testing.TB, msg string, callback func(), durationParams ...time.Duration) { + t.Helper() + + duration := 100 * time.Millisecond + if len(durationParams) > 0 { + duration = durationParams[0] + } + + done := make(chan struct{}) + go func() { + callback() + close(done) + }() + + select { + case <-done: + case <-time.After(duration): + t.Fatalf("CallbackOrTimeout: %s timed out", msg) + } +}