diff --git a/common/txmgr/broadcaster.go b/common/txmgr/broadcaster.go index 9f2204f37e2..4b5834d21a8 100644 --- a/common/txmgr/broadcaster.go +++ b/common/txmgr/broadcaster.go @@ -208,7 +208,7 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) star return errors.New("Broadcaster is already started") } var err error - eb.enabledAddresses, err = eb.ks.EnabledAddressesForChain(eb.chainID) + eb.enabledAddresses, err = eb.ks.EnabledAddressesForChain(ctx, eb.chainID) if err != nil { return fmt.Errorf("Broadcaster: failed to load EnabledAddressesForChain: %w", err) } diff --git a/common/txmgr/confirmer.go b/common/txmgr/confirmer.go index c28216467a1..073a6b90fa4 100644 --- a/common/txmgr/confirmer.go +++ b/common/txmgr/confirmer.go @@ -183,7 +183,7 @@ func NewConfirmer[ } // Start is a comment to appease the linter -func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(_ context.Context) error { +func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx context.Context) error { return ec.StartOnce("Confirmer", func() error { if ec.feeConfig.BumpThreshold() == 0 { ec.lggr.Infow("Gas bumping is disabled (FeeEstimator.BumpThreshold set to 0)", "feeBumpThreshold", 0) @@ -191,18 +191,18 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Sta ec.lggr.Infow(fmt.Sprintf("Fee bumping is enabled, unconfirmed transactions will have their fee bumped every %d blocks", ec.feeConfig.BumpThreshold()), "feeBumpThreshold", ec.feeConfig.BumpThreshold()) } - return ec.startInternal() + return ec.startInternal(ctx) }) } -func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) startInternal() error { +func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) startInternal(ctx context.Context) error { ec.initSync.Lock() defer ec.initSync.Unlock() if ec.isStarted { return errors.New("Confirmer is already started") } var err error - ec.enabledAddresses, err = ec.ks.EnabledAddressesForChain(ec.chainID) + ec.enabledAddresses, err = ec.ks.EnabledAddressesForChain(ctx, ec.chainID) if err != nil { return fmt.Errorf("Confirmer: failed to load EnabledAddressesForChain: %w", err) } @@ -1065,7 +1065,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) For if overrideGasLimit != 0 { etx.FeeLimit = overrideGasLimit } - attempt, _, err := ec.NewCustomTxAttempt(*etx, fee, etx.FeeLimit, 0x0, ec.lggr) + attempt, _, err := ec.NewCustomTxAttempt(ctx, *etx, fee, etx.FeeLimit, 0x0, ec.lggr) if err != nil { ec.lggr.Errorw("ForceRebroadcast: failed to create new attempt", "txID", etx.ID, "err", err) continue diff --git a/common/txmgr/resender.go b/common/txmgr/resender.go index 384c0c7a2c0..8c2dd6b827e 100644 --- a/common/txmgr/resender.go +++ b/common/txmgr/resender.go @@ -102,7 +102,7 @@ func NewResender[ } // Start is a comment which satisfies the linter -func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start() { +func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx context.Context) { er.logger.Debugf("Enabled with poll interval of %s and age threshold of %s", er.interval, er.txConfig.ResendAfterThreshold()) go er.runLoop() } @@ -116,7 +116,7 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Stop() { func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { defer close(er.chDone) - if err := er.resendUnconfirmed(); err != nil { + if err := er.resendUnconfirmed(er.ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } @@ -127,15 +127,15 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() case <-er.ctx.Done(): return case <-ticker.C: - if err := er.resendUnconfirmed(); err != nil { + if err := er.resendUnconfirmed(er.ctx); err != nil { er.logger.Warnw("Failed to resend unconfirmed transactions", "err", err) } } } } -func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnconfirmed() error { - resendAddresses, err := er.ks.EnabledAddressesForChain(er.chainID) +func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnconfirmed(ctx context.Context) error { + resendAddresses, err := er.ks.EnabledAddressesForChain(ctx, er.chainID) if err != nil { return fmt.Errorf("Resender failed getting enabled keys for chain %s: %w", er.chainID.String(), err) } diff --git a/common/txmgr/test_helpers.go b/common/txmgr/test_helpers.go index 6c0c5680ea7..dbc07861ffe 100644 --- a/common/txmgr/test_helpers.go +++ b/common/txmgr/test_helpers.go @@ -35,7 +35,7 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) XXXT } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestStartInternal() error { - return ec.startInternal() + return ec.startInternal(ec.ctx) } func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestCloseInternal() error { @@ -43,7 +43,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXX } func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestResendUnconfirmed() error { - return er.resendUnconfirmed() + return er.resendUnconfirmed(er.ctx) } func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) XXXTestAbandon(addr ADDR) (err error) { diff --git a/common/txmgr/tracker.go b/common/txmgr/tracker.go index 8b66668c41e..c63d9c264fc 100644 --- a/common/txmgr/tracker.go +++ b/common/txmgr/tracker.go @@ -91,25 +91,25 @@ func NewTracker[ } } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(_ context.Context) (err error) { +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx context.Context) (err error) { tr.lggr.Info("Abandoned transaction tracking enabled") return tr.StartOnce("Tracker", func() error { - return tr.startInternal() + return tr.startInternal(ctx) }) } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) startInternal() (err error) { +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) startInternal(ctx context.Context) (err error) { tr.lock.Lock() defer tr.lock.Unlock() tr.ctx, tr.ctxCancel = context.WithCancel(context.Background()) - if err := tr.setEnabledAddresses(); err != nil { + if err := tr.setEnabledAddresses(ctx); err != nil { return fmt.Errorf("failed to set enabled addresses: %w", err) } tr.lggr.Info("Enabled addresses set") - if err := tr.trackAbandonedTxes(tr.ctx); err != nil { + if err := tr.trackAbandonedTxes(ctx); err != nil { return fmt.Errorf("failed to track abandoned txes: %w", err) } @@ -194,8 +194,8 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) IsStarted() return tr.isStarted } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) setEnabledAddresses() error { - enabledAddrs, err := tr.keyStore.EnabledAddressesForChain(tr.chainID) +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) setEnabledAddresses(ctx context.Context) error { + enabledAddrs, err := tr.keyStore.EnabledAddressesForChain(ctx, tr.chainID) if err != nil { return fmt.Errorf("failed to get enabled addresses for chain: %w", err) } diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 3e3fa9a20db..d0b33ad7e30 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -209,7 +209,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx } if b.resender != nil { - b.resender.Start() + b.resender.Start(ctx) } if b.fwdMgr != nil { @@ -308,10 +308,13 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) HealthRepo } func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { + ctx, cancel := b.chStop.NewCtx() + defer cancel() + // eb, ec and keyStates can all be modified by the runloop. // This is concurrent-safe because the runloop ensures serial access. defer b.wg.Done() - keysChanged, unsub := b.keyStore.SubscribeToKeyChanges() + keysChanged, unsub := b.keyStore.SubscribeToKeyChanges(ctx) defer unsub() close(b.chSubbed) @@ -321,7 +324,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() // execReset is defined as an inline function here because it closes over // eb, ec and stopped - execReset := func(r *reset) { + execReset := func(ctx context.Context, r *reset) { // These should always close successfully, since it should be logically // impossible to enter this code path with ec/eb in a state other than // "Started" @@ -348,8 +351,6 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() wg.Add(2) go func() { defer wg.Done() - ctx, cancel := b.chStop.NewCtx() - defer cancel() // Retry indefinitely on failure backoff := iutils.NewRedialBackoff() for { @@ -361,7 +362,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() continue } return - case <-ctx.Done(): + case <-b.chStop: stopOnce.Do(func() { stopped = true }) return } @@ -374,7 +375,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() for { select { case <-time.After(backoff.Duration()): - if err := b.confirmer.startInternal(); err != nil { + if err := b.confirmer.startInternal(ctx); err != nil { b.logger.Criticalw("Failed to start Confirmer", "err", err) b.SvcErrBuffer.Append(err) continue @@ -408,7 +409,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() reset.done <- errors.New("Txm was stopped") continue } - execReset(&reset) + execReset(ctx, &reset) case <-b.chStop: // close and exit // @@ -441,7 +442,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() if stopped { continue } - enabledAddresses, err := b.keyStore.EnabledAddressesForChain(b.chainID) + enabledAddresses, err := b.keyStore.EnabledAddressesForChain(ctx, b.chainID) if err != nil { b.logger.Critical("Failed to reload key states after key change") b.SvcErrBuffer.Append(err) @@ -449,7 +450,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() } b.logger.Debugw("Keys changed, reloading", "enabledAddresses", enabledAddresses) - execReset(nil) + execReset(ctx, nil) } } } @@ -496,7 +497,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) CreateTran } } - if err = b.checkEnabled(txRequest.FromAddress); err != nil { + if err = b.checkEnabled(ctx, txRequest.FromAddress); err != nil { return tx, err } @@ -543,8 +544,8 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForward return } -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) checkEnabled(addr ADDR) error { - if err := b.keyStore.CheckEnabled(addr, b.chainID); err != nil { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) checkEnabled(ctx context.Context, addr ADDR) error { + if err := b.keyStore.CheckEnabled(ctx, addr, b.chainID); err != nil { return fmt.Errorf("cannot send transaction from %s on chain ID %s: %w", addr, b.chainID.String(), err) } return nil diff --git a/common/txmgr/types/client.go b/common/txmgr/types/client.go index 0db50e97ad3..32527e5896e 100644 --- a/common/txmgr/types/client.go +++ b/common/txmgr/types/client.go @@ -62,7 +62,7 @@ type TransactionClient[ ) (client.SendTxReturnCode, error) SendEmptyTransaction( ctx context.Context, - newTxAttempt func(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error), + newTxAttempt func(ctx context.Context, seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error), seq SEQ, gasLimit uint32, fee FEE, diff --git a/common/txmgr/types/keystore.go b/common/txmgr/types/keystore.go index 9c5b8cfce37..0eecc49be70 100644 --- a/common/txmgr/types/keystore.go +++ b/common/txmgr/types/keystore.go @@ -1,6 +1,8 @@ package types import ( + "context" + "github.com/smartcontractkit/chainlink/v2/common/types" ) @@ -15,7 +17,7 @@ type KeyStore[ // Chain's sequence type. For example, EVM chains use nonce, bitcoin uses UTXO. SEQ types.Sequence, ] interface { - CheckEnabled(address ADDR, chainID CHAIN_ID) error - EnabledAddressesForChain(chainId CHAIN_ID) ([]ADDR, error) - SubscribeToKeyChanges() (ch chan struct{}, unsub func()) + CheckEnabled(ctx context.Context, address ADDR, chainID CHAIN_ID) error + EnabledAddressesForChain(ctx context.Context, chainId CHAIN_ID) ([]ADDR, error) + SubscribeToKeyChanges(ctx context.Context) (ch chan struct{}, unsub func()) } diff --git a/common/txmgr/types/mocks/key_store.go b/common/txmgr/types/mocks/key_store.go index d440528a41d..7e825322977 100644 --- a/common/txmgr/types/mocks/key_store.go +++ b/common/txmgr/types/mocks/key_store.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + mock "github.com/stretchr/testify/mock" types "github.com/smartcontractkit/chainlink/v2/common/types" @@ -13,17 +15,17 @@ type KeyStore[ADDR types.Hashable, CHAIN_ID types.ID, SEQ types.Sequence] struct mock.Mock } -// CheckEnabled provides a mock function with given fields: address, chainID -func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) CheckEnabled(address ADDR, chainID CHAIN_ID) error { - ret := _m.Called(address, chainID) +// CheckEnabled provides a mock function with given fields: ctx, address, chainID +func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) CheckEnabled(ctx context.Context, address ADDR, chainID CHAIN_ID) error { + ret := _m.Called(ctx, address, chainID) if len(ret) == 0 { panic("no return value specified for CheckEnabled") } var r0 error - if rf, ok := ret.Get(0).(func(ADDR, CHAIN_ID) error); ok { - r0 = rf(address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, CHAIN_ID) error); ok { + r0 = rf(ctx, address, chainID) } else { r0 = ret.Error(0) } @@ -31,9 +33,9 @@ func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) CheckEnabled(address ADDR, chainID CHAI return r0 } -// EnabledAddressesForChain provides a mock function with given fields: chainId -func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) EnabledAddressesForChain(chainId CHAIN_ID) ([]ADDR, error) { - ret := _m.Called(chainId) +// EnabledAddressesForChain provides a mock function with given fields: ctx, chainId +func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) EnabledAddressesForChain(ctx context.Context, chainId CHAIN_ID) ([]ADDR, error) { + ret := _m.Called(ctx, chainId) if len(ret) == 0 { panic("no return value specified for EnabledAddressesForChain") @@ -41,19 +43,19 @@ func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) EnabledAddressesForChain(chainId CHAIN_ var r0 []ADDR var r1 error - if rf, ok := ret.Get(0).(func(CHAIN_ID) ([]ADDR, error)); ok { - return rf(chainId) + if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID) ([]ADDR, error)); ok { + return rf(ctx, chainId) } - if rf, ok := ret.Get(0).(func(CHAIN_ID) []ADDR); ok { - r0 = rf(chainId) + if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID) []ADDR); ok { + r0 = rf(ctx, chainId) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ADDR) } } - if rf, ok := ret.Get(1).(func(CHAIN_ID) error); ok { - r1 = rf(chainId) + if rf, ok := ret.Get(1).(func(context.Context, CHAIN_ID) error); ok { + r1 = rf(ctx, chainId) } else { r1 = ret.Error(1) } @@ -61,9 +63,9 @@ func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) EnabledAddressesForChain(chainId CHAIN_ return r0, r1 } -// SubscribeToKeyChanges provides a mock function with given fields: -func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) SubscribeToKeyChanges() (chan struct{}, func()) { - ret := _m.Called() +// SubscribeToKeyChanges provides a mock function with given fields: ctx +func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) SubscribeToKeyChanges(ctx context.Context) (chan struct{}, func()) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for SubscribeToKeyChanges") @@ -71,19 +73,19 @@ func (_m *KeyStore[ADDR, CHAIN_ID, SEQ]) SubscribeToKeyChanges() (chan struct{}, var r0 chan struct{} var r1 func() - if rf, ok := ret.Get(0).(func() (chan struct{}, func())); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (chan struct{}, func())); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() chan struct{}); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) chan struct{}); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(chan struct{}) } } - if rf, ok := ret.Get(1).(func() func()); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) func()); ok { + r1 = rf(ctx) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(func()) diff --git a/common/txmgr/types/mocks/tx_attempt_builder.go b/common/txmgr/types/mocks/tx_attempt_builder.go index b3b6ff761fb..5b9b3e505ad 100644 --- a/common/txmgr/types/mocks/tx_attempt_builder.go +++ b/common/txmgr/types/mocks/tx_attempt_builder.go @@ -125,9 +125,9 @@ func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) return r0, r1, r2, r3, r4 } -// NewCustomTxAttempt provides a mock function with given fields: tx, fee, gasLimit, txType, lggr -func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) NewCustomTxAttempt(tx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], fee FEE, gasLimit uint32, txType int, lggr logger.Logger) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], bool, error) { - ret := _m.Called(tx, fee, gasLimit, txType, lggr) +// NewCustomTxAttempt provides a mock function with given fields: ctx, tx, fee, gasLimit, txType, lggr +func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) NewCustomTxAttempt(ctx context.Context, tx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], fee FEE, gasLimit uint32, txType int, lggr logger.Logger) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], bool, error) { + ret := _m.Called(ctx, tx, fee, gasLimit, txType, lggr) if len(ret) == 0 { panic("no return value specified for NewCustomTxAttempt") @@ -136,23 +136,23 @@ func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) var r0 txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 bool var r2 error - if rf, ok := ret.Get(0).(func(txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], bool, error)); ok { - return rf(tx, fee, gasLimit, txType, lggr) + if rf, ok := ret.Get(0).(func(context.Context, txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], bool, error)); ok { + return rf(ctx, tx, fee, gasLimit, txType, lggr) } - if rf, ok := ret.Get(0).(func(txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { - r0 = rf(tx, fee, gasLimit, txType, lggr) + if rf, ok := ret.Get(0).(func(context.Context, txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + r0 = rf(ctx, tx, fee, gasLimit, txType, lggr) } else { r0 = ret.Get(0).(txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) } - if rf, ok := ret.Get(1).(func(txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) bool); ok { - r1 = rf(tx, fee, gasLimit, txType, lggr) + if rf, ok := ret.Get(1).(func(context.Context, txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) bool); ok { + r1 = rf(ctx, tx, fee, gasLimit, txType, lggr) } else { r1 = ret.Get(1).(bool) } - if rf, ok := ret.Get(2).(func(txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) error); ok { - r2 = rf(tx, fee, gasLimit, txType, lggr) + if rf, ok := ret.Get(2).(func(context.Context, txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], FEE, uint32, int, logger.Logger) error); ok { + r2 = rf(ctx, tx, fee, gasLimit, txType, lggr) } else { r2 = ret.Error(2) } @@ -160,9 +160,9 @@ func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) return r0, r1, r2 } -// NewEmptyTxAttempt provides a mock function with given fields: seq, feeLimit, fee, fromAddress -func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) NewEmptyTxAttempt(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - ret := _m.Called(seq, feeLimit, fee, fromAddress) +// NewEmptyTxAttempt provides a mock function with given fields: ctx, seq, feeLimit, fee, fromAddress +func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) NewEmptyTxAttempt(ctx context.Context, seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { + ret := _m.Called(ctx, seq, feeLimit, fee, fromAddress) if len(ret) == 0 { panic("no return value specified for NewEmptyTxAttempt") @@ -170,17 +170,17 @@ func (_m *TxAttemptBuilder[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) var r0 txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 error - if rf, ok := ret.Get(0).(func(SEQ, uint32, FEE, ADDR) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { - return rf(seq, feeLimit, fee, fromAddress) + if rf, ok := ret.Get(0).(func(context.Context, SEQ, uint32, FEE, ADDR) (txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { + return rf(ctx, seq, feeLimit, fee, fromAddress) } - if rf, ok := ret.Get(0).(func(SEQ, uint32, FEE, ADDR) txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { - r0 = rf(seq, feeLimit, fee, fromAddress) + if rf, ok := ret.Get(0).(func(context.Context, SEQ, uint32, FEE, ADDR) txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + r0 = rf(ctx, seq, feeLimit, fee, fromAddress) } else { r0 = ret.Get(0).(txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) } - if rf, ok := ret.Get(1).(func(SEQ, uint32, FEE, ADDR) error); ok { - r1 = rf(seq, feeLimit, fee, fromAddress) + if rf, ok := ret.Get(1).(func(context.Context, SEQ, uint32, FEE, ADDR) error); ok { + r1 = rf(ctx, seq, feeLimit, fee, fromAddress) } else { r1 = ret.Error(1) } diff --git a/common/txmgr/types/tx_attempt_builder.go b/common/txmgr/types/tx_attempt_builder.go index 383b6d862f0..47c71abea35 100644 --- a/common/txmgr/types/tx_attempt_builder.go +++ b/common/txmgr/types/tx_attempt_builder.go @@ -37,8 +37,8 @@ type TxAttemptBuilder[ NewBumpTxAttempt(ctx context.Context, tx Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], previousAttempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], priorAttempts []TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], lggr logger.Logger) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], bumpedFee FEE, bumpedFeeLimit uint32, retryable bool, err error) // NewCustomTxAttempt builds a transaction using the passed in fee + tx type - NewCustomTxAttempt(tx Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], fee FEE, gasLimit uint32, txType int, lggr logger.Logger) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], retryable bool, err error) + NewCustomTxAttempt(ctx context.Context, tx Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], fee FEE, gasLimit uint32, txType int, lggr logger.Logger) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], retryable bool, err error) // NewEmptyTxAttempt is used in ForceRebroadcast to create a signed tx with zero value sent to the zero address - NewEmptyTxAttempt(seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) + NewEmptyTxAttempt(ctx context.Context, seq SEQ, feeLimit uint32, fee FEE, fromAddress ADDR) (attempt TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) } diff --git a/core/chains/evm/monitor/balance.go b/core/chains/evm/monitor/balance.go index b0f0fbc9c91..bb271ad1d46 100644 --- a/core/chains/evm/monitor/balance.go +++ b/core/chains/evm/monitor/balance.go @@ -174,7 +174,7 @@ func (w *worker) Work() { } func (w *worker) WorkCtx(ctx context.Context) { - enabledAddresses, err := w.bm.ethKeyStore.EnabledAddressesForChain(w.bm.chainID) + enabledAddresses, err := w.bm.ethKeyStore.EnabledAddressesForChain(ctx, w.bm.chainID) if err != nil { w.bm.logger.Error("BalanceMonitor: error getting keys", err) } diff --git a/core/chains/evm/txmgr/attempts.go b/core/chains/evm/txmgr/attempts.go index 91645bae6f6..e37f0e4d2d8 100644 --- a/core/chains/evm/txmgr/attempts.go +++ b/core/chains/evm/txmgr/attempts.go @@ -20,7 +20,7 @@ import ( ) type TxAttemptSigner[ADDR commontypes.Hashable] interface { - SignTx(fromAddress ADDR, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) + SignTx(ctx context.Context, fromAddress ADDR, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) } var _ TxAttemptBuilder = (*evmTxAttemptBuilder)(nil) @@ -62,7 +62,7 @@ func (c *evmTxAttemptBuilder) NewTxAttemptWithType(ctx context.Context, etx Tx, return attempt, fee, feeLimit, true, errors.Wrap(err, "failed to get fee") // estimator errors are retryable } - attempt, retryable, err = c.NewCustomTxAttempt(etx, fee, feeLimit, txType, lggr) + attempt, retryable, err = c.NewCustomTxAttempt(ctx, etx, fee, feeLimit, txType, lggr) return attempt, fee, feeLimit, retryable, err } @@ -76,13 +76,13 @@ func (c *evmTxAttemptBuilder) NewBumpTxAttempt(ctx context.Context, etx Tx, prev return attempt, bumpedFee, bumpedFeeLimit, true, errors.Wrap(err, "failed to bump fee") // estimator errors are retryable } - attempt, retryable, err = c.NewCustomTxAttempt(etx, bumpedFee, bumpedFeeLimit, previousAttempt.TxType, lggr) + attempt, retryable, err = c.NewCustomTxAttempt(ctx, etx, bumpedFee, bumpedFeeLimit, previousAttempt.TxType, lggr) return attempt, bumpedFee, bumpedFeeLimit, retryable, err } // NewCustomTxAttempt is the lowest level func where the fee parameters + tx type must be passed in // used in the txm for force rebroadcast where fees and tx type are pre-determined without an estimator -func (c *evmTxAttemptBuilder) NewCustomTxAttempt(etx Tx, fee gas.EvmFee, gasLimit uint32, txType int, lggr logger.Logger) (attempt TxAttempt, retryable bool, err error) { +func (c *evmTxAttemptBuilder) NewCustomTxAttempt(ctx context.Context, etx Tx, fee gas.EvmFee, gasLimit uint32, txType int, lggr logger.Logger) (attempt TxAttempt, retryable bool, err error) { switch txType { case 0x0: // legacy if fee.Legacy == nil { @@ -90,7 +90,7 @@ func (c *evmTxAttemptBuilder) NewCustomTxAttempt(etx Tx, fee gas.EvmFee, gasLimi logger.Sugared(lggr).AssumptionViolation(err.Error()) return attempt, false, err // not retryable } - attempt, err = c.newLegacyAttempt(etx, fee.Legacy, gasLimit) + attempt, err = c.newLegacyAttempt(ctx, etx, fee.Legacy, gasLimit) return attempt, true, err case 0x2: // dynamic, EIP1559 if !fee.ValidDynamic() { @@ -98,7 +98,7 @@ func (c *evmTxAttemptBuilder) NewCustomTxAttempt(etx Tx, fee gas.EvmFee, gasLimi logger.Sugared(lggr).AssumptionViolation(err.Error()) return attempt, false, err // not retryable } - attempt, err = c.newDynamicFeeAttempt(etx, gas.DynamicFee{ + attempt, err = c.newDynamicFeeAttempt(ctx, etx, gas.DynamicFee{ FeeCap: fee.DynamicFeeCap, TipCap: fee.DynamicTipCap, }, gasLimit) @@ -112,7 +112,7 @@ func (c *evmTxAttemptBuilder) NewCustomTxAttempt(etx Tx, fee gas.EvmFee, gasLimi } // NewEmptyTxAttempt is used in ForceRebroadcast to create a signed tx with zero value sent to the zero address -func (c *evmTxAttemptBuilder) NewEmptyTxAttempt(nonce evmtypes.Nonce, feeLimit uint32, fee gas.EvmFee, fromAddress common.Address) (attempt TxAttempt, err error) { +func (c *evmTxAttemptBuilder) NewEmptyTxAttempt(ctx context.Context, nonce evmtypes.Nonce, feeLimit uint32, fee gas.EvmFee, fromAddress common.Address) (attempt TxAttempt, err error) { value := big.NewInt(0) payload := []byte{} @@ -130,7 +130,7 @@ func (c *evmTxAttemptBuilder) NewEmptyTxAttempt(nonce evmtypes.Nonce, feeLimit u ) transaction := types.NewTx(&tx) - hash, signedTxBytes, err := c.SignTx(fromAddress, transaction) + hash, signedTxBytes, err := c.SignTx(ctx, fromAddress, transaction) if err != nil { return attempt, errors.Wrapf(err, "error using account %s to sign empty transaction", fromAddress.String()) } @@ -141,7 +141,7 @@ func (c *evmTxAttemptBuilder) NewEmptyTxAttempt(nonce evmtypes.Nonce, feeLimit u } -func (c *evmTxAttemptBuilder) newDynamicFeeAttempt(etx Tx, fee gas.DynamicFee, gasLimit uint32) (attempt TxAttempt, err error) { +func (c *evmTxAttemptBuilder) newDynamicFeeAttempt(ctx context.Context, etx Tx, fee gas.DynamicFee, gasLimit uint32) (attempt TxAttempt, err error) { if err = validateDynamicFeeGas(c.feeConfig, c.feeConfig.TipCapMin(), fee, gasLimit, etx); err != nil { return attempt, errors.Wrap(err, "error validating gas") } @@ -157,7 +157,7 @@ func (c *evmTxAttemptBuilder) newDynamicFeeAttempt(etx Tx, fee gas.DynamicFee, g etx.EncodedPayload, ) tx := types.NewTx(&d) - attempt, err = c.newSignedAttempt(etx, tx) + attempt, err = c.newSignedAttempt(ctx, etx, tx) if err != nil { return attempt, err } @@ -226,8 +226,8 @@ func newDynamicFeeTransaction(nonce uint64, to common.Address, value *big.Int, g } } -func (c *evmTxAttemptBuilder) newLegacyAttempt(etx Tx, gasPrice *assets.Wei, gasLimit uint32) (attempt TxAttempt, err error) { - if err = validateLegacyGas(c.feeConfig, c.feeConfig.PriceMin(), gasPrice, gasLimit, etx); err != nil { +func (c *evmTxAttemptBuilder) newLegacyAttempt(ctx context.Context, etx Tx, gasPrice *assets.Wei, gasLimit uint32) (attempt TxAttempt, err error) { + if err = validateLegacyGas(ctx, c.feeConfig, c.feeConfig.PriceMin(), gasPrice, gasLimit, etx); err != nil { return attempt, errors.Wrap(err, "error validating gas") } @@ -241,7 +241,7 @@ func (c *evmTxAttemptBuilder) newLegacyAttempt(etx Tx, gasPrice *assets.Wei, gas ) transaction := types.NewTx(&tx) - hash, signedTxBytes, err := c.SignTx(etx.FromAddress, transaction) + hash, signedTxBytes, err := c.SignTx(ctx, etx.FromAddress, transaction) if err != nil { return attempt, errors.Wrapf(err, "error using account %s to sign transaction %v", etx.FromAddress, etx.ID) } @@ -260,7 +260,7 @@ func (c *evmTxAttemptBuilder) newLegacyAttempt(etx Tx, gasPrice *assets.Wei, gas // validateLegacyGas is a sanity check - we have other checks elsewhere, but this // makes sure we _never_ create an invalid attempt -func validateLegacyGas(kse keySpecificEstimator, minGasPriceWei, gasPrice *assets.Wei, gasLimit uint32, etx Tx) error { +func validateLegacyGas(ctx context.Context, kse keySpecificEstimator, minGasPriceWei, gasPrice *assets.Wei, gasLimit uint32, etx Tx) error { if gasPrice == nil { panic("gas price missing") } @@ -275,8 +275,8 @@ func validateLegacyGas(kse keySpecificEstimator, minGasPriceWei, gasPrice *asset return nil } -func (c *evmTxAttemptBuilder) newSignedAttempt(etx Tx, tx *types.Transaction) (attempt TxAttempt, err error) { - hash, signedTxBytes, err := c.SignTx(etx.FromAddress, tx) +func (c *evmTxAttemptBuilder) newSignedAttempt(ctx context.Context, etx Tx, tx *types.Transaction) (attempt TxAttempt, err error) { + hash, signedTxBytes, err := c.SignTx(ctx, etx.FromAddress, tx) if err != nil { return attempt, errors.Wrapf(err, "error using account %s to sign transaction %v", etx.FromAddress.String(), etx.ID) } @@ -301,8 +301,8 @@ func newLegacyTransaction(nonce uint64, to common.Address, value *big.Int, gasLi } } -func (c *evmTxAttemptBuilder) SignTx(address common.Address, tx *types.Transaction) (common.Hash, []byte, error) { - signedTx, err := c.keystore.SignTx(address, tx, &c.chainID) +func (c *evmTxAttemptBuilder) SignTx(ctx context.Context, address common.Address, tx *types.Transaction) (common.Hash, []byte, error) { + signedTx, err := c.keystore.SignTx(ctx, address, tx, &c.chainID) if err != nil { return common.Hash{}, nil, fmt.Errorf("failed to sign tx: %w", err) } diff --git a/core/chains/evm/txmgr/attempts_test.go b/core/chains/evm/txmgr/attempts_test.go index c7373e2c4f6..e0b3cd59ce7 100644 --- a/core/chains/evm/txmgr/attempts_test.go +++ b/core/chains/evm/txmgr/attempts_test.go @@ -68,9 +68,9 @@ func TestTxm_SignTx(t *testing.T) { t.Run("returns correct hash for non-okex chains", func(t *testing.T) { chainID := big.NewInt(1) kst := ksmocks.NewEth(t) - kst.On("SignTx", to, tx, chainID).Return(tx, nil).Once() + kst.On("SignTx", mock.Anything, to, tx, chainID).Return(tx, nil).Once() cks := txmgr.NewEvmTxAttemptBuilder(*chainID, newFeeConfig(), kst, nil) - hash, rawBytes, err := cks.SignTx(addr, tx) + hash, rawBytes, err := cks.SignTx(testutils.Context(t), addr, tx) require.NoError(t, err) require.NotNil(t, rawBytes) require.Equal(t, "0xdd68f554373fdea7ec6713a6e437e7646465d553a6aa0b43233093366cc87ef0", hash.String()) @@ -79,9 +79,9 @@ func TestTxm_SignTx(t *testing.T) { t.Run("returns correct hash for okex chains", func(t *testing.T) { chainID := big.NewInt(1) kst := ksmocks.NewEth(t) - kst.On("SignTx", to, tx, chainID).Return(tx, nil).Once() + kst.On("SignTx", mock.Anything, to, tx, chainID).Return(tx, nil).Once() cks := txmgr.NewEvmTxAttemptBuilder(*chainID, newFeeConfig(), kst, nil) - hash, rawBytes, err := cks.SignTx(addr, tx) + hash, rawBytes, err := cks.SignTx(testutils.Context(t), addr, tx) require.NoError(t, err) require.NotNil(t, rawBytes) require.Equal(t, "0xdd68f554373fdea7ec6713a6e437e7646465d553a6aa0b43233093366cc87ef0", hash.String()) @@ -89,10 +89,10 @@ func TestTxm_SignTx(t *testing.T) { t.Run("can properly encoded and decode raw transaction for LegacyTx", func(t *testing.T) { chainID := big.NewInt(1) kst := ksmocks.NewEth(t) - kst.On("SignTx", to, tx, chainID).Return(tx, nil).Once() + kst.On("SignTx", mock.Anything, to, tx, chainID).Return(tx, nil).Once() cks := txmgr.NewEvmTxAttemptBuilder(*chainID, newFeeConfig(), kst, nil) - _, rawBytes, err := cks.SignTx(addr, tx) + _, rawBytes, err := cks.SignTx(testutils.Context(t), addr, tx) require.NoError(t, err) require.NotNil(t, rawBytes) require.Equal(t, "0xe42a82015681f294b921f7763960b296b9cbad586ff066a18d749724818e83010203808080", hexutil.Encode(rawBytes)) @@ -112,9 +112,9 @@ func TestTxm_SignTx(t *testing.T) { Gas: 242, Data: []byte{1, 2, 3}, }) - kst.On("SignTx", to, typedTx, chainID).Return(typedTx, nil).Once() + kst.On("SignTx", mock.Anything, to, typedTx, chainID).Return(typedTx, nil).Once() cks := txmgr.NewEvmTxAttemptBuilder(*chainID, newFeeConfig(), kst, nil) - _, rawBytes, err := cks.SignTx(addr, typedTx) + _, rawBytes, err := cks.SignTx(testutils.Context(t), addr, typedTx) require.NoError(t, err) require.NotNil(t, rawBytes) require.Equal(t, "0xa702e5802a808081f294b921f7763960b296b9cbad586ff066a18d749724818e83010203c0808080", hexutil.Encode(rawBytes)) @@ -130,7 +130,7 @@ func TestTxm_NewDynamicFeeTx(t *testing.T) { addr := NewEvmAddress() tx := types.NewTx(&types.DynamicFeeTx{}) kst := ksmocks.NewEth(t) - kst.On("SignTx", addr, mock.Anything, big.NewInt(1)).Return(tx, nil) + kst.On("SignTx", mock.Anything, addr, mock.Anything, big.NewInt(1)).Return(tx, nil) var n evmtypes.Nonce lggr := logger.Test(t) @@ -139,7 +139,7 @@ func TestTxm_NewDynamicFeeTx(t *testing.T) { feeCfg.priceMax = assets.GWei(200) cks := txmgr.NewEvmTxAttemptBuilder(*big.NewInt(1), feeCfg, kst, nil) dynamicFee := gas.DynamicFee{TipCap: assets.GWei(100), FeeCap: assets.GWei(200)} - a, _, err := cks.NewCustomTxAttempt(txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{ + a, _, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{ DynamicTipCap: dynamicFee.TipCap, DynamicFeeCap: dynamicFee.FeeCap, }, 100, 0x2, lggr) @@ -181,7 +181,7 @@ func TestTxm_NewDynamicFeeTx(t *testing.T) { cfg := evmtest.NewChainScopedConfig(t, gcfg) cks := txmgr.NewEvmTxAttemptBuilder(*big.NewInt(1), cfg.EVM().GasEstimator(), kst, nil) dynamicFee := gas.DynamicFee{TipCap: test.tipcap, FeeCap: test.feecap} - _, _, err := cks.NewCustomTxAttempt(txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{ + _, _, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{ DynamicTipCap: dynamicFee.TipCap, DynamicFeeCap: dynamicFee.FeeCap, }, 100, 0x2, lggr) @@ -199,7 +199,7 @@ func TestTxm_NewLegacyAttempt(t *testing.T) { addr := NewEvmAddress() kst := ksmocks.NewEth(t) tx := types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", addr, mock.Anything, big.NewInt(1)).Return(tx, nil) + kst.On("SignTx", mock.Anything, addr, mock.Anything, big.NewInt(1)).Return(tx, nil) gc := newFeeConfig() gc.priceMin = assets.NewWeiI(10) gc.priceMax = assets.NewWeiI(50) @@ -208,7 +208,7 @@ func TestTxm_NewLegacyAttempt(t *testing.T) { t.Run("creates attempt with fields", func(t *testing.T) { var n evmtypes.Nonce - a, _, err := cks.NewCustomTxAttempt(txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{Legacy: assets.NewWeiI(25)}, 100, 0x0, lggr) + a, _, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{Sequence: &n, FromAddress: addr}, gas.EvmFee{Legacy: assets.NewWeiI(25)}, 100, 0x0, lggr) require.NoError(t, err) assert.Equal(t, 100, int(a.ChainSpecificFeeLimit)) assert.NotNil(t, a.TxFee.Legacy) @@ -218,7 +218,7 @@ func TestTxm_NewLegacyAttempt(t *testing.T) { }) t.Run("verifies max gas price", func(t *testing.T) { - _, _, err := cks.NewCustomTxAttempt(txmgr.Tx{FromAddress: addr}, gas.EvmFee{Legacy: assets.NewWeiI(100)}, 100, 0x0, lggr) + _, _, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{FromAddress: addr}, gas.EvmFee{Legacy: assets.NewWeiI(100)}, 100, 0x0, lggr) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("specified gas price of 100 wei would exceed max configured gas price of 50 wei for key %s", addr.String())) }) @@ -235,7 +235,7 @@ func TestTxm_NewCustomTxAttempt_NonRetryableErrors(t *testing.T) { legacyFee := assets.NewWeiI(100) t.Run("dynamic fee with legacy tx type", func(t *testing.T) { - _, retryable, err := cks.NewCustomTxAttempt(txmgr.Tx{}, gas.EvmFee{ + _, retryable, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{}, gas.EvmFee{ DynamicTipCap: dynamicFee.TipCap, DynamicFeeCap: dynamicFee.FeeCap, }, 100, 0x0, lggr) @@ -243,13 +243,13 @@ func TestTxm_NewCustomTxAttempt_NonRetryableErrors(t *testing.T) { assert.False(t, retryable) }) t.Run("legacy fee with dynamic tx type", func(t *testing.T) { - _, retryable, err := cks.NewCustomTxAttempt(txmgr.Tx{}, gas.EvmFee{Legacy: legacyFee}, 100, 0x2, lggr) + _, retryable, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{}, gas.EvmFee{Legacy: legacyFee}, 100, 0x2, lggr) require.Error(t, err) assert.False(t, retryable) }) t.Run("invalid type", func(t *testing.T) { - _, retryable, err := cks.NewCustomTxAttempt(txmgr.Tx{}, gas.EvmFee{}, 100, 0xA, lggr) + _, retryable, err := cks.NewCustomTxAttempt(testutils.Context(t), txmgr.Tx{}, gas.EvmFee{}, 100, 0xA, lggr) require.Error(t, err) assert.False(t, retryable) }) diff --git a/core/chains/evm/txmgr/broadcaster_test.go b/core/chains/evm/txmgr/broadcaster_test.go index 67e9b0d8f04..2af22d76a66 100644 --- a/core/chains/evm/txmgr/broadcaster_test.go +++ b/core/chains/evm/txmgr/broadcaster_test.go @@ -1648,7 +1648,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_KeystoreErrors(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := NewTestEthBroadcaster(t, txStore, ethClient, kst, evmcfg, &testCheckerFactory{}, false) ctx := testutils.Context(t) @@ -1658,7 +1658,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_KeystoreErrors(t *testing.T) { t.Run("tx signing fails", func(t *testing.T) { etx := mustCreateUnstartedTx(t, txStore, fromAddress, toAddress, encodedPayload, gasLimit, value, &cltest.FixtureChainID) tx := *gethTypes.NewTx(&gethTypes.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.AnythingOfType("*types.Transaction"), mock.MatchedBy(func(chainID *big.Int) bool { @@ -1696,7 +1696,7 @@ func TestEthBroadcaster_GetNextNonce(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := NewTestEthBroadcaster(t, txStore, ethClient, kst, evmcfg, &testCheckerFactory{}, false) nonce := getLocalNextNonce(t, eb, fromAddress) @@ -1714,7 +1714,7 @@ func TestEthBroadcaster_IncrementNextNonce(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := NewTestEthBroadcaster(t, txStore, ethClient, kst, evmcfg, &testCheckerFactory{}, false) @@ -1777,7 +1777,7 @@ func TestEthBroadcaster_SyncNonce(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := txmgr.NewEvmBroadcaster(txStore, txmgr.NewEvmTxmClient(ethClient), evmTxmCfg, txmgr.NewEvmTxmFeeConfig(ge), evmcfg.EVM().Transactions(), cfg.Database().Listener(), kst, txBuilder, nil, lggr, checkerFactory, false) err := eb.Start(ctx) @@ -1795,7 +1795,7 @@ func TestEthBroadcaster_SyncNonce(t *testing.T) { txNonceSyncer := txmgr.NewNonceSyncer(txStore, lggr, ethClient) kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := txmgr.NewEvmBroadcaster(txStore, txmgr.NewEvmTxmClient(ethClient), evmTxmCfg, txmgr.NewEvmTxmFeeConfig(ge), evmcfg.EVM().Transactions(), cfg.Database().Listener(), kst, txBuilder, txNonceSyncer, lggr, checkerFactory, true) @@ -1824,7 +1824,7 @@ func TestEthBroadcaster_SyncNonce(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Once() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Once() ethClient.On("PendingNonceAt", mock.Anything, fromAddress).Return(uint64(0), nil).Once() eb := txmgr.NewEvmBroadcaster(txStore, txmgr.NewEvmTxmClient(ethClient), evmTxmCfg, txmgr.NewEvmTxmFeeConfig(evmcfg.EVM().GasEstimator()), evmcfg.EVM().Transactions(), cfg.Database().Listener(), kst, txBuilder, txNonceSyncer, lggr, checkerFactory, true) diff --git a/core/chains/evm/txmgr/client.go b/core/chains/evm/txmgr/client.go index 0aa03536276..e794f56ba31 100644 --- a/core/chains/evm/txmgr/client.go +++ b/core/chains/evm/txmgr/client.go @@ -145,7 +145,7 @@ func (c *evmTxmClient) BatchGetReceipts(ctx context.Context, attempts []TxAttemp // May be useful for clearing stuck nonces func (c *evmTxmClient) SendEmptyTransaction( ctx context.Context, - newTxAttempt func(seq evmtypes.Nonce, feeLimit uint32, fee gas.EvmFee, fromAddress common.Address) (attempt TxAttempt, err error), + newTxAttempt func(ctx context.Context, seq evmtypes.Nonce, feeLimit uint32, fee gas.EvmFee, fromAddress common.Address) (attempt TxAttempt, err error), seq evmtypes.Nonce, gasLimit uint32, fee gas.EvmFee, @@ -153,7 +153,7 @@ func (c *evmTxmClient) SendEmptyTransaction( ) (txhash string, err error) { defer utils.WrapIfError(&err, "sendEmptyTransaction failed") - attempt, err := newTxAttempt(seq, gasLimit, fee, fromAddress) + attempt, err := newTxAttempt(ctx, seq, gasLimit, fee, fromAddress) if err != nil { return txhash, err } diff --git a/core/chains/evm/txmgr/confirmer_test.go b/core/chains/evm/txmgr/confirmer_test.go index 6cb14a8d618..1860b557335 100644 --- a/core/chains/evm/txmgr/confirmer_test.go +++ b/core/chains/evm/txmgr/confirmer_test.go @@ -1646,7 +1646,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_WithConnectivityCheck(t *testing feeEstimator := gas.NewWrappedEvmEstimator(lggr, newEst, ge.EIP1559DynamicFees(), nil) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, kst, feeEstimator) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Maybe() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() // Create confirmer with necessary state ec := txmgr.NewEvmConfirmer(txStore, txmgr.NewEvmTxmClient(ethClient), ccfg.EVM(), txmgr.NewEvmTxmFeeConfig(ccfg.EVM().GasEstimator()), ccfg.EVM().Transactions(), cfg.Database(), kst, txBuilder, lggr) servicetest.Run(t, ec) @@ -1693,7 +1693,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_WithConnectivityCheck(t *testing feeEstimator := gas.NewWrappedEvmEstimator(lggr, newEst, ge.EIP1559DynamicFees(), nil) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, kst, feeEstimator) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Maybe() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() ec := txmgr.NewEvmConfirmer(txStore, txmgr.NewEvmTxmClient(ethClient), ccfg.EVM(), txmgr.NewEvmTxmFeeConfig(ccfg.EVM().GasEstimator()), ccfg.EVM().Transactions(), cfg.Database(), kst, txBuilder, lggr) servicetest.Run(t, ec) currentHead := int64(30) @@ -1738,7 +1738,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_MaxFeeScenario(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Maybe() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() // Use a mock keystore for this test ec := newEthConfirmer(t, txStore, ethClient, evmcfg, kst, nil) currentHead := int64(30) @@ -1753,7 +1753,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_MaxFeeScenario(t *testing.T) { t.Run("treats an exceeds max fee attempt as a success", func(t *testing.T) { ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if tx.Nonce() != uint64(*etx.Sequence) { @@ -1805,7 +1805,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Maybe() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() // Use a mock keystore for this test ec := newEthConfirmer(t, txStore, ethClient, evmcfg, kst, nil) currentHead := int64(30) @@ -1825,7 +1825,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { t.Run("re-sends previous transaction on keystore error", func(t *testing.T) { // simulate bumped transaction that is somehow impossible to sign - kst.On("SignTx", fromAddress, + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { return tx.Nonce() == uint64(*etx.Sequence) }), @@ -1845,7 +1845,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { t.Run("does nothing and continues on fatal error", func(t *testing.T) { ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if tx.Nonce() != uint64(*etx.Sequence) { @@ -1879,7 +1879,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedGasPrice.Int64(), attempt1_1.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -1927,7 +1927,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedGasPrice.Int64(), attempt1_2.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx.Sequence || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -1966,7 +1966,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { ethTx := *types.NewTx(&types.LegacyTx{}) receipt := evmtypes.Receipt{BlockNumber: big.NewInt(40)} - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx.Sequence || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2013,13 +2013,13 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.NoError(t, db.Get(&dbAttempt, `UPDATE evm.tx_attempts SET broadcast_before_block_num=$1 WHERE id=$2 RETURNING *`, oldEnough, attempt2_1.ID)) var attempt2_2 txmgr.TxAttempt - t.Run("saves in_progress attempt on temporary error and returns error", func(t *testing.T) { + t.Run("saves in-progress attempt on temporary error and returns error", func(t *testing.T) { expectedBumpedGasPrice := big.NewInt(20000000000) require.Greater(t, expectedBumpedGasPrice.Int64(), attempt2_1.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) n := *etx2.Sequence - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != n || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2086,7 +2086,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { ethTx := *types.NewTx(&types.LegacyTx{}) n := *etx2.Sequence - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != n || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2127,7 +2127,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedGasPrice.Int64(), attempt3_1.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx3.Sequence || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2164,7 +2164,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedGasPrice.Int64(), attempt3_1.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx3.Sequence || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2203,7 +2203,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedGasPrice.Int64(), attempt3_2.TxFee.Legacy.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx3.Sequence || expectedBumpedGasPrice.Cmp(tx.GasPrice()) != 0 { @@ -2299,7 +2299,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { t.Run("EIP-1559: bumps using EIP-1559 rules when existing attempts are of type 0x2", func(t *testing.T) { ethTx := *types.NewTx(&types.DynamicFeeTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx4.Sequence { @@ -2367,7 +2367,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary(t *testing.T) { require.Greater(t, expectedBumpedTipCap.Int64(), attempt4_2.TxFee.DynamicTipCap.ToInt().Int64()) ethTx := *types.NewTx(&types.LegacyTx{}) - kst.On("SignTx", + kst.On("SignTx", mock.Anything, fromAddress, mock.MatchedBy(func(tx *types.Transaction) bool { if evmtypes.Nonce(tx.Nonce()) != *etx4.Sequence || expectedBumpedTipCap.ToInt().Cmp(tx.GasTipCap()) != 0 { @@ -2417,7 +2417,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_TerminallyUnderpriced_ThenGoesTh // Use a mock keystore for this test kst := ksmocks.NewEth(t) addresses := []gethCommon.Address{fromAddress} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addresses, nil).Maybe() + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() currentHead := int64(30) oldEnough := 5 nonce := int64(0) @@ -2440,7 +2440,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_TerminallyUnderpriced_ThenGoesTh // Succeed the second time after bumping gas. ethClient.On("SendTransactionReturnCode", mock.Anything, mock.Anything, fromAddress).Return( commonclient.Successful, nil).Once() - kst.On("SignTx", mock.Anything, mock.Anything, mock.Anything).Return( + kst.On("SignTx", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( signedTx, nil, ).Once() require.NoError(t, ec.RebroadcastWhereNecessary(testutils.Context(t), currentHead)) @@ -2464,14 +2464,14 @@ func TestEthConfirmer_RebroadcastWhereNecessary_TerminallyUnderpriced_ThenGoesTh ethClient.On("SendTransactionReturnCode", mock.Anything, mock.Anything, fromAddress).Return( commonclient.Successful, nil).Once() signedLegacyTx := new(types.Transaction) - kst.On("SignTx", mock.Anything, mock.MatchedBy(func(tx *types.Transaction) bool { + kst.On("SignTx", mock.Anything, mock.Anything, mock.MatchedBy(func(tx *types.Transaction) bool { return tx.Type() == 0x0 && tx.Nonce() == uint64(*etx.Sequence) }), mock.Anything).Return( signedLegacyTx, nil, ).Run(func(args mock.Arguments) { - unsignedLegacyTx := args.Get(1).(*types.Transaction) + unsignedLegacyTx := args.Get(2).(*types.Transaction) // Use the real keystore to do the actual signing - thisSignedLegacyTx, err := ethKeyStore.SignTx(fromAddress, unsignedLegacyTx, testutils.FixtureChainID) + thisSignedLegacyTx, err := ethKeyStore.SignTx(testutils.Context(t), fromAddress, unsignedLegacyTx, testutils.FixtureChainID) require.NoError(t, err) *signedLegacyTx = *thisSignedLegacyTx }).Times(4) // 3 failures 1 success @@ -2496,14 +2496,14 @@ func TestEthConfirmer_RebroadcastWhereNecessary_TerminallyUnderpriced_ThenGoesTh ethClient.On("SendTransactionReturnCode", mock.Anything, mock.Anything, fromAddress).Return( commonclient.Successful, nil).Once() signedDxFeeTx := new(types.Transaction) - kst.On("SignTx", mock.Anything, mock.MatchedBy(func(tx *types.Transaction) bool { + kst.On("SignTx", mock.Anything, mock.Anything, mock.MatchedBy(func(tx *types.Transaction) bool { return tx.Type() == 0x2 && tx.Nonce() == uint64(*etx.Sequence) }), mock.Anything).Return( signedDxFeeTx, nil, ).Run(func(args mock.Arguments) { - unsignedDxFeeTx := args.Get(1).(*types.Transaction) + unsignedDxFeeTx := args.Get(2).(*types.Transaction) // Use the real keystore to do the actual signing - thisSignedDxFeeTx, err := ethKeyStore.SignTx(fromAddress, unsignedDxFeeTx, testutils.FixtureChainID) + thisSignedDxFeeTx, err := ethKeyStore.SignTx(testutils.Context(t), fromAddress, unsignedDxFeeTx, testutils.FixtureChainID) require.NoError(t, err) *signedDxFeeTx = *thisSignedDxFeeTx }).Times(4) // 3 failures 1 success @@ -2513,7 +2513,6 @@ func TestEthConfirmer_RebroadcastWhereNecessary_TerminallyUnderpriced_ThenGoesTh func TestEthConfirmer_RebroadcastWhereNecessary_WhenOutOfEth(t *testing.T) { t.Parallel() - db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) txStore := cltest.NewTestTxStore(t, db, cfg.Database()) @@ -2523,7 +2522,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_WhenOutOfEth(t *testing.T) { _, fromAddress := cltest.MustInsertRandomKeyReturningState(t, ethKeyStore) - _, err := ethKeyStore.EnabledKeysForChain(testutils.FixtureChainID) + _, err := ethKeyStore.EnabledKeysForChain(testutils.Context(t), testutils.FixtureChainID) require.NoError(t, err) require.NoError(t, err) // keyStates, err := ethKeyStore.GetStatesForKeys(keys) diff --git a/core/chains/evm/txmgr/evm_tx_store_test.go b/core/chains/evm/txmgr/evm_tx_store_test.go index 35d684727d1..c3bf531276a 100644 --- a/core/chains/evm/txmgr/evm_tx_store_test.go +++ b/core/chains/evm/txmgr/evm_tx_store_test.go @@ -411,8 +411,8 @@ func TestORM_SetBroadcastBeforeBlockNum(t *testing.T) { }) t.Run("only updates evm.tx_attempts for the current chain", func(t *testing.T) { - require.NoError(t, ethKeyStore.Add(fromAddress, testutils.SimulatedChainID)) - require.NoError(t, ethKeyStore.Enable(fromAddress, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Add(testutils.Context(t), fromAddress, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Enable(testutils.Context(t), fromAddress, testutils.SimulatedChainID)) etxThisChain := cltest.MustInsertUnconfirmedEthTxWithBroadcastLegacyAttempt(t, txStore, 1, fromAddress, cfg.EVM().ChainID()) etxOtherChain := cltest.MustInsertUnconfirmedEthTxWithBroadcastLegacyAttempt(t, txStore, 0, fromAddress, testutils.SimulatedChainID) diff --git a/core/chains/evm/txmgr/resender_test.go b/core/chains/evm/txmgr/resender_test.go index fd3d1745010..e8c1c9e079f 100644 --- a/core/chains/evm/txmgr/resender_test.go +++ b/core/chains/evm/txmgr/resender_test.go @@ -155,6 +155,7 @@ func Test_EthResender_Start(t *testing.T) { lggr := logger.Test(t) t.Run("resends transactions that have been languishing unconfirmed for too long", func(t *testing.T) { + ctx := testutils.Context(t) ethClient := evmtest.NewEthClientMockWithDefaultChain(t) er := txmgr.NewEvmResender(lggr, txStore, txmgr.NewEvmTxmClient(ethClient), txmgr.NewEvmTracker(txStore, ethKeyStore, big.NewInt(0), lggr), ethKeyStore, 100*time.Millisecond, ccfg.EVM(), ccfg.EVM().Transactions()) @@ -180,7 +181,7 @@ func Test_EthResender_Start(t *testing.T) { }) func() { - er.Start() + er.Start(ctx) defer er.Stop() cltest.EventuallyExpectationsMet(t, ethClient, 5*time.Second, time.Second) diff --git a/core/chains/evm/txmgr/txmgr_test.go b/core/chains/evm/txmgr/txmgr_test.go index 0e28f2948ee..ddcb281944b 100644 --- a/core/chains/evm/txmgr/txmgr_test.go +++ b/core/chains/evm/txmgr/txmgr_test.go @@ -481,11 +481,11 @@ func TestTxm_Lifecycle(t *testing.T) { evmConfig.ReaperThreshold = 1 * time.Hour evmConfig.ReaperInterval = 1 * time.Hour - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return([]common.Address{}, nil) + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return([]common.Address{}, nil) keyChangeCh := make(chan struct{}) unsub := cltest.NewAwaiter() - kst.On("SubscribeToKeyChanges").Return(keyChangeCh, unsub.ItHappened) + kst.On("SubscribeToKeyChanges", mock.Anything).Return(keyChangeCh, unsub.ItHappened) estimator := gas.NewEstimator(logger.Test(t), ethClient, config, evmConfig.GasEstimator()) txm, err := makeTestEvmTxm(t, db, ethClient, estimator, evmConfig, evmConfig.GasEstimator(), evmConfig.Transactions(), dbConfig, dbConfig.Listener(), kst) require.NoError(t, err) @@ -506,7 +506,7 @@ func TestTxm_Lifecycle(t *testing.T) { keyState := cltest.MustGenerateRandomKeyState(t) addr := []common.Address{keyState.Address.Address()} - kst.On("EnabledAddressesForChain", &cltest.FixtureChainID).Return(addr, nil) + kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addr, nil) ethClient.On("PendingNonceAt", mock.AnythingOfType("*context.cancelCtx"), common.Address{}).Return(uint64(0), nil).Maybe() keyChangeCh <- struct{}{} diff --git a/core/chains/evm/types/models_test.go b/core/chains/evm/types/models_test.go index a32deba697d..80f09d7a779 100644 --- a/core/chains/evm/types/models_test.go +++ b/core/chains/evm/types/models_test.go @@ -99,7 +99,7 @@ func TestEthTxAttempt_GetSignedTx(t *testing.T) { chainID := big.NewInt(3) - signedTx, err := ethKeyStore.SignTx(fromAddress, tx, chainID) + signedTx, err := ethKeyStore.SignTx(testutils.Context(t), fromAddress, tx, chainID) require.NoError(t, err) rlp := new(bytes.Buffer) require.NoError(t, signedTx.EncodeRLP(rlp)) diff --git a/core/cmd/eth_keys_commands_test.go b/core/cmd/eth_keys_commands_test.go index de40a5bf873..2f22cd1d3ae 100644 --- a/core/cmd/eth_keys_commands_test.go +++ b/core/cmd/eth_keys_commands_test.go @@ -148,7 +148,7 @@ func TestShell_ListETHKeys_Disabled(t *testing.T) { withMocks(ethClient), ) client, r := app.NewShellAndRenderer() - keys, err := app.KeyStore.Eth().GetAll() + keys, err := app.KeyStore.Eth().GetAll(testutils.Context(t)) require.NoError(t, err) require.Equal(t, 1, len(keys)) k := keys[0] @@ -186,7 +186,7 @@ func TestShell_CreateETHKey(t *testing.T) { client, _ := app.NewShellAndRenderer() cltest.AssertCount(t, db, "evm.key_states", 1) // The initial funding key - keys, err := app.KeyStore.Eth().GetAll() + keys, err := app.KeyStore.Eth().GetAll(testutils.Context(t)) require.NoError(t, err) require.Equal(t, 1, len(keys)) @@ -202,7 +202,7 @@ func TestShell_CreateETHKey(t *testing.T) { assert.NoError(t, client.CreateETHKey(c)) cltest.AssertCount(t, db, "evm.key_states", 2) - keys, err = app.KeyStore.Eth().GetAll() + keys, err = app.KeyStore.Eth().GetAll(testutils.Context(t)) require.NoError(t, err) require.Equal(t, 2, len(keys)) } @@ -221,7 +221,7 @@ func TestShell_DeleteETHKey(t *testing.T) { client, _ := app.NewShellAndRenderer() // Create the key - key, err := ethKeyStore.Create(&cltest.FixtureChainID) + key, err := ethKeyStore.Create(testutils.Context(t), &cltest.FixtureChainID) require.NoError(t, err) // Delete the key @@ -235,7 +235,7 @@ func TestShell_DeleteETHKey(t *testing.T) { err = client.DeleteETHKey(c) require.NoError(t, err) - _, err = ethKeyStore.Get(key.Address.Hex()) + _, err = ethKeyStore.Get(testutils.Context(t), key.Address.Hex()) assert.Error(t, err) } @@ -303,7 +303,7 @@ func TestShell_ImportExportETHKey_NoChains(t *testing.T) { c = cli.NewContext(nil, set, nil) err = client.DeleteETHKey(c) require.NoError(t, err) - _, err = ethKeyStore.Get(address) + _, err = ethKeyStore.Get(testutils.Context(t), address) require.Error(t, err) cltest.AssertCount(t, app.GetSqlxDB(), "evm.key_states", 0) @@ -328,7 +328,7 @@ func TestShell_ImportExportETHKey_NoChains(t *testing.T) { err = client.ListETHKeys(c) require.NoError(t, err) require.Len(t, *r.Renders[0].(*cmd.EthKeyPresenters), 1) - _, err = ethKeyStore.Get(address) + _, err = ethKeyStore.Get(testutils.Context(t), address) require.NoError(t, err) // Export test invalid id @@ -411,7 +411,7 @@ func TestShell_ImportExportETHKey_WithChains(t *testing.T) { c = cli.NewContext(nil, set, nil) err = client.DeleteETHKey(c) require.NoError(t, err) - _, err = ethKeyStore.Get(address) + _, err = ethKeyStore.Get(testutils.Context(t), address) require.Error(t, err) // Import the key @@ -435,7 +435,7 @@ func TestShell_ImportExportETHKey_WithChains(t *testing.T) { err = client.ListETHKeys(c) require.NoError(t, err) require.Len(t, *r.Renders[0].(*cmd.EthKeyPresenters), 1) - _, err = ethKeyStore.Get(address) + _, err = ethKeyStore.Get(testutils.Context(t), address) require.NoError(t, err) // Export test invalid id diff --git a/core/cmd/ocr2vrf_configure_commands.go b/core/cmd/ocr2vrf_configure_commands.go index 06f26ddb6a4..906c27374c8 100644 --- a/core/cmd/ocr2vrf_configure_commands.go +++ b/core/cmd/ocr2vrf_configure_commands.go @@ -173,7 +173,7 @@ func (s *Shell) ConfigureOCR2VRFNode(c *cli.Context, owner *bind.TransactOpts, e // Initialize keystore and generate keys. keyStore := app.GetKeyStore() - err = setupKeystore(s, app, keyStore) + err = setupKeystore(ctx, s, app, keyStore) if err != nil { return nil, s.errorOut(err) } @@ -191,7 +191,7 @@ func (s *Shell) ConfigureOCR2VRFNode(c *cli.Context, owner *bind.TransactOpts, e var sendingKeys []string var sendingKeysAddresses []common.Address useForwarder := c.Bool("use-forwarder") - ethKeys, err := app.GetKeyStore().Eth().EnabledKeysForChain(big.NewInt(chainID)) + ethKeys, err := app.GetKeyStore().Eth().EnabledKeysForChain(ctx, big.NewInt(chainID)) if err != nil { return nil, s.errorOut(err) } @@ -205,7 +205,7 @@ func (s *Shell) ConfigureOCR2VRFNode(c *cli.Context, owner *bind.TransactOpts, e if useForwarder { // Add extra sending keys if using a forwarder. - sendingKeys, sendingKeysAddresses, err = s.appendForwarders(chainID, app.GetKeyStore().Eth(), sendingKeys, sendingKeysAddresses) + sendingKeys, sendingKeysAddresses, err = s.appendForwarders(ctx, chainID, app.GetKeyStore().Eth(), sendingKeys, sendingKeysAddresses) if err != nil { return nil, err } @@ -298,16 +298,16 @@ func (s *Shell) ConfigureOCR2VRFNode(c *cli.Context, owner *bind.TransactOpts, e }, nil } -func (s *Shell) appendForwarders(chainID int64, ks keystore.Eth, sendingKeys []string, sendingKeysAddresses []common.Address) ([]string, []common.Address, error) { +func (s *Shell) appendForwarders(ctx context.Context, chainID int64, ks keystore.Eth, sendingKeys []string, sendingKeysAddresses []common.Address) ([]string, []common.Address, error) { for i := 0; i < forwarderAdditionalEOACount; i++ { // Create the sending key in the keystore. - k, err := ks.Create() + k, err := ks.Create(ctx) if err != nil { return nil, nil, err } // Enable the sending key for the current chain. - err = ks.Enable(k.Address, big.NewInt(chainID)) + err = ks.Enable(ctx, k.Address, big.NewInt(chainID)) if err != nil { return nil, nil, err } @@ -351,7 +351,7 @@ func (s *Shell) authorizeForwarder(c *cli.Context, db *sqlx.DB, lggr logger.Logg return nil } -func setupKeystore(cli *Shell, app chainlink.Application, keyStore keystore.Master) error { +func setupKeystore(ctx context.Context, cli *Shell, app chainlink.Application, keyStore keystore.Master) error { if err := cli.KeyStoreAuthenticator.authenticate(keyStore, cli.Config.Password()); err != nil { return errors.Wrap(err, "error authenticating keystore") } @@ -362,7 +362,7 @@ func setupKeystore(cli *Shell, app chainlink.Application, keyStore keystore.Mast return fmt.Errorf("failed to get legacy evm chains") } for _, ch := range chains { - if err = keyStore.Eth().EnsureKeys(ch.ID()); err != nil { + if err = keyStore.Eth().EnsureKeys(ctx, ch.ID()); err != nil { return errors.Wrap(err, "failed to ensure keystore keys") } } diff --git a/core/cmd/shell_local.go b/core/cmd/shell_local.go index 350e6abf77d..7e03fe719e1 100644 --- a/core/cmd/shell_local.go +++ b/core/cmd/shell_local.go @@ -378,7 +378,7 @@ func (s *Shell) runNode(c *cli.Context) error { for _, ch := range chainList { if ch.Config().EVM().AutoCreateKey() { lggr.Debugf("AutoCreateKey=true, will ensure EVM key for chain %s", ch.ID()) - err2 := app.GetKeyStore().Eth().EnsureKeys(ch.ID()) + err2 := app.GetKeyStore().Eth().EnsureKeys(rootCtx, ch.ID()) if err2 != nil { return errors.Wrap(err2, "failed to ensure keystore keys") } @@ -625,7 +625,7 @@ func (s *Shell) RebroadcastTransactions(c *cli.Context) (err error) { return s.errorOut(errors.Wrap(err, "error authenticating keystore")) } - if err = keyStore.Eth().CheckEnabled(address, chain.ID()); err != nil { + if err = keyStore.Eth().CheckEnabled(ctx, address, chain.ID()); err != nil { return s.errorOut(err) } diff --git a/core/cmd/shell_local_test.go b/core/cmd/shell_local_test.go index 72c2f2b5bbd..d6f4946dd9d 100644 --- a/core/cmd/shell_local_test.go +++ b/core/cmd/shell_local_test.go @@ -181,7 +181,7 @@ func TestShell_RunNodeWithAPICredentialsFile(t *testing.T) { pgtest.MustExec(t, db, "DELETE FROM users;") keyStore := cltest.NewKeyStore(t, db, cfg.Database()) - _, err := keyStore.Eth().Create(&cltest.FixtureChainID) + _, err := keyStore.Eth().Create(testutils.Context(t), &cltest.FixtureChainID) require.NoError(t, err) ethClient := evmtest.NewEthClientMock(t) @@ -436,7 +436,6 @@ func TestShell_RebroadcastTransactions_AddressCheck(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - config, sqlxDB := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.Database.Dialect = dialects.Postgres @@ -450,7 +449,7 @@ func TestShell_RebroadcastTransactions_AddressCheck(t *testing.T) { _, fromAddress := cltest.MustInsertRandomKey(t, keyStore.Eth()) if !test.enableAddress { - err := keyStore.Eth().Disable(fromAddress, testutils.FixtureChainID) + err := keyStore.Eth().Disable(testutils.Context(t), fromAddress, testutils.FixtureChainID) require.NoError(t, err, "failed to disable test key") } diff --git a/core/internal/cltest/cltest.go b/core/internal/cltest/cltest.go index 332513b28d4..08766d64c8b 100644 --- a/core/internal/cltest/cltest.go +++ b/core/internal/cltest/cltest.go @@ -236,6 +236,7 @@ func NewApplicationWithKey(t *testing.T, flagsAndDeps ...interface{}) *TestAppli // NewApplicationWithConfigAndKey creates a new TestApplication with the given testorm // it will also provide an unlocked account on the keystore func NewApplicationWithConfigAndKey(t testing.TB, c chainlink.GeneralConfig, flagsAndDeps ...interface{}) *TestApplication { + ctx := testutils.Context(t) app := NewApplicationWithConfig(t, c, flagsAndDeps...) chainID := *ubig.New(&FixtureChainID) @@ -252,9 +253,9 @@ func NewApplicationWithConfigAndKey(t testing.TB, c chainlink.GeneralConfig, fla } else { id, ks := chainID.ToInt(), app.KeyStore.Eth() for _, k := range app.Keys { - ks.XXXTestingOnlyAdd(k) - require.NoError(t, ks.Add(k.Address, id)) - require.NoError(t, ks.Enable(k.Address, id)) + ks.XXXTestingOnlyAdd(ctx, k) + require.NoError(t, ks.Add(ctx, k.Address, id)) + require.NoError(t, ks.Enable(ctx, k.Address, id)) } } @@ -571,9 +572,9 @@ func (ta *TestApplication) MustSeedNewSession(email string) (id string) { } // ImportKey adds private key to the application keystore and database -func (ta *TestApplication) Import(content string) { +func (ta *TestApplication) Import(ctx context.Context, content string) { require.NoError(ta.t, ta.KeyStore.Unlock(Password)) - _, err := ta.KeyStore.Eth().Import([]byte(content), Password, &FixtureChainID) + _, err := ta.KeyStore.Eth().Import(ctx, []byte(content), Password, &FixtureChainID) require.NoError(ta.t, err) } diff --git a/core/internal/cltest/factories.go b/core/internal/cltest/factories.go index 804dbe2d088..2649cc47c6c 100644 --- a/core/internal/cltest/factories.go +++ b/core/internal/cltest/factories.go @@ -258,18 +258,19 @@ type RandomKey struct { } func (r RandomKey) MustInsert(t testing.TB, keystore keystore.Eth) (ethkey.KeyV2, common.Address) { + ctx := testutils.Context(t) if r.chainIDs == nil { r.chainIDs = []ubig.Big{*ubig.New(&FixtureChainID)} } key := MustGenerateRandomKey(t) - keystore.XXXTestingOnlyAdd(key) + keystore.XXXTestingOnlyAdd(ctx, key) for _, cid := range r.chainIDs { - require.NoError(t, keystore.Add(key.Address, cid.ToInt())) - require.NoError(t, keystore.Enable(key.Address, cid.ToInt())) + require.NoError(t, keystore.Add(ctx, key.Address, cid.ToInt())) + require.NoError(t, keystore.Enable(ctx, key.Address, cid.ToInt())) if r.Disabled { - require.NoError(t, keystore.Disable(key.Address, cid.ToInt())) + require.NoError(t, keystore.Disable(ctx, key.Address, cid.ToInt())) } } @@ -277,8 +278,9 @@ func (r RandomKey) MustInsert(t testing.TB, keystore keystore.Eth) (ethkey.KeyV2 } func (r RandomKey) MustInsertWithState(t testing.TB, keystore keystore.Eth) (ethkey.State, common.Address) { + ctx := testutils.Context(t) k, address := r.MustInsert(t, keystore) - state, err := keystore.GetStateForKey(k) + state, err := keystore.GetStateForKey(ctx, k) require.NoError(t, err) return state, address } diff --git a/core/internal/features/features_test.go b/core/internal/features/features_test.go index 1c4d097d633..d94452f2512 100644 --- a/core/internal/features/features_test.go +++ b/core/internal/features/features_test.go @@ -361,6 +361,7 @@ func TestIntegration_DirectRequest(t *testing.T) { for _, tt := range tests { test := tt t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) // Simulate a consumer contract calling to obtain ETH quotes in 3 different currencies // in a single callback. config := configtest.NewGeneralConfigSimulated(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -371,7 +372,7 @@ func TestIntegration_DirectRequest(t *testing.T) { b := operatorContracts.sim app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, b) - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) authorizedSenders := []common.Address{sendingKeys[0].Address} tx, err := operatorContracts.operator.SetAuthorizedSenders(operatorContracts.user, authorizedSenders) @@ -474,7 +475,7 @@ func setupAppForEthTx(t *testing.T, operatorContracts OperatorContracts) (app *c app = cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, cfg, b, lggr) b.Commit() - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) require.Len(t, sendingKeys, 1) @@ -703,7 +704,7 @@ func setupNode(t *testing.T, owner *bind.TransactOpts, portV2 int, app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, b, p2pKey) - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) transmitter := sendingKeys[0].Address @@ -745,7 +746,7 @@ func setupForwarderEnabledNode(t *testing.T, owner *bind.TransactOpts, portV2 in app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, b, p2pKey) - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) transmitter := sendingKeys[0].Address @@ -1360,7 +1361,7 @@ func TestIntegration_BlockHistoryEstimator(t *testing.T) { func triggerAllKeys(t *testing.T, app *cltest.TestApplication) { for _, chain := range app.GetRelayers().LegacyEVMChains().Slice() { - keys, err := app.KeyStore.Eth().EnabledKeysForChain(chain.ID()) + keys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), chain.ID()) require.NoError(t, err) for _, k := range keys { chain.TxManager().Trigger(k.Address) diff --git a/core/internal/features/ocr2/features_ocr2_test.go b/core/internal/features/ocr2/features_ocr2_test.go index 938b7aa2a66..e089970951b 100644 --- a/core/internal/features/ocr2/features_ocr2_test.go +++ b/core/internal/features/ocr2/features_ocr2_test.go @@ -135,7 +135,7 @@ func setupNodeOCR2( app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, b, p2pKey) - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) require.Len(t, sendingKeys, 1) transmitter := sendingKeys[0].Address diff --git a/core/scripts/chaincli/handler/keeper_upkeep_history.go b/core/scripts/chaincli/handler/keeper_upkeep_history.go index d4237b708f2..487d06dd5ed 100644 --- a/core/scripts/chaincli/handler/keeper_upkeep_history.go +++ b/core/scripts/chaincli/handler/keeper_upkeep_history.go @@ -127,7 +127,7 @@ func (k *Keeper) UpkeepHistory(ctx context.Context, upkeepId *big.Int, from, to, panic("unsupported registry version") } - turnBinary, err2 := turnBlockHashBinary(block, bcpt, defaultLookBackRange, k.client) + turnBinary, err2 := turnBlockHashBinary(ctx, block, bcpt, defaultLookBackRange, k.client) if err2 != nil { log.Fatal("failed to calculate turn block hash: ", err2) } @@ -265,9 +265,9 @@ func printResultsToConsole(parsedResults []result) { fmt.Fprintf(writer, "\n %s\t\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t\n", "----", "----", "----", "----", "----", "----", "----", "----", "----", "----") } -func turnBlockHashBinary(blockNum, bcpt, lookback uint64, ethClient *ethclient.Client) (string, error) { +func turnBlockHashBinary(ctx context.Context, blockNum, bcpt, lookback uint64, ethClient *ethclient.Client) (string, error) { turnBlock := blockNum - (blockNum % bcpt) - lookback - block, err := ethClient.BlockByNumber(context.Background(), big.NewInt(int64(turnBlock))) + block, err := ethClient.BlockByNumber(ctx, big.NewInt(int64(turnBlock))) if err != nil { return "", err } diff --git a/core/services/blockhashstore/bhs.go b/core/services/blockhashstore/bhs.go index 3de00f64590..0ca91c682e7 100644 --- a/core/services/blockhashstore/bhs.go +++ b/core/services/blockhashstore/bhs.go @@ -91,7 +91,7 @@ func (c *BulletproofBHS) Store(ctx context.Context, blockNum uint64) error { return errors.Wrap(err, "packing args") } - fromAddress, err := c.gethks.GetRoundRobinAddress(c.chainID, SendingKeys(c.fromAddresses)...) + fromAddress, err := c.gethks.GetRoundRobinAddress(ctx, c.chainID, SendingKeys(c.fromAddresses)...) if err != nil { return errors.Wrap(err, "getting next from address") } @@ -132,7 +132,7 @@ func (c *BulletproofBHS) StoreTrusted( } // Create a transaction from the given batch and send it to the TXM. - fromAddress, err := c.gethks.GetRoundRobinAddress(c.chainID, SendingKeys(c.fromAddresses)...) + fromAddress, err := c.gethks.GetRoundRobinAddress(ctx, c.chainID, SendingKeys(c.fromAddresses)...) if err != nil { return errors.Wrap(err, "getting next from address") } @@ -186,7 +186,7 @@ func (c *BulletproofBHS) StoreEarliest(ctx context.Context) error { return errors.Wrap(err, "packing args") } - fromAddress, err := c.gethks.GetRoundRobinAddress(c.chainID, c.sendingKeys()...) + fromAddress, err := c.gethks.GetRoundRobinAddress(ctx, c.chainID, c.sendingKeys()...) if err != nil { return errors.Wrap(err, "getting next from address") } diff --git a/core/services/blockhashstore/bhs_test.go b/core/services/blockhashstore/bhs_test.go index 44205ec7b86..f8d33b51a34 100644 --- a/core/services/blockhashstore/bhs_test.go +++ b/core/services/blockhashstore/bhs_test.go @@ -24,6 +24,7 @@ import ( ) func TestStoreRotatesFromAddresses(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) ethClient := evmtest.NewEthClientMockWithDefaultChain(t) cfg := configtest.NewTestGeneralConfig(t) @@ -36,9 +37,9 @@ func TestStoreRotatesFromAddresses(t *testing.T) { lggr := logger.TestLogger(t) ks := keystore.New(db, utils.FastScryptParams, lggr, cfg.Database()) require.NoError(t, ks.Unlock("blah")) - k1, err := ks.Eth().Create(&cltest.FixtureChainID) + k1, err := ks.Eth().Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - k2, err := ks.Eth().Create(&cltest.FixtureChainID) + k2, err := ks.Eth().Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) fromAddresses := []ethkey.EIP55Address{k1.EIP55Address, k2.EIP55Address} txm := new(txmmocks.MockEvmTxManager) @@ -66,8 +67,6 @@ func TestStoreRotatesFromAddresses(t *testing.T) { return tx.FromAddress.String() == k2.Address.String() })).Once().Return(txmgr.Tx{}, nil) - ctx := testutils.Context(t) - // store 2 blocks err = bhs.Store(ctx, 1) require.NoError(t, err) diff --git a/core/services/blockhashstore/delegate.go b/core/services/blockhashstore/delegate.go index d6c27acd0b5..d07efcb95fe 100644 --- a/core/services/blockhashstore/delegate.go +++ b/core/services/blockhashstore/delegate.go @@ -51,7 +51,7 @@ func (d *Delegate) JobType() job.Type { } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { if jb.BlockhashStoreSpec == nil { return nil, errors.Errorf( "blockhashstore.Delegate expects a BlockhashStoreSpec to be present, got %+v", jb) @@ -67,7 +67,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { return nil, errors.New("log poller must be enabled to run blockhashstore") } - keys, err := d.ks.EnabledKeysForChain(chain.ID()) + keys, err := d.ks.EnabledKeysForChain(ctx, chain.ID()) if err != nil { return nil, errors.Wrap(err, "getting sending keys") } diff --git a/core/services/blockhashstore/delegate_test.go b/core/services/blockhashstore/delegate_test.go index 6fffcfdd493..5f5118afeaf 100644 --- a/core/services/blockhashstore/delegate_test.go +++ b/core/services/blockhashstore/delegate_test.go @@ -91,7 +91,7 @@ func TestDelegate_ServicesForSpec(t *testing.T) { t.Run("happy", func(t *testing.T) { spec := job.Job{BlockhashStoreSpec: &job.BlockhashStoreSpec{WaitBlocks: defaultWaitBlocks, EVMChainID: (*big.Big)(testutils.FixtureChainID)}} - services, err := delegate.ServicesForSpec(spec) + services, err := delegate.ServicesForSpec(testutils.Context(t), spec) require.NoError(t, err) require.Len(t, services, 1) @@ -109,7 +109,7 @@ func TestDelegate_ServicesForSpec(t *testing.T) { CoordinatorV2PlusAddress: &coordinatorV2Plus, EVMChainID: (*big.Big)(testutils.FixtureChainID), }} - services, err := delegate.ServicesForSpec(spec) + services, err := delegate.ServicesForSpec(testutils.Context(t), spec) require.NoError(t, err) require.Len(t, services, 1) @@ -117,7 +117,7 @@ func TestDelegate_ServicesForSpec(t *testing.T) { t.Run("missing BlockhashStoreSpec", func(t *testing.T) { spec := job.Job{BlockhashStoreSpec: nil} - _, err := delegate.ServicesForSpec(spec) + _, err := delegate.ServicesForSpec(testutils.Context(t), spec) assert.Error(t, err) }) @@ -125,18 +125,19 @@ func TestDelegate_ServicesForSpec(t *testing.T) { spec := job.Job{BlockhashStoreSpec: &job.BlockhashStoreSpec{ EVMChainID: big.NewI(123), }} - _, err := delegate.ServicesForSpec(spec) + _, err := delegate.ServicesForSpec(testutils.Context(t), spec) assert.Error(t, err) }) t.Run("missing EnabledKeysForChain", func(t *testing.T) { - _, err := testData.ethKeyStore.Delete(testData.sendingKey.ID()) + ctx := testutils.Context(t) + _, err := testData.ethKeyStore.Delete(ctx, testData.sendingKey.ID()) require.NoError(t, err) spec := job.Job{BlockhashStoreSpec: &job.BlockhashStoreSpec{ WaitBlocks: defaultWaitBlocks, }} - _, err = delegate.ServicesForSpec(spec) + _, err = delegate.ServicesForSpec(testutils.Context(t), spec) assert.Error(t, err) }) } @@ -154,7 +155,7 @@ func TestDelegate_StartStop(t *testing.T) { RunTimeout: testutils.WaitTimeout(t), EVMChainID: (*big.Big)(testutils.FixtureChainID), }} - services, err := delegate.ServicesForSpec(spec) + services, err := delegate.ServicesForSpec(testutils.Context(t), spec) require.NoError(t, err) require.Len(t, services, 1) diff --git a/core/services/blockheaderfeeder/block_header_feeder.go b/core/services/blockheaderfeeder/block_header_feeder.go index a5bcb003613..d1bcab4297a 100644 --- a/core/services/blockheaderfeeder/block_header_feeder.go +++ b/core/services/blockheaderfeeder/block_header_feeder.go @@ -148,7 +148,7 @@ func (f *BlockHeaderFeeder) Run(ctx context.Context) error { } // use 1 sending key for all batches because ordering matters for StoreVerifyHeader - fromAddress, err := f.gethks.GetRoundRobinAddress(f.chainID, blockhashstore.SendingKeys(f.fromAddresses)...) + fromAddress, err := f.gethks.GetRoundRobinAddress(ctx, f.chainID, blockhashstore.SendingKeys(f.fromAddresses)...) if err != nil { return errors.Wrap(err, "getting round robin address") } diff --git a/core/services/blockheaderfeeder/block_header_feeder_test.go b/core/services/blockheaderfeeder/block_header_feeder_test.go index 6c1ec0946e7..1b855caf9d2 100644 --- a/core/services/blockheaderfeeder/block_header_feeder_test.go +++ b/core/services/blockheaderfeeder/block_header_feeder_test.go @@ -202,7 +202,7 @@ func (test testCase) testFeeder(t *testing.T) { fromAddress := "0x469aA2CD13e037DC5236320783dCfd0e641c0559" fromAddresses := []ethkey.EIP55Address{ethkey.EIP55Address(fromAddress)} ks := keystoremocks.NewEth(t) - ks.On("GetRoundRobinAddress", testutils.FixtureChainID, mock.Anything).Maybe().Return(common.HexToAddress(fromAddress), nil) + ks.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, mock.Anything).Maybe().Return(common.HexToAddress(fromAddress), nil) feeder := NewBlockHeaderFeeder( lggr, @@ -246,7 +246,7 @@ func TestFeeder_CachesStoredBlocks(t *testing.T) { fromAddress := "0x469aA2CD13e037DC5236320783dCfd0e641c0559" fromAddresses := []ethkey.EIP55Address{ethkey.EIP55Address(fromAddress)} ks := keystoremocks.NewEth(t) - ks.On("GetRoundRobinAddress", testutils.FixtureChainID, mock.Anything).Maybe().Return(common.HexToAddress(fromAddress), nil) + ks.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, mock.Anything).Maybe().Return(common.HexToAddress(fromAddress), nil) feeder := NewBlockHeaderFeeder( logger.TestLogger(t), diff --git a/core/services/blockheaderfeeder/delegate.go b/core/services/blockheaderfeeder/delegate.go index 53f514cee27..d78782f6592 100644 --- a/core/services/blockheaderfeeder/delegate.go +++ b/core/services/blockheaderfeeder/delegate.go @@ -49,7 +49,7 @@ func (d *Delegate) JobType() job.Type { } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { if jb.BlockHeaderFeederSpec == nil { return nil, errors.Errorf("Delegate expects a BlockHeaderFeederSpec to be present, got %+v", jb) } @@ -70,14 +70,14 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { chain.Config().EVM().FinalityDepth(), jb.BlockHeaderFeederSpec.LookbackBlocks) } - keys, err := d.ks.EnabledKeysForChain(chain.ID()) + keys, err := d.ks.EnabledKeysForChain(ctx, chain.ID()) if err != nil { return nil, errors.Wrap(err, "getting sending keys") } if len(keys) == 0 { return nil, fmt.Errorf("missing sending keys for chain ID: %v", chain.ID()) } - if err = CheckFromAddressesExist(jb, d.ks); err != nil { + if err = CheckFromAddressesExist(ctx, jb, d.ks); err != nil { return nil, err } fromAddresses := jb.BlockHeaderFeederSpec.FromAddresses @@ -269,9 +269,9 @@ func (s *service) runFeeder() { // CheckFromAddressesExist returns an error if and only if one of the addresses // in the BlockHeaderFeeder spec's fromAddresses field does not exist in the keystore. -func CheckFromAddressesExist(jb job.Job, gethks keystore.Eth) (err error) { +func CheckFromAddressesExist(ctx context.Context, jb job.Job, gethks keystore.Eth) (err error) { for _, a := range jb.BlockHeaderFeederSpec.FromAddresses { - _, err2 := gethks.Get(a.Hex()) + _, err2 := gethks.Get(ctx, a.Hex()) err = multierr.Append(err, err2) } return diff --git a/core/services/cron/cron_test.go b/core/services/cron/cron_test.go index b561248eddb..5c968c75824 100644 --- a/core/services/cron/cron_test.go +++ b/core/services/cron/cron_test.go @@ -41,7 +41,7 @@ func TestCronV2Pipeline(t *testing.T) { delegate := cron.NewDelegate(runner, lggr) require.NoError(t, jobORM.CreateJob(jb)) - serviceArray, err := delegate.ServicesForSpec(*jb) + serviceArray, err := delegate.ServicesForSpec(testutils.Context(t), *jb) require.NoError(t, err) assert.Len(t, serviceArray, 1) service := serviceArray[0] diff --git a/core/services/cron/delegate.go b/core/services/cron/delegate.go index c227fd60d00..4a08fec5a40 100644 --- a/core/services/cron/delegate.go +++ b/core/services/cron/delegate.go @@ -1,6 +1,8 @@ package cron import ( + "context" + "github.com/pkg/errors" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -33,7 +35,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec returns the scheduler to be used for running cron jobs -func (d *Delegate) ServicesForSpec(spec job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { if spec.CronSpec == nil { return nil, errors.Errorf("services.Delegate expects a *jobSpec.CronSpec to be present, got %v", spec) } diff --git a/core/services/directrequest/delegate.go b/core/services/directrequest/delegate.go index cfdf1eed116..083e6f02266 100644 --- a/core/services/directrequest/delegate.go +++ b/core/services/directrequest/delegate.go @@ -69,7 +69,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec returns the log listener service for a direct request job -func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { if jb.DirectRequestSpec == nil { return nil, errors.Errorf("DirectRequest: directrequest.Delegate expects a *job.DirectRequestSpec to be present, got %v", jb) } diff --git a/core/services/directrequest/delegate_test.go b/core/services/directrequest/delegate_test.go index 3b80ba2f915..d8540b4471c 100644 --- a/core/services/directrequest/delegate_test.go +++ b/core/services/directrequest/delegate_test.go @@ -54,13 +54,13 @@ func TestDelegate_ServicesForSpec(t *testing.T) { t.Run("Spec without DirectRequestSpec", func(t *testing.T) { spec := job.Job{} - _, err := delegate.ServicesForSpec(spec) + _, err := delegate.ServicesForSpec(testutils.Context(t), spec) assert.Error(t, err, "expects a *job.DirectRequestSpec to be present") }) t.Run("Spec with DirectRequestSpec", func(t *testing.T) { spec := job.Job{DirectRequestSpec: &job.DirectRequestSpec{EVMChainID: (*ubig.Big)(testutils.FixtureChainID)}, PipelineSpec: &pipeline.Spec{}} - services, err := delegate.ServicesForSpec(spec) + services, err := delegate.ServicesForSpec(testutils.Context(t), spec) require.NoError(t, err) assert.Len(t, services, 1) }) @@ -100,7 +100,7 @@ func NewDirectRequestUniverseWithConfig(t *testing.T, cfg chainlink.GeneralConfi specF(jb) } require.NoError(t, jobORM.CreateJob(jb)) - serviceArray, err := delegate.ServicesForSpec(*jb) + serviceArray, err := delegate.ServicesForSpec(testutils.Context(t), *jb) require.NoError(t, err) assert.Len(t, serviceArray, 1) service := serviceArray[0] diff --git a/core/services/fluxmonitorv2/contract_submitter.go b/core/services/fluxmonitorv2/contract_submitter.go index 8f3f40c309d..c5a6e599f5d 100644 --- a/core/services/fluxmonitorv2/contract_submitter.go +++ b/core/services/fluxmonitorv2/contract_submitter.go @@ -52,7 +52,7 @@ func NewFluxAggregatorContractSubmitter( // Submit submits the answer by writing a EthTx for the txmgr to // pick up func (c *FluxAggregatorContractSubmitter) Submit(ctx context.Context, roundID *big.Int, submission *big.Int, idempotencyKey *string) error { - fromAddress, err := c.keyStore.GetRoundRobinAddress(c.chainID) + fromAddress, err := c.keyStore.GetRoundRobinAddress(ctx, c.chainID) if err != nil { return err } diff --git a/core/services/fluxmonitorv2/contract_submitter_test.go b/core/services/fluxmonitorv2/contract_submitter_test.go index c3b2ca7e715..4c8ce019bfd 100644 --- a/core/services/fluxmonitorv2/contract_submitter_test.go +++ b/core/services/fluxmonitorv2/contract_submitter_test.go @@ -32,7 +32,7 @@ func TestFluxAggregatorContractSubmitter_Submit(t *testing.T) { payload, err := fluxmonitorv2.FluxAggregatorABI.Pack("submit", roundID, submission) assert.NoError(t, err) - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID).Return(fromAddress, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID).Return(fromAddress, nil) fluxAggregator.On("Address").Return(toAddress) idempotencyKey := uuid.New().String() diff --git a/core/services/fluxmonitorv2/delegate.go b/core/services/fluxmonitorv2/delegate.go index 99e2b688f5d..5de59432d11 100644 --- a/core/services/fluxmonitorv2/delegate.go +++ b/core/services/fluxmonitorv2/delegate.go @@ -1,6 +1,8 @@ package fluxmonitorv2 import ( + "context" + "github.com/pkg/errors" "github.com/jmoiron/sqlx" @@ -60,7 +62,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec returns the flux monitor service for the job spec -func (d *Delegate) ServicesForSpec(jb job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { if jb.FluxMonitorSpec == nil { return nil, errors.Errorf("Delegate expects a *job.FluxMonitorSpec to be present, got %v", jb) } diff --git a/core/services/fluxmonitorv2/flux_monitor.go b/core/services/fluxmonitorv2/flux_monitor.go index 8fe0fc7c70e..f21e56cbc80 100644 --- a/core/services/fluxmonitorv2/flux_monitor.go +++ b/core/services/fluxmonitorv2/flux_monitor.go @@ -466,12 +466,17 @@ func formatTime(at time.Time) string { // SetOracleAddress sets the oracle address which matches the node's keys. // If none match, it uses the first available key func (fm *FluxMonitor) SetOracleAddress() error { + + // fm on deprecation path, using dangling context + ctx, cancel := fm.chStop.NewCtx() + defer cancel() + oracleAddrs, err := fm.fluxAggregator.GetOracles(nil) if err != nil { fm.logger.Error("failed to get list of oracles from FluxAggregator contract") return errors.Wrap(err, "failed to get list of oracles from FluxAggregator contract") } - keys, err := fm.keyStore.EnabledKeysForChain(fm.chainID) + keys, err := fm.keyStore.EnabledKeysForChain(ctx, fm.chainID) if err != nil { return errors.Wrap(err, "failed to load keys") } diff --git a/core/services/fluxmonitorv2/flux_monitor_test.go b/core/services/fluxmonitorv2/flux_monitor_test.go index b13edcc12d8..e4db716bbbb 100644 --- a/core/services/fluxmonitorv2/flux_monitor_test.go +++ b/core/services/fluxmonitorv2/flux_monitor_test.go @@ -358,7 +358,7 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { fm, tm := setup(t, db) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(tc.connected).Once() // Setup Answers @@ -508,7 +508,7 @@ func TestFluxMonitor_PollIfEligible_Creates_JobErr(t *testing.T) { fm, tm := setup(t, db) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Once() tm.jobORM. @@ -559,7 +559,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { readyToFillQueue := cltest.NewAwaiter() logsAwaiter := cltest.NewAwaiter() - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.fluxAggregator.On("Address").Return(common.Address{}) tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(freshContractRoundDataResponse()).Maybe() @@ -754,7 +754,7 @@ func TestFluxMonitor_TriggerIdleTimeThreshold(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(tc.idleTimerDisabled), setIdleTimerPeriod(tc.idleDuration), withORM(orm)) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -833,7 +833,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { setHibernationState(t, true), ) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -926,7 +926,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { setFlags(flags), ) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -1034,7 +1034,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { setIdleTimerPeriod(2*time.Second), ) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -1139,7 +1139,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutAtZero(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) tm.keyStore. - On("EnabledKeysForChain", testutils.FixtureChainID). + On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID). Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil). Twice() // Once called from the test, once during start @@ -1200,7 +1200,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_RoundTimeout(t *testing.T) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("Register", mock.Anything, mock.Anything).Return(func() {}) tm.logBroadcaster.On("IsConnected").Return(true).Maybe() @@ -1273,7 +1273,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_IdleTimer(t *testing.T) { ) initialPollOccurred := make(chan struct{}, 1) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("Register", mock.Anything, mock.Anything).Return(func() {}) tm.logBroadcaster.On("IsConnected").Return(true).Maybe() tm.fluxAggregator.On("Address").Return(common.Address{}) @@ -1331,7 +1331,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutNotZero(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -1466,7 +1466,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { run := &pipeline.Run{ID: 1} - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() // Mocks initiated by the New Round log @@ -1581,7 +1581,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { ) run := &pipeline.Run{ID: 1} - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() // First, force the node to try to poll, which should result in a submission @@ -1677,7 +1677,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { answer = 100 ) run := &pipeline.Run{ID: 1} - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() // First, force the node to try to poll, which should result in a submission @@ -1823,7 +1823,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), enableDrumbeatTicker("@every 3s", 2*time.Second)) - tm.keyStore.On("EnabledKeysForChain", testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil) + tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil) const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) diff --git a/core/services/fluxmonitorv2/key_store.go b/core/services/fluxmonitorv2/key_store.go index 185b59311cb..070d392a922 100644 --- a/core/services/fluxmonitorv2/key_store.go +++ b/core/services/fluxmonitorv2/key_store.go @@ -1,6 +1,7 @@ package fluxmonitorv2 import ( + "context" "math/big" "github.com/ethereum/go-ethereum/common" @@ -13,8 +14,8 @@ import ( // KeyStoreInterface defines an interface to interact with the keystore type KeyStoreInterface interface { - EnabledKeysForChain(chainID *big.Int) ([]ethkey.KeyV2, error) - GetRoundRobinAddress(chainID *big.Int, addrs ...common.Address) (common.Address, error) + EnabledKeysForChain(ctx context.Context, chainID *big.Int) ([]ethkey.KeyV2, error) + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addrs ...common.Address) (common.Address, error) } // KeyStore implements KeyStoreInterface diff --git a/core/services/fluxmonitorv2/key_store_test.go b/core/services/fluxmonitorv2/key_store_test.go index ed0485d3b3c..fdef2ade210 100644 --- a/core/services/fluxmonitorv2/key_store_test.go +++ b/core/services/fluxmonitorv2/key_store_test.go @@ -13,6 +13,7 @@ import ( func TestKeyStore_EnabledKeysForChain(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := pgtest.NewQConfig(true) @@ -20,17 +21,17 @@ func TestKeyStore_EnabledKeysForChain(t *testing.T) { ks := fluxmonitorv2.NewKeyStore(ethKeyStore) - key, err := ethKeyStore.Create(testutils.FixtureChainID) + key, err := ethKeyStore.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) - key2, err := ethKeyStore.Create(testutils.SimulatedChainID) + key2, err := ethKeyStore.Create(ctx, testutils.SimulatedChainID) require.NoError(t, err) - keys, err := ks.EnabledKeysForChain(testutils.FixtureChainID) + keys, err := ks.EnabledKeysForChain(ctx, testutils.FixtureChainID) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, key, keys[0]) - keys, err = ks.EnabledKeysForChain(testutils.SimulatedChainID) + keys, err = ks.EnabledKeysForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, key2, keys[0]) @@ -39,6 +40,8 @@ func TestKeyStore_EnabledKeysForChain(t *testing.T) { func TestKeyStore_GetRoundRobinAddress(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) cfg := pgtest.NewQConfig(true) ethKeyStore := cltest.NewKeyStore(t, db, cfg).Eth() @@ -48,7 +51,7 @@ func TestKeyStore_GetRoundRobinAddress(t *testing.T) { ks := fluxmonitorv2.NewKeyStore(ethKeyStore) // Gets the only address in the keystore - addr, err := ks.GetRoundRobinAddress(testutils.FixtureChainID) + addr, err := ks.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) require.Equal(t, k0Address, addr) } diff --git a/core/services/fluxmonitorv2/mocks/key_store_interface.go b/core/services/fluxmonitorv2/mocks/key_store_interface.go index 98f5ab71020..7b2aac75e2f 100644 --- a/core/services/fluxmonitorv2/mocks/key_store_interface.go +++ b/core/services/fluxmonitorv2/mocks/key_store_interface.go @@ -3,9 +3,11 @@ package mocks import ( + context "context" big "math/big" common "github.com/ethereum/go-ethereum/common" + ethkey "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" mock "github.com/stretchr/testify/mock" @@ -16,9 +18,9 @@ type KeyStoreInterface struct { mock.Mock } -// EnabledKeysForChain provides a mock function with given fields: chainID -func (_m *KeyStoreInterface) EnabledKeysForChain(chainID *big.Int) ([]ethkey.KeyV2, error) { - ret := _m.Called(chainID) +// EnabledKeysForChain provides a mock function with given fields: ctx, chainID +func (_m *KeyStoreInterface) EnabledKeysForChain(ctx context.Context, chainID *big.Int) ([]ethkey.KeyV2, error) { + ret := _m.Called(ctx, chainID) if len(ret) == 0 { panic("no return value specified for EnabledKeysForChain") @@ -26,19 +28,19 @@ func (_m *KeyStoreInterface) EnabledKeysForChain(chainID *big.Int) ([]ethkey.Key var r0 []ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func(*big.Int) ([]ethkey.KeyV2, error)); ok { - return rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) ([]ethkey.KeyV2, error)); ok { + return rf(ctx, chainID) } - if rf, ok := ret.Get(0).(func(*big.Int) []ethkey.KeyV2); ok { - r0 = rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) []ethkey.KeyV2); ok { + r0 = rf(ctx, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ethkey.KeyV2) } } - if rf, ok := ret.Get(1).(func(*big.Int) error); ok { - r1 = rf(chainID) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, chainID) } else { r1 = ret.Error(1) } @@ -46,14 +48,14 @@ func (_m *KeyStoreInterface) EnabledKeysForChain(chainID *big.Int) ([]ethkey.Key return r0, r1 } -// GetRoundRobinAddress provides a mock function with given fields: chainID, addrs -func (_m *KeyStoreInterface) GetRoundRobinAddress(chainID *big.Int, addrs ...common.Address) (common.Address, error) { +// GetRoundRobinAddress provides a mock function with given fields: ctx, chainID, addrs +func (_m *KeyStoreInterface) GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addrs ...common.Address) (common.Address, error) { _va := make([]interface{}, len(addrs)) for _i := range addrs { _va[_i] = addrs[_i] } var _ca []interface{} - _ca = append(_ca, chainID) + _ca = append(_ca, ctx, chainID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -63,19 +65,19 @@ func (_m *KeyStoreInterface) GetRoundRobinAddress(chainID *big.Int, addrs ...com var r0 common.Address var r1 error - if rf, ok := ret.Get(0).(func(*big.Int, ...common.Address) (common.Address, error)); ok { - return rf(chainID, addrs...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, ...common.Address) (common.Address, error)); ok { + return rf(ctx, chainID, addrs...) } - if rf, ok := ret.Get(0).(func(*big.Int, ...common.Address) common.Address); ok { - r0 = rf(chainID, addrs...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, ...common.Address) common.Address); ok { + r0 = rf(ctx, chainID, addrs...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(common.Address) } } - if rf, ok := ret.Get(1).(func(*big.Int, ...common.Address) error); ok { - r1 = rf(chainID, addrs...) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int, ...common.Address) error); ok { + r1 = rf(ctx, chainID, addrs...) } else { r1 = ret.Error(1) } diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index 8a97f68d1ea..3100877e96a 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -1,6 +1,7 @@ package gateway import ( + "context" "encoding/json" "github.com/google/uuid" @@ -46,7 +47,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec returns the scheduler to be used for running observer jobs -func (d *Delegate) ServicesForSpec(spec job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { if spec.GatewaySpec == nil { return nil, errors.Errorf("services.Delegate expects a *jobSpec.GatewaySpec to be present, got %v", spec) } diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index 3590b526022..9716231868c 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -900,57 +900,62 @@ func TestORM_ValidateKeyStoreMatch(t *testing.T) { } t.Run("test ETH key validation", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.Relay = relay.EVM - err := job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, "bad key") + err := job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, "bad key") require.EqualError(t, err, "no EVM key matching: \"bad key\"") _, evmKey := cltest.MustInsertRandomKey(t, keyStore.Eth()) - err = job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, evmKey.String()) + err = job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, evmKey.String()) require.NoError(t, err) }) t.Run("test Cosmos key validation", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.Relay = relay.Cosmos - err := job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, "bad key") + err := job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, "bad key") require.EqualError(t, err, "no Cosmos key matching: \"bad key\"") cosmosKey, err := keyStore.Cosmos().Create() require.NoError(t, err) - err = job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, cosmosKey.ID()) + err = job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, cosmosKey.ID()) require.NoError(t, err) }) t.Run("test Solana key validation", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.Relay = relay.Solana - err := job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, "bad key") + err := job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, "bad key") require.EqualError(t, err, "no Solana key matching: \"bad key\"") solanaKey, err := keyStore.Solana().Create() require.NoError(t, err) - err = job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, solanaKey.ID()) + err = job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, solanaKey.ID()) require.NoError(t, err) }) t.Run("test Starknet key validation", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.Relay = relay.StarkNet - err := job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, "bad key") + err := job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, "bad key") require.EqualError(t, err, "no Starknet key matching: \"bad key\"") starkNetKey, err := keyStore.StarkNet().Create() require.NoError(t, err) - err = job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, starkNetKey.ID()) + err = job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, starkNetKey.ID()) require.NoError(t, err) }) t.Run("test Mercury ETH key validation", func(t *testing.T) { + ctx := testutils.Context(t) jb.OCR2OracleSpec.PluginType = types.Mercury - err := job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, "bad key") + err := job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, "bad key") require.EqualError(t, err, "no CSA key matching: \"bad key\"") csaKey, err := keyStore.CSA().Create() require.NoError(t, err) - err = job.ValidateKeyStoreMatch(jb.OCR2OracleSpec, keyStore, csaKey.ID()) + err = job.ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keyStore, csaKey.ID()) require.NoError(t, err) }) } diff --git a/core/services/job/orm.go b/core/services/job/orm.go index 2e7bb0a90a5..c608e2cc544 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -192,7 +192,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } } if jb.OCROracleSpec.TransmitterAddress != nil { - _, err := o.keyStore.Eth().Get(jb.OCROracleSpec.TransmitterAddress.Hex()) + _, err := o.keyStore.Eth().Get(q.ParentCtx, jb.OCROracleSpec.TransmitterAddress.Hex()) if err != nil { return errors.Wrapf(ErrNoSuchTransmitterKey, "no key matching transmitter address: %s", jb.OCROracleSpec.TransmitterAddress.Hex()) } @@ -239,7 +239,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } // checks if they are present and if they are valid - sendingKeysDefined, err := areSendingKeysDefined(jb, o.keyStore) + sendingKeysDefined, err := areSendingKeysDefined(q.ParentCtx, jb, o.keyStore) if err != nil { return err } @@ -249,7 +249,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } if !sendingKeysDefined { - if err = ValidateKeyStoreMatch(jb.OCR2OracleSpec, o.keyStore, jb.OCR2OracleSpec.TransmitterID.String); err != nil { + if err = ValidateKeyStoreMatch(q.ParentCtx, jb.OCR2OracleSpec, o.keyStore, jb.OCR2OracleSpec.TransmitterID.String); err != nil { return errors.Wrap(ErrNoSuchTransmitterKey, err.Error()) } } @@ -467,7 +467,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { } // ValidateKeyStoreMatch confirms that the key has a valid match in the keystore -func ValidateKeyStoreMatch(spec *OCR2OracleSpec, keyStore keystore.Master, key string) (err error) { +func ValidateKeyStoreMatch(ctx context.Context, spec *OCR2OracleSpec, keyStore keystore.Master, key string) (err error) { switch spec.PluginType { case types.Mercury, types.LLO: _, err = keyStore.CSA().Get(key) @@ -475,15 +475,15 @@ func ValidateKeyStoreMatch(spec *OCR2OracleSpec, keyStore keystore.Master, key s err = errors.Errorf("no CSA key matching: %q", key) } default: - err = validateKeyStoreMatchForRelay(spec.Relay, keyStore, key) + err = validateKeyStoreMatchForRelay(ctx, spec.Relay, keyStore, key) } return } -func validateKeyStoreMatchForRelay(network relay.Network, keyStore keystore.Master, key string) error { +func validateKeyStoreMatchForRelay(ctx context.Context, network relay.Network, keyStore keystore.Master, key string) error { switch network { case relay.EVM: - _, err := keyStore.Eth().Get(key) + _, err := keyStore.Eth().Get(ctx, key) if err != nil { return errors.Errorf("no EVM key matching: %q", key) } @@ -506,7 +506,7 @@ func validateKeyStoreMatchForRelay(network relay.Network, keyStore keystore.Mast return nil } -func areSendingKeysDefined(jb *Job, keystore keystore.Master) (bool, error) { +func areSendingKeysDefined(ctx context.Context, jb *Job, keystore keystore.Master) (bool, error) { if jb.OCR2OracleSpec.RelayConfig["sendingKeys"] != nil { sendingKeys, err := SendingKeysForJob(jb) if err != nil { @@ -514,7 +514,7 @@ func areSendingKeysDefined(jb *Job, keystore keystore.Master) (bool, error) { } for _, sendingKey := range sendingKeys { - if err = ValidateKeyStoreMatch(jb.OCR2OracleSpec, keystore, sendingKey); err != nil { + if err = ValidateKeyStoreMatch(ctx, jb.OCR2OracleSpec, keystore, sendingKey); err != nil { return false, errors.Wrap(ErrNoSuchSendingKey, err.Error()) } } diff --git a/core/services/job/runner_integration_test.go b/core/services/job/runner_integration_test.go index fb671982ec5..2722e190e24 100644 --- a/core/services/job/runner_integration_test.go +++ b/core/services/job/runner_integration_test.go @@ -465,7 +465,7 @@ answer1 [type=median index=0]; config.Database(), servicetest.Run(t, mailboxtest.NewMonitor(t)), ) - _, err = sd.ServicesForSpec(jb) + _, err = sd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) }) @@ -499,7 +499,7 @@ answer1 [type=median index=0]; config.Database(), servicetest.Run(t, mailboxtest.NewMonitor(t)), ) - _, err = sd.ServicesForSpec(jb) + _, err = sd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) }) @@ -527,7 +527,7 @@ answer1 [type=median index=0]; config.Database(), servicetest.Run(t, mailboxtest.NewMonitor(t)), ) - _, err = sd.ServicesForSpec(jb) + _, err = sd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) }) @@ -584,7 +584,7 @@ answer1 [type=median index=0]; ) jb.OCROracleSpec.CaptureEATelemetry = tc.jbCaptureEATelemetry - services, err := sd.ServicesForSpec(jb) + services, err := sd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) enhancedTelemetryServiceCreated := false @@ -626,7 +626,7 @@ answer1 [type=median index=0]; config.Database(), servicetest.Run(t, mailboxtest.NewMonitor(t)), ) - services, err := sd.ServicesForSpec(*jb) + services, err := sd.ServicesForSpec(testutils.Context(t), *jb) require.NoError(t, err) // Return an error getting the contract code. diff --git a/core/services/job/spawner.go b/core/services/job/spawner.go index a16466fbef1..f0486df1c25 100644 --- a/core/services/job/spawner.go +++ b/core/services/job/spawner.go @@ -70,7 +70,7 @@ type ( // job. In case a given job type relies upon well-defined startup/shutdown // ordering for services, they are started in the order they are given // and stopped in reverse order. - ServicesForSpec(Job) ([]ServiceCtx, error) + ServicesForSpec(context.Context, Job) ([]ServiceCtx, error) AfterJobCreated(Job) BeforeJobDeleted(Job) // OnDeleteJob will be called from within DELETE db transaction. Any db @@ -215,7 +215,7 @@ func (js *spawner) StartService(ctx context.Context, jb Job, qopts ...pg.QOpt) e jb.PipelineSpec.GasLimit = &jb.GasLimit.Uint32 } - srvs, err := delegate.ServicesForSpec(jb) + srvs, err := delegate.ServicesForSpec(ctx, jb) if err != nil { lggr.Errorw("Error creating services for job", "err", err) cctx, cancel := js.chStop.NewCtx() @@ -391,7 +391,7 @@ func (n *NullDelegate) JobType() Type { } // ServicesForSpec does no-op. -func (n *NullDelegate) ServicesForSpec(spec Job) (s []ServiceCtx, err error) { +func (n *NullDelegate) ServicesForSpec(ctx context.Context, spec Job) (s []ServiceCtx, err error) { return } diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index 9dde7a47721..71357a675c3 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -1,6 +1,7 @@ package job_test import ( + "context" "testing" "time" @@ -52,7 +53,7 @@ func (d delegate) JobType() job.Type { } // ServicesForSpec satisfies the job.Delegate interface. -func (d delegate) ServicesForSpec(js job.Job) ([]job.ServiceCtx, error) { +func (d delegate) ServicesForSpec(ctx context.Context, js job.Job) ([]job.ServiceCtx, error) { if js.Type != d.jobType { return nil, nil } diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 4418bea670a..c2c546fcd33 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -1,6 +1,8 @@ package keeper import ( + "context" + "github.com/pkg/errors" "github.com/jmoiron/sqlx" @@ -55,7 +57,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(spec job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { if spec.KeeperSpec == nil { return nil, errors.Errorf("Delegate expects a *job.KeeperSpec to be present, got %v", spec) } diff --git a/core/services/keeper/upkeep_executer_test.go b/core/services/keeper/upkeep_executer_test.go index 590c9720cb2..1018d5f2aab 100644 --- a/core/services/keeper/upkeep_executer_test.go +++ b/core/services/keeper/upkeep_executer_test.go @@ -226,14 +226,15 @@ func Test_UpkeepExecuter_PerformsUpkeep_Happy(t *testing.T) { }) t.Run("errors if submission key not found", func(t *testing.T) { + ctx := testutils.Context(t) _, _, ethMock, executer, registry, _, job, jpv2, _, keyStore, _, _ := setup(t, mockEstimator(t), func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].ChainID = (*ubig.Big)(testutils.SimulatedChainID) }) // replace expected key with random one - _, err := keyStore.Eth().Create(testutils.SimulatedChainID) + _, err := keyStore.Eth().Create(ctx, testutils.SimulatedChainID) require.NoError(t, err) - _, err = keyStore.Eth().Delete(job.KeeperSpec.FromAddress.Hex()) + _, err = keyStore.Eth().Delete(ctx, job.KeeperSpec.FromAddress.Hex()) require.NoError(t, err) registryMock := cltest.NewContractMockReceiver(t, ethMock, keeper.Registry1_1ABI, registry.ContractAddress.Address()) diff --git a/core/services/keystore/eth.go b/core/services/keystore/eth.go index 0e86cccacae..be59cb5e54c 100644 --- a/core/services/keystore/eth.go +++ b/core/services/keystore/eth.go @@ -1,6 +1,7 @@ package keystore import ( + "context" "fmt" "math/big" "sort" @@ -21,34 +22,34 @@ import ( // //go:generate mockery --quiet --name Eth --output mocks/ --case=underscore type Eth interface { - Get(id string) (ethkey.KeyV2, error) - GetAll() ([]ethkey.KeyV2, error) - Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) - Delete(id string) (ethkey.KeyV2, error) - Import(keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) - Export(id string, password string) ([]byte, error) + Get(ctx context.Context, id string) (ethkey.KeyV2, error) + GetAll(ctx context.Context) ([]ethkey.KeyV2, error) + Create(ctx context.Context, chainIDs ...*big.Int) (ethkey.KeyV2, error) + Delete(ctx context.Context, id string) (ethkey.KeyV2, error) + Import(ctx context.Context, keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) + Export(ctx context.Context, id string, password string) ([]byte, error) - Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error - Disable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error - Add(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error + Enable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error + Disable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error + Add(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error - EnsureKeys(chainIDs ...*big.Int) error - SubscribeToKeyChanges() (ch chan struct{}, unsub func()) + EnsureKeys(ctx context.Context, chainIDs ...*big.Int) error + SubscribeToKeyChanges(ctx context.Context) (ch chan struct{}, unsub func()) - SignTx(fromAddress common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) + SignTx(ctx context.Context, fromAddress common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) - EnabledKeysForChain(chainID *big.Int) (keys []ethkey.KeyV2, err error) - GetRoundRobinAddress(chainID *big.Int, addresses ...common.Address) (address common.Address, err error) - CheckEnabled(address common.Address, chainID *big.Int) error + EnabledKeysForChain(ctx context.Context, chainID *big.Int) (keys []ethkey.KeyV2, err error) + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (address common.Address, err error) + CheckEnabled(ctx context.Context, address common.Address, chainID *big.Int) error - GetState(id string, chainID *big.Int) (ethkey.State, error) - GetStatesForKeys([]ethkey.KeyV2) ([]ethkey.State, error) - GetStateForKey(ethkey.KeyV2) (ethkey.State, error) - GetStatesForChain(chainID *big.Int) ([]ethkey.State, error) - EnabledAddressesForChain(chainID *big.Int) (addresses []common.Address, err error) + GetState(ctx context.Context, id string, chainID *big.Int) (ethkey.State, error) + GetStatesForKeys(ctx context.Context, keys []ethkey.KeyV2) ([]ethkey.State, error) + GetStateForKey(ctx context.Context, key ethkey.KeyV2) (ethkey.State, error) + GetStatesForChain(ctx context.Context, chainID *big.Int) ([]ethkey.State, error) + EnabledAddressesForChain(ctx context.Context, chainID *big.Int) (addresses []common.Address, err error) - XXXTestingOnlySetState(ethkey.State) - XXXTestingOnlyAdd(key ethkey.KeyV2) + XXXTestingOnlySetState(ctx context.Context, keyState ethkey.State) + XXXTestingOnlyAdd(ctx context.Context, key ethkey.KeyV2) } type eth struct { @@ -71,7 +72,7 @@ func newEthKeyStore(km *keyManager, orm keystateORM, q pg.Q) *eth { } } -func (ks *eth) Get(id string) (ethkey.KeyV2, error) { +func (ks *eth) Get(ctx context.Context, id string) (ethkey.KeyV2, error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { @@ -80,17 +81,17 @@ func (ks *eth) Get(id string) (ethkey.KeyV2, error) { return ks.getByID(id) } -func (ks *eth) GetAll() (keys []ethkey.KeyV2, _ error) { +func (ks *eth) GetAll(ctx context.Context) (keys []ethkey.KeyV2, _ error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { return nil, ErrLocked } - return ks.getAll(), nil + return ks.getAll(ctx), nil } // caller must hold lock! -func (ks *eth) getAll() (keys []ethkey.KeyV2) { +func (ks *eth) getAll(ctx context.Context) (keys []ethkey.KeyV2) { for _, key := range ks.keyRing.Eth { keys = append(keys, key) } @@ -99,7 +100,7 @@ func (ks *eth) getAll() (keys []ethkey.KeyV2) { } // Create generates a fresh new key and enables it for the given chain IDs -func (ks *eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { +func (ks *eth) Create(ctx context.Context, chainIDs ...*big.Int) (ethkey.KeyV2, error) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -109,7 +110,7 @@ func (ks *eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { if err != nil { return ethkey.KeyV2{}, err } - err = ks.add(key, chainIDs...) + err = ks.add(ctx, key, chainIDs...) if err == nil { ks.notify() } @@ -120,7 +121,7 @@ func (ks *eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { // EnsureKeys ensures that each chain has at least one key with a state // linked to that chain. If a key and state exists for a chain but it is // disabled, we do not enable it automatically here. -func (ks *eth) EnsureKeys(chainIDs ...*big.Int) (err error) { +func (ks *eth) EnsureKeys(ctx context.Context, chainIDs ...*big.Int) (err error) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -136,7 +137,7 @@ func (ks *eth) EnsureKeys(chainIDs ...*big.Int) (err error) { if err != nil { return err } - err = ks.add(newKey, chainID) + err = ks.add(ctx, newKey, chainID) if err != nil { return err } @@ -146,7 +147,7 @@ func (ks *eth) EnsureKeys(chainIDs ...*big.Int) (err error) { return nil } -func (ks *eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) { +func (ks *eth) Import(ctx context.Context, keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -160,7 +161,7 @@ func (ks *eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (et if _, found := ks.keyRing.Eth[key.ID()]; found { return ethkey.KeyV2{}, ErrKeyExists } - err = ks.add(key, chainIDs...) + err = ks.add(ctx, key, chainIDs...) if err != nil { return ethkey.KeyV2{}, errors.Wrap(err, "unable to add eth key") } @@ -168,7 +169,7 @@ func (ks *eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (et return key, nil } -func (ks *eth) Export(id string, password string) ([]byte, error) { +func (ks *eth) Export(ctx context.Context, id string, password string) ([]byte, error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { @@ -181,23 +182,25 @@ func (ks *eth) Export(id string, password string) ([]byte, error) { return key.ToEncryptedJSON(password, ks.scryptParams) } -func (ks *eth) Add(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) Add(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { ks.lock.Lock() defer ks.lock.Unlock() _, found := ks.keyRing.Eth[address.Hex()] if !found { return ErrKeyNotFound } - return ks.addKey(address, chainID, qopts...) + return ks.addKey(ctx, address, chainID, qopts...) } // caller must hold lock! -func (ks *eth) addKey(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) addKey(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { state := new(ethkey.State) sql := `INSERT INTO evm.key_states (address, disabled, evm_chain_id, created_at, updated_at) VALUES ($1, false, $2, NOW(), NOW()) RETURNING *;` q := ks.q.WithOpts(qopts...) + q = q.WithOpts(pg.WithParentCtx(ctx)) + if err := q.Get(state, sql, address, chainID.String()); err != nil { return errors.Wrap(err, "failed to insert evm_key_state") } @@ -207,18 +210,18 @@ func (ks *eth) addKey(address common.Address, chainID *big.Int, qopts ...pg.QOpt return nil } -func (ks *eth) Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) Enable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { ks.lock.Lock() defer ks.lock.Unlock() _, found := ks.keyRing.Eth[address.Hex()] if !found { return ErrKeyNotFound } - return ks.enable(address, chainID, qopts...) + return ks.enable(ctx, address, chainID, qopts...) } // caller must hold lock! -func (ks *eth) enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) enable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { state := new(ethkey.State) q := ks.q.WithOpts(qopts...) sql := `INSERT INTO evm.key_states as key_states ("address", "evm_chain_id", "disabled", "created_at", "updated_at") VALUES ($1, $2, false, NOW(), NOW()) @@ -237,17 +240,17 @@ func (ks *eth) enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt return nil } -func (ks *eth) Disable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) Disable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { ks.lock.Lock() defer ks.lock.Unlock() _, found := ks.keyRing.Eth[address.Hex()] if !found { return errors.Errorf("no key exists with ID %s", address.Hex()) } - return ks.disable(address, chainID, qopts...) + return ks.disable(ctx, address, chainID, qopts...) } -func (ks *eth) disable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +func (ks *eth) disable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { state := new(ethkey.State) q := ks.q.WithOpts(qopts...) sql := `INSERT INTO evm.key_states as key_states ("address", "evm_chain_id", "disabled", "created_at", "updated_at") VALUES ($1, $2, true, NOW(), NOW()) @@ -266,7 +269,7 @@ func (ks *eth) disable(address common.Address, chainID *big.Int, qopts ...pg.QOp return nil } -func (ks *eth) Delete(id string) (ethkey.KeyV2, error) { +func (ks *eth) Delete(ctx context.Context, id string) (ethkey.KeyV2, error) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -288,7 +291,7 @@ func (ks *eth) Delete(id string) (ethkey.KeyV2, error) { return key, nil } -func (ks *eth) SubscribeToKeyChanges() (ch chan struct{}, unsub func()) { +func (ks *eth) SubscribeToKeyChanges(ctx context.Context) (ch chan struct{}, unsub func()) { ch = make(chan struct{}, 1) ks.subscribersMu.Lock() defer ks.subscribersMu.Unlock() @@ -305,7 +308,7 @@ func (ks *eth) SubscribeToKeyChanges() (ch chan struct{}, unsub func()) { } } -func (ks *eth) SignTx(address common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { +func (ks *eth) SignTx(ctx context.Context, address common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { @@ -320,7 +323,7 @@ func (ks *eth) SignTx(address common.Address, tx *types.Transaction, chainID *bi } // EnabledKeysForChain returns all keys that are enabled for the given chain -func (ks *eth) EnabledKeysForChain(chainID *big.Int) (sendingKeys []ethkey.KeyV2, err error) { +func (ks *eth) EnabledKeysForChain(ctx context.Context, chainID *big.Int) (sendingKeys []ethkey.KeyV2, err error) { if chainID == nil { return nil, errors.New("chainID must be non-nil") } @@ -332,7 +335,7 @@ func (ks *eth) EnabledKeysForChain(chainID *big.Int) (sendingKeys []ethkey.KeyV2 return ks.enabledKeysForChain(chainID), nil } -func (ks *eth) GetRoundRobinAddress(chainID *big.Int, whitelist ...common.Address) (common.Address, error) { +func (ks *eth) GetRoundRobinAddress(ctx context.Context, chainID *big.Int, whitelist ...common.Address) (common.Address, error) { if chainID == nil { return common.Address{}, errors.New("chainID must be non-nil") } @@ -381,7 +384,7 @@ func (ks *eth) GetRoundRobinAddress(chainID *big.Int, whitelist ...common.Addres // CheckEnabled returns nil if state is present and enabled // The complexity here comes because we want to return nice, useful error messages -func (ks *eth) CheckEnabled(address common.Address, chainID *big.Int) error { +func (ks *eth) CheckEnabled(ctx context.Context, address common.Address, chainID *big.Int) error { if utils.IsZero(address) { return errors.Errorf("empty address provided as input") } @@ -423,7 +426,7 @@ func (ks *eth) CheckEnabled(address common.Address, chainID *big.Int) error { return nil } -func (ks *eth) GetState(id string, chainID *big.Int) (ethkey.State, error) { +func (ks *eth) GetState(ctx context.Context, id string, chainID *big.Int) (ethkey.State, error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { @@ -436,7 +439,7 @@ func (ks *eth) GetState(id string, chainID *big.Int) (ethkey.State, error) { return *state, nil } -func (ks *eth) GetStatesForKeys(keys []ethkey.KeyV2) (states []ethkey.State, err error) { +func (ks *eth) GetStatesForKeys(ctx context.Context, keys []ethkey.KeyV2) (states []ethkey.State, err error) { ks.lock.RLock() defer ks.lock.RUnlock() for _, state := range ks.keyStates.All { @@ -451,7 +454,7 @@ func (ks *eth) GetStatesForKeys(keys []ethkey.KeyV2) (states []ethkey.State, err } // Useful to fetch the ChainID for a given key -func (ks *eth) GetStateForKey(key ethkey.KeyV2) (state ethkey.State, err error) { +func (ks *eth) GetStateForKey(ctx context.Context, key ethkey.KeyV2) (state ethkey.State, err error) { ks.lock.RLock() defer ks.lock.RUnlock() for _, state := range ks.keyStates.All { @@ -463,7 +466,7 @@ func (ks *eth) GetStateForKey(key ethkey.KeyV2) (state ethkey.State, err error) return } -func (ks *eth) GetStatesForChain(chainID *big.Int) (states []ethkey.State, err error) { +func (ks *eth) GetStatesForChain(ctx context.Context, chainID *big.Int) (states []ethkey.State, err error) { ks.lock.RLock() defer ks.lock.RUnlock() if ks.isLocked() { @@ -476,7 +479,7 @@ func (ks *eth) GetStatesForChain(chainID *big.Int) (states []ethkey.State, err e return } -func (ks *eth) EnabledAddressesForChain(chainID *big.Int) (addresses []common.Address, err error) { +func (ks *eth) EnabledAddressesForChain(ctx context.Context, chainID *big.Int) (addresses []common.Address, err error) { ks.lock.RLock() defer ks.lock.RUnlock() if chainID == nil { @@ -495,7 +498,7 @@ func (ks *eth) EnabledAddressesForChain(chainID *big.Int) (addresses []common.Ad } // XXXTestingOnlySetState is only used in tests to manually update a key's state -func (ks *eth) XXXTestingOnlySetState(state ethkey.State) { +func (ks *eth) XXXTestingOnlySetState(ctx context.Context, state ethkey.State) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -515,7 +518,7 @@ func (ks *eth) XXXTestingOnlySetState(state ethkey.State) { } // XXXTestingOnlyAdd is only used in tests to manually add a key -func (ks *eth) XXXTestingOnlyAdd(key ethkey.KeyV2) { +func (ks *eth) XXXTestingOnlyAdd(ctx context.Context, key ethkey.KeyV2) { ks.lock.Lock() defer ks.lock.Unlock() if ks.isLocked() { @@ -524,7 +527,7 @@ func (ks *eth) XXXTestingOnlyAdd(key ethkey.KeyV2) { if _, found := ks.keyRing.Eth[key.ID()]; found { panic(fmt.Sprintf("key with ID %s already exists", key.ID())) } - err := ks.add(key) + err := ks.add(ctx, key) if err != nil { panic(err.Error()) } @@ -561,10 +564,10 @@ func (ks *eth) keysForChain(chainID *big.Int, includeDisabled bool) (keys []ethk } // caller must hold lock! -func (ks *eth) add(key ethkey.KeyV2, chainIDs ...*big.Int) (err error) { +func (ks *eth) add(ctx context.Context, key ethkey.KeyV2, chainIDs ...*big.Int) (err error) { err = ks.safeAddKey(key, func(tx pg.Queryer) (serr error) { for _, chainID := range chainIDs { - if serr = ks.addKey(key.Address, chainID, pg.WithQueryer(tx)); serr != nil { + if serr = ks.addKey(ctx, key.Address, chainID, pg.WithQueryer(tx)); serr != nil { return serr } } diff --git a/core/services/keystore/eth_test.go b/core/services/keystore/eth_test.go index dd42f4049c3..573830638ab 100644 --- a/core/services/keystore/eth_test.go +++ b/core/services/keystore/eth_test.go @@ -43,14 +43,15 @@ func Test_EthKeyStore(t *testing.T) { const statesTableName = "evm.key_states" t.Run("Create / GetAll / Get", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ethKeyStore.Create(&cltest.FixtureChainID) + key, err := ethKeyStore.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - retrievedKeys, err := ethKeyStore.GetAll() + retrievedKeys, err := ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 1, len(retrievedKeys)) require.Equal(t, key.Address, retrievedKeys[0].Address) - foundKey, err := ethKeyStore.Get(key.Address.Hex()) + foundKey, err := ethKeyStore.Get(ctx, key.Address.Hex()) require.NoError(t, err) require.Equal(t, key, foundKey) // adds ethkey.State @@ -62,27 +63,28 @@ func Test_EthKeyStore(t *testing.T) { // adds key to db keyStore.ResetXXXTestOnly() require.NoError(t, keyStore.Unlock(cltest.Password)) - retrievedKeys, err = ethKeyStore.GetAll() + retrievedKeys, err = ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 1, len(retrievedKeys)) require.Equal(t, key.Address, retrievedKeys[0].Address) // adds 2nd key - _, err = ethKeyStore.Create(&cltest.FixtureChainID) + _, err = ethKeyStore.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - retrievedKeys, err = ethKeyStore.GetAll() + retrievedKeys, err = ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 2, len(retrievedKeys)) }) t.Run("GetAll ordering", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() var keys []ethkey.KeyV2 for i := 0; i < 5; i++ { - key, err := ethKeyStore.Create(&cltest.FixtureChainID) + key, err := ethKeyStore.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) keys = append(keys, key) } - retrievedKeys, err := ethKeyStore.GetAll() + retrievedKeys, err := ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 5, len(retrievedKeys)) @@ -92,20 +94,22 @@ func Test_EthKeyStore(t *testing.T) { }) t.Run("RemoveKey", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ethKeyStore.Create(&cltest.FixtureChainID) + key, err := ethKeyStore.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - _, err = ethKeyStore.Delete(key.ID()) + _, err = ethKeyStore.Delete(ctx, key.ID()) require.NoError(t, err) - retrievedKeys, err := ethKeyStore.GetAll() + retrievedKeys, err := ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 0, len(retrievedKeys)) cltest.AssertCount(t, db, statesTableName, 0) }) t.Run("Delete removes key even if evm.txes are present", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ethKeyStore.Create(&cltest.FixtureChainID) + key, err := ethKeyStore.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) // ensure at least one state is present cltest.AssertCount(t, db, statesTableName, 1) @@ -114,27 +118,28 @@ func Test_EthKeyStore(t *testing.T) { txStore := cltest.NewTestTxStore(t, db, cfg.Database()) cltest.MustInsertConfirmedEthTxWithLegacyAttempt(t, txStore, 0, 42, key.Address) - _, err = ethKeyStore.Delete(key.ID()) + _, err = ethKeyStore.Delete(ctx, key.ID()) require.NoError(t, err) - retrievedKeys, err := ethKeyStore.GetAll() + retrievedKeys, err := ethKeyStore.GetAll(ctx) require.NoError(t, err) require.Equal(t, 0, len(retrievedKeys)) cltest.AssertCount(t, db, statesTableName, 0) }) t.Run("EnsureKeys / EnabledKeysForChain", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - err := ethKeyStore.EnsureKeys(&cltest.FixtureChainID) + err := ethKeyStore.EnsureKeys(ctx, &cltest.FixtureChainID) assert.NoError(t, err) - sendingKeys1, err := ethKeyStore.EnabledKeysForChain(testutils.FixtureChainID) + sendingKeys1, err := ethKeyStore.EnabledKeysForChain(ctx, testutils.FixtureChainID) assert.NoError(t, err) require.Equal(t, 1, len(sendingKeys1)) cltest.AssertCount(t, db, statesTableName, 1) - err = ethKeyStore.EnsureKeys(&cltest.FixtureChainID) + err = ethKeyStore.EnsureKeys(ctx, &cltest.FixtureChainID) assert.NoError(t, err) - sendingKeys2, err := ethKeyStore.EnabledKeysForChain(testutils.FixtureChainID) + sendingKeys2, err := ethKeyStore.EnabledKeysForChain(ctx, testutils.FixtureChainID) assert.NoError(t, err) require.Equal(t, 1, len(sendingKeys2)) @@ -142,63 +147,65 @@ func Test_EthKeyStore(t *testing.T) { }) t.Run("EnabledKeysForChain with specified chain ID", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ethKeyStore.Create(testutils.FixtureChainID) + key, err := ethKeyStore.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) - key2, err := ethKeyStore.Create(big.NewInt(1337)) + key2, err := ethKeyStore.Create(ctx, big.NewInt(1337)) require.NoError(t, err) - keys, err := ethKeyStore.EnabledKeysForChain(testutils.FixtureChainID) + keys, err := ethKeyStore.EnabledKeysForChain(ctx, testutils.FixtureChainID) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, key, keys[0]) - keys, err = ethKeyStore.EnabledKeysForChain(big.NewInt(1337)) + keys, err = ethKeyStore.EnabledKeysForChain(ctx, big.NewInt(1337)) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, key2, keys[0]) - _, err = ethKeyStore.EnabledKeysForChain(nil) + _, err = ethKeyStore.EnabledKeysForChain(ctx, nil) assert.Error(t, err) assert.EqualError(t, err, "chainID must be non-nil") }) t.Run("EnabledAddressesForChain with specified chain ID", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ethKeyStore.Create(testutils.FixtureChainID) + key, err := ethKeyStore.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) - key2, err := ethKeyStore.Create(big.NewInt(1337)) + key2, err := ethKeyStore.Create(ctx, big.NewInt(1337)) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 2) - keys, err := ethKeyStore.GetAll() + keys, err := ethKeyStore.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 2) //get enabled addresses for FixtureChainID - enabledAddresses, err := ethKeyStore.EnabledAddressesForChain(testutils.FixtureChainID) + enabledAddresses, err := ethKeyStore.EnabledAddressesForChain(ctx, testutils.FixtureChainID) require.NoError(t, err) require.Len(t, enabledAddresses, 1) require.Equal(t, key.Address, enabledAddresses[0]) //get enabled addresses for chain 1337 - enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(big.NewInt(1337)) + enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(ctx, big.NewInt(1337)) require.NoError(t, err) require.Len(t, enabledAddresses, 1) require.Equal(t, key2.Address, enabledAddresses[0]) // /get enabled addresses for nil chain ID - _, err = ethKeyStore.EnabledAddressesForChain(nil) + _, err = ethKeyStore.EnabledAddressesForChain(ctx, nil) assert.Error(t, err) assert.EqualError(t, err, "chainID must be non-nil") // disable the key for chain FixtureChainID - err = ethKeyStore.Disable(key.Address, testutils.FixtureChainID) + err = ethKeyStore.Disable(ctx, key.Address, testutils.FixtureChainID) require.NoError(t, err) - enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(testutils.FixtureChainID) + enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(ctx, testutils.FixtureChainID) require.NoError(t, err) assert.Len(t, enabledAddresses, 0) - enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(big.NewInt(1337)) + enabledAddresses, err = ethKeyStore.EnabledAddressesForChain(ctx, big.NewInt(1337)) require.NoError(t, err) assert.Len(t, enabledAddresses, 1) require.Equal(t, key2.Address, enabledAddresses[0]) @@ -206,6 +213,7 @@ func Test_EthKeyStore(t *testing.T) { } func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { + ctx := testutils.Context(t) t.Parallel() db := pgtest.NewSqlxDB(t) @@ -215,7 +223,8 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { ethKeyStore := keyStore.Eth() t.Run("should error when no addresses", func(t *testing.T) { - _, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + ctx1 := testutils.Context(t) + _, err := ethKeyStore.GetRoundRobinAddress(ctx1, testutils.FixtureChainID) require.Error(t, err) }) @@ -231,38 +240,38 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { // - key 4 // enabled - fixture k1, _ := cltest.MustInsertRandomKeyNoChains(t, ethKeyStore) - require.NoError(t, ethKeyStore.Add(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ethKeyStore.Add(k1.Address, testutils.SimulatedChainID)) - require.NoError(t, ethKeyStore.Enable(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ethKeyStore.Enable(k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k1.Address, testutils.SimulatedChainID)) k2, _ := cltest.MustInsertRandomKeyNoChains(t, ethKeyStore) - require.NoError(t, ethKeyStore.Add(k2.Address, testutils.FixtureChainID)) - require.NoError(t, ethKeyStore.Add(k2.Address, testutils.SimulatedChainID)) - require.NoError(t, ethKeyStore.Enable(k2.Address, testutils.FixtureChainID)) - require.NoError(t, ethKeyStore.Enable(k2.Address, testutils.SimulatedChainID)) - require.NoError(t, ethKeyStore.Disable(k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k2.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k2.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Disable(ctx, k2.Address, testutils.SimulatedChainID)) k3, _ := cltest.MustInsertRandomKeyNoChains(t, ethKeyStore) - require.NoError(t, ethKeyStore.Add(k3.Address, testutils.SimulatedChainID)) - require.NoError(t, ethKeyStore.Enable(k3.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k3.Address, testutils.SimulatedChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k3.Address, testutils.SimulatedChainID)) k4, _ := cltest.MustInsertRandomKeyNoChains(t, ethKeyStore) - require.NoError(t, ethKeyStore.Add(k4.Address, testutils.FixtureChainID)) - require.NoError(t, ethKeyStore.Enable(k4.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Add(ctx, k4.Address, testutils.FixtureChainID)) + require.NoError(t, ethKeyStore.Enable(ctx, k4.Address, testutils.FixtureChainID)) t.Run("with no address filter, rotates between all enabled addresses", func(t *testing.T) { - address1, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address1, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) - address2, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address2, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) - address3, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address3, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) - address4, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address4, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) - address5, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address5, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) - address6, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID) + address6, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID) require.NoError(t, err) assert.NotEqual(t, address1, address2) @@ -278,13 +287,13 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { // k3 is a disabled address for FixtureChainID so even though it's whitelisted, it will be ignored addresses := []common.Address{k4.Address, k3.Address, k1.Address, k2.Address, testutils.NewAddress()} - address1, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID, addresses...) + address1, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID, addresses...) require.NoError(t, err) - address2, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID, addresses...) + address2, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID, addresses...) require.NoError(t, err) - address3, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID, addresses...) + address3, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID, addresses...) require.NoError(t, err) - address4, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID, addresses...) + address4, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID, addresses...) require.NoError(t, err) assert.NotEqual(t, k3.Address, address1) @@ -301,13 +310,13 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { // k2 and k4 are disabled address for SimulatedChainID so even though it's whitelisted, it will be ignored addresses := []common.Address{k4.Address, k3.Address, k1.Address, k2.Address, testutils.NewAddress()} - address1, err := ethKeyStore.GetRoundRobinAddress(testutils.SimulatedChainID, addresses...) + address1, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.SimulatedChainID, addresses...) require.NoError(t, err) - address2, err := ethKeyStore.GetRoundRobinAddress(testutils.SimulatedChainID, addresses...) + address2, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.SimulatedChainID, addresses...) require.NoError(t, err) - address3, err := ethKeyStore.GetRoundRobinAddress(testutils.SimulatedChainID, addresses...) + address3, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.SimulatedChainID, addresses...) require.NoError(t, err) - address4, err := ethKeyStore.GetRoundRobinAddress(testutils.SimulatedChainID, addresses...) + address4, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.SimulatedChainID, addresses...) require.NoError(t, err) assert.True(t, address1 == k1.Address || address1 == k3.Address) @@ -320,7 +329,7 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { t.Run("with address filter when no address matches", func(t *testing.T) { addr := testutils.NewAddress() - _, err := ethKeyStore.GetRoundRobinAddress(testutils.FixtureChainID, []common.Address{addr}...) + _, err := ethKeyStore.GetRoundRobinAddress(ctx, testutils.FixtureChainID, []common.Address{addr}...) require.Error(t, err) require.Equal(t, fmt.Sprintf("no sending keys available for chain %s that match whitelist: [%s]", testutils.FixtureChainID.String(), addr.Hex()), err.Error()) }) @@ -329,6 +338,8 @@ func Test_EthKeyStore_GetRoundRobinAddress(t *testing.T) { func Test_EthKeyStore_SignTx(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) config := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) @@ -340,10 +351,10 @@ func Test_EthKeyStore_SignTx(t *testing.T) { tx := cltest.NewLegacyTransaction(0, testutils.NewAddress(), big.NewInt(53), 21000, big.NewInt(1000000000), []byte{1, 2, 3, 4}) randomAddress := testutils.NewAddress() - _, err := ethKeyStore.SignTx(randomAddress, tx, chainID) + _, err := ethKeyStore.SignTx(ctx, randomAddress, tx, chainID) require.EqualError(t, err, "Key not found") - signed, err := ethKeyStore.SignTx(k.Address, tx, chainID) + signed, err := ethKeyStore.SignTx(ctx, k.Address, tx, chainID) require.NoError(t, err) require.NotEqual(t, tx, signed) @@ -367,78 +378,85 @@ func Test_EthKeyStore_E2E(t *testing.T) { } t.Run("initializes with an empty state", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) require.NoError(t, err) require.Equal(t, 0, len(keys)) }) t.Run("errors when getting non-existent ID", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - _, err := ks.Get("non-existent-id") + _, err := ks.Get(ctx, "non-existent-id") require.Error(t, err) }) t.Run("creates a key", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ks.Create(&cltest.FixtureChainID) + key, err := ks.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - retrievedKey, err := ks.Get(key.ID()) + retrievedKey, err := ks.Get(ctx, key.ID()) require.NoError(t, err) require.Equal(t, key, retrievedKey) }) t.Run("imports and exports a key", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() - key, err := ks.Create(&cltest.FixtureChainID) + key, err := ks.Create(ctx, &cltest.FixtureChainID) require.NoError(t, err) - exportJSON, err := ks.Export(key.ID(), cltest.Password) + exportJSON, err := ks.Export(ctx, key.ID(), cltest.Password) require.NoError(t, err) - _, err = ks.Delete(key.ID()) + _, err = ks.Delete(ctx, key.ID()) require.NoError(t, err) - _, err = ks.Get(key.ID()) + _, err = ks.Get(ctx, key.ID()) require.Error(t, err) - importedKey, err := ks.Import(exportJSON, cltest.Password, &cltest.FixtureChainID) + importedKey, err := ks.Import(ctx, exportJSON, cltest.Password, &cltest.FixtureChainID) require.NoError(t, err) require.Equal(t, key.ID(), importedKey.ID()) - retrievedKey, err := ks.Get(key.ID()) + retrievedKey, err := ks.Get(ctx, key.ID()) require.NoError(t, err) require.Equal(t, importedKey, retrievedKey) }) t.Run("adds an externally created key / deletes a key", func(t *testing.T) { + ctx := testutils.Context(t) defer reset() newKey, err := ethkey.NewV2() require.NoError(t, err) - ks.XXXTestingOnlyAdd(newKey) - keys, err := ks.GetAll() + ks.XXXTestingOnlyAdd(ctx, newKey) + keys, err := ks.GetAll(ctx) require.NoError(t, err) assert.Equal(t, 1, len(keys)) - _, err = ks.Delete(newKey.ID()) + _, err = ks.Delete(ctx, newKey.ID()) require.NoError(t, err) - keys, err = ks.GetAll() + keys, err = ks.GetAll(ctx) require.NoError(t, err) assert.Equal(t, 0, len(keys)) - _, err = ks.Get(newKey.ID()) + _, err = ks.Get(ctx, newKey.ID()) assert.Error(t, err) - _, err = ks.Delete(newKey.ID()) + _, err = ks.Delete(ctx, newKey.ID()) assert.Error(t, err) }) t.Run("imports a key exported from a v1 keystore", func(t *testing.T) { + ctx := testutils.Context(t) exportedKey := `{"address":"0dd359b4f22a30e44b2fd744b679971941865820","crypto":{"cipher":"aes-128-ctr","ciphertext":"b30af964a3b3f37894e599446b4cf2314bbfcd1062e6b35b620d3d20bd9965cc","cipherparams":{"iv":"58a8d75629cc1945da7cf8c24520d1dc"},"kdf":"scrypt","kdfparams":{"dklen":32,"n":262144,"p":1,"r":8,"salt":"c352887e9d427d8a6a1869082619b73fac4566082a99f6e367d126f11b434f28"},"mac":"fd76a588210e0bf73d01332091e0e83a4584ee2df31eaec0e27f9a1b94f024b4"},"id":"a5ee0802-1d7b-45b6-aeb8-ea8a3351e715","version":3}` - importedKey, err := ks.Import([]byte(exportedKey), "p4SsW0rD1!@#_", &cltest.FixtureChainID) + importedKey, err := ks.Import(ctx, []byte(exportedKey), "p4SsW0rD1!@#_", &cltest.FixtureChainID) require.NoError(t, err) assert.Equal(t, "0x0dd359b4f22a30E44b2fD744B679971941865820", importedKey.ID()) - k, err := ks.Import([]byte(exportedKey), cltest.Password, &cltest.FixtureChainID) + k, err := ks.Import(ctx, []byte(exportedKey), cltest.Password, &cltest.FixtureChainID) assert.Empty(t, k) assert.Error(t, err) }) t.Run("fails to export a non-existent key", func(t *testing.T) { - k, err := ks.Export("non-existent", cltest.Password) + ctx := testutils.Context(t) + k, err := ks.Export(ctx, "non-existent", cltest.Password) assert.Empty(t, k) assert.Error(t, err) @@ -448,24 +466,25 @@ func Test_EthKeyStore_E2E(t *testing.T) { defer reset() t.Run("returns states for keys", func(t *testing.T) { + ctx := testutils.Context(t) k1, err := ethkey.NewV2() require.NoError(t, err) k2, err := ethkey.NewV2() require.NoError(t, err) - ks.XXXTestingOnlyAdd(k1) - ks.XXXTestingOnlyAdd(k2) - require.NoError(t, ks.Add(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) + ks.XXXTestingOnlyAdd(ctx, k1) + ks.XXXTestingOnlyAdd(ctx, k2) + require.NoError(t, ks.Add(ctx, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx, k1.Address, testutils.FixtureChainID)) - states, err := ks.GetStatesForKeys([]ethkey.KeyV2{k1, k2}) + states, err := ks.GetStatesForKeys(ctx, []ethkey.KeyV2{k1, k2}) require.NoError(t, err) assert.Len(t, states, 1) - chainStates, err := ks.GetStatesForChain(testutils.FixtureChainID) + chainStates, err := ks.GetStatesForChain(ctx, testutils.FixtureChainID) require.NoError(t, err) assert.Len(t, chainStates, 2) // one created here, one created above - chainStates, err = ks.GetStatesForChain(testutils.SimulatedChainID) + chainStates, err = ks.GetStatesForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) assert.Len(t, chainStates, 0) }) @@ -475,13 +494,15 @@ func Test_EthKeyStore_E2E(t *testing.T) { func Test_EthKeyStore_SubscribeToKeyChanges(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) + chDone := make(chan struct{}) defer func() { close(chDone) }() db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) ks := keyStore.Eth() - chSub, unsubscribe := ks.SubscribeToKeyChanges() + chSub, unsubscribe := ks.SubscribeToKeyChanges(ctx) defer unsubscribe() var count atomic.Int32 @@ -517,28 +538,28 @@ func Test_EthKeyStore_SubscribeToKeyChanges(t *testing.T) { count.Store(0) } - err := ks.EnsureKeys(&cltest.FixtureChainID) + err := ks.EnsureKeys(ctx, &cltest.FixtureChainID) require.NoError(t, err) assertCountAtLeast(1) drainAndReset() // Create the key includes a state, triggering notify - k1, err := ks.Create(testutils.FixtureChainID) + k1, err := ks.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) assertCountAtLeast(1) drainAndReset() // Enabling the key for a new state triggers the notification callback again - require.NoError(t, ks.Add(k1.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, k1.Address, testutils.SimulatedChainID)) assertCountAtLeast(1) drainAndReset() // Disabling triggers a notify - require.NoError(t, ks.Disable(k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Disable(ctx, k1.Address, testutils.SimulatedChainID)) assertCountAtLeast(1) } @@ -551,38 +572,42 @@ func Test_EthKeyStore_Enable(t *testing.T) { ks := keyStore.Eth() t.Run("already existing disabled key gets enabled", func(t *testing.T) { + ctx := testutils.Context(t) k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Disable(k.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) - key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, ks.Add(ctx, k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Disable(ctx, k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(ctx, k.Address.Hex(), testutils.SimulatedChainID) require.NoError(t, err) require.Equal(t, key.Disabled, false) }) t.Run("creates key, deletes it unsafely and then enable creates it again", func(t *testing.T) { + ctx := testutils.Context(t) k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k.Address, testutils.SimulatedChainID)) _, err := db.Exec("DELETE FROM evm.key_states WHERE address = $1", k.Address) require.NoError(t, err) - require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) - key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, ks.Enable(ctx, k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(ctx, k.Address.Hex(), testutils.SimulatedChainID) require.NoError(t, err) require.Equal(t, key.Disabled, false) }) t.Run("creates key and enables it if it exists in the keystore, but is missing from key states db table", func(t *testing.T) { + ctx := testutils.Context(t) k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Enable(k.Address, testutils.SimulatedChainID)) - key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, ks.Enable(ctx, k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(ctx, k.Address.Hex(), testutils.SimulatedChainID) require.NoError(t, err) require.Equal(t, key.Disabled, false) }) t.Run("errors if key is not present in keystore", func(t *testing.T) { + ctx := testutils.Context(t) addrNotInKs := testutils.NewAddress() - require.Error(t, ks.Enable(addrNotInKs, testutils.SimulatedChainID)) - _, err := ks.GetState(addrNotInKs.Hex(), testutils.SimulatedChainID) + require.Error(t, ks.Enable(ctx, addrNotInKs, testutils.SimulatedChainID)) + _, err := ks.GetState(ctx, addrNotInKs.Hex(), testutils.SimulatedChainID) require.Error(t, err) }) } @@ -591,69 +616,72 @@ func Test_EthKeyStore_EnsureKeys(t *testing.T) { t.Parallel() t.Run("creates one unique key per chain if none exist", func(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) ks := keyStore.Eth() testutils.AssertCount(t, db, "evm.key_states", 0) - err := ks.EnsureKeys(testutils.FixtureChainID, testutils.SimulatedChainID) + err := ks.EnsureKeys(ctx, testutils.FixtureChainID, testutils.SimulatedChainID) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 2) - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 2) }) t.Run("does nothing if a key exists for a chain", func(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) ks := keyStore.Eth() // Add one enabled key - _, err := ks.Create(testutils.FixtureChainID) + _, err := ks.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 1) - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 1) // this adds one more key for the additional chain - err = ks.EnsureKeys(testutils.FixtureChainID, testutils.SimulatedChainID) + err = ks.EnsureKeys(ctx, testutils.FixtureChainID, testutils.SimulatedChainID) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 2) - keys, err = ks.GetAll() + keys, err = ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 2) }) t.Run("does nothing if a key exists but is disabled for a chain", func(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) ks := keyStore.Eth() // Add one enabled key - k, err := ks.Create(testutils.FixtureChainID) + k, err := ks.Create(ctx, testutils.FixtureChainID) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 1) - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 1) // disable the key - err = ks.Disable(k.Address, testutils.FixtureChainID) + err = ks.Disable(ctx, k.Address, testutils.FixtureChainID) require.NoError(t, err) // this does nothing - err = ks.EnsureKeys(testutils.FixtureChainID) + err = ks.EnsureKeys(ctx, testutils.FixtureChainID) require.NoError(t, err) testutils.AssertCount(t, db, "evm.key_states", 1) - keys, err = ks.GetAll() + keys, err = ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 1) - state, err := ks.GetState(k.Address.Hex(), testutils.FixtureChainID) + state, err := ks.GetState(ctx, k.Address.Hex(), testutils.FixtureChainID) require.NoError(t, err) assert.True(t, state.Disabled) }) @@ -662,52 +690,56 @@ func Test_EthKeyStore_EnsureKeys(t *testing.T) { func Test_EthKeyStore_Delete(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) ks := keyStore.Eth() randKeyID := utils.RandomAddress().Hex() - _, err := ks.Delete(randKeyID) + _, err := ks.Delete(ctx, randKeyID) require.Error(t, err) assert.Contains(t, err.Error(), "Key not found") _, addr1 := cltest.MustInsertRandomKey(t, ks) _, addr2 := cltest.MustInsertRandomKey(t, ks) cltest.MustInsertRandomKey(t, ks, *ubig.New(testutils.SimulatedChainID)) - require.NoError(t, ks.Add(addr1, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(addr1, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, addr1, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, addr1, testutils.SimulatedChainID)) testutils.AssertCount(t, db, "evm.key_states", 4) - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 3) - _, err = ks.GetState(addr1.Hex(), testutils.FixtureChainID) + _, err = ks.GetState(ctx, addr1.Hex(), testutils.FixtureChainID) require.NoError(t, err) - _, err = ks.GetState(addr1.Hex(), testutils.SimulatedChainID) + _, err = ks.GetState(ctx, addr1.Hex(), testutils.SimulatedChainID) require.NoError(t, err) - deletedK, err := ks.Delete(addr1.String()) + deletedK, err := ks.Delete(ctx, addr1.String()) require.NoError(t, err) assert.Equal(t, addr1, deletedK.Address) testutils.AssertCount(t, db, "evm.key_states", 2) - keys, err = ks.GetAll() + keys, err = ks.GetAll(ctx) require.NoError(t, err) assert.Len(t, keys, 2) - _, err = ks.GetState(addr1.Hex(), testutils.FixtureChainID) + _, err = ks.GetState(ctx, addr1.Hex(), testutils.FixtureChainID) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("state not found for eth key ID %s", addr1.Hex())) - _, err = ks.GetState(addr1.Hex(), testutils.SimulatedChainID) + _, err = ks.GetState(ctx, addr1.Hex(), testutils.SimulatedChainID) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("state not found for eth key ID %s", addr1.Hex())) - _, err = ks.GetState(addr2.Hex(), testutils.FixtureChainID) + _, err = ks.GetState(ctx, addr2.Hex(), testutils.FixtureChainID) require.NoError(t, err) } func Test_EthKeyStore_CheckEnabled(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) + db := pgtest.NewSqlxDB(t) cfg := configtest.NewTestGeneralConfig(t) keyStore := cltest.NewKeyStore(t, db, cfg.Database()) @@ -725,29 +757,30 @@ func Test_EthKeyStore_CheckEnabled(t *testing.T) { // - key 4 // enabled - fixture k1, addr1 := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k1.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Add(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Add(ctx, k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx, k1.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, k1.Address, testutils.FixtureChainID)) k2, addr2 := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k2.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Add(k2.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(k2.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k2.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Disable(k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k2.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Add(ctx, k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, k2.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx, k2.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Disable(ctx, k2.Address, testutils.SimulatedChainID)) k3, addr3 := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k3.Address, testutils.SimulatedChainID)) - require.NoError(t, ks.Enable(k3.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k3.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Enable(ctx, k3.Address, testutils.SimulatedChainID)) t.Run("enabling the same key multiple times does not create duplicate states", func(t *testing.T) { - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) - require.NoError(t, ks.Enable(k1.Address, testutils.FixtureChainID)) + ctx2 := testutils.Context(t) + require.NoError(t, ks.Enable(ctx2, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx2, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx2, k1.Address, testutils.FixtureChainID)) + require.NoError(t, ks.Enable(ctx2, k1.Address, testutils.FixtureChainID)) - states, err := ks.GetStatesForKeys([]ethkey.KeyV2{k1}) + states, err := ks.GetStatesForKeys(ctx2, []ethkey.KeyV2{k1}) require.NoError(t, err) assert.Len(t, states, 2) var cids []*big.Int @@ -764,27 +797,27 @@ func Test_EthKeyStore_CheckEnabled(t *testing.T) { }) t.Run("returns nil when key is enabled for given chain", func(t *testing.T) { - err := ks.CheckEnabled(addr1, testutils.FixtureChainID) + err := ks.CheckEnabled(ctx, addr1, testutils.FixtureChainID) assert.NoError(t, err) - err = ks.CheckEnabled(addr1, testutils.SimulatedChainID) + err = ks.CheckEnabled(ctx, addr1, testutils.SimulatedChainID) assert.NoError(t, err) }) t.Run("returns error when key does not exist", func(t *testing.T) { addr := utils.RandomAddress() - err := ks.CheckEnabled(addr, testutils.FixtureChainID) + err := ks.CheckEnabled(ctx, addr, testutils.FixtureChainID) assert.Error(t, err) require.Contains(t, err.Error(), fmt.Sprintf("no eth key exists with address %s", addr.Hex())) }) t.Run("returns error when key exists but has never been enabled (no state) for the given chain", func(t *testing.T) { - err := ks.CheckEnabled(addr3, testutils.FixtureChainID) + err := ks.CheckEnabled(ctx, addr3, testutils.FixtureChainID) assert.Error(t, err) require.Contains(t, err.Error(), fmt.Sprintf("eth key with address %s exists but is has not been enabled for chain 0 (enabled only for chain IDs: 1337)", addr3.Hex())) }) t.Run("returns error when key exists but is disabled for the given chain", func(t *testing.T) { - err := ks.CheckEnabled(addr2, testutils.SimulatedChainID) + err := ks.CheckEnabled(ctx, addr2, testutils.SimulatedChainID) assert.Error(t, err) require.Contains(t, err.Error(), fmt.Sprintf("eth key with address %s exists but is disabled for chain 1337 (enabled only for chain IDs: 0)", addr2.Hex())) }) @@ -799,28 +832,31 @@ func Test_EthKeyStore_Disable(t *testing.T) { ks := keyStore.Eth() t.Run("creates key, deletes it unsafely and then enable creates it again", func(t *testing.T) { + ctx := testutils.Context(t) k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Add(k.Address, testutils.SimulatedChainID)) + require.NoError(t, ks.Add(ctx, k.Address, testutils.SimulatedChainID)) _, err := db.Exec("DELETE FROM evm.key_states WHERE address = $1", k.Address) require.NoError(t, err) - require.NoError(t, ks.Disable(k.Address, testutils.SimulatedChainID)) - key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, ks.Disable(ctx, k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(ctx, k.Address.Hex(), testutils.SimulatedChainID) require.NoError(t, err) require.Equal(t, key.Disabled, true) }) t.Run("creates key and enables it if it exists in the keystore, but is missing from key states db table", func(t *testing.T) { + ctx := testutils.Context(t) k, _ := cltest.MustInsertRandomKeyNoChains(t, ks) - require.NoError(t, ks.Disable(k.Address, testutils.SimulatedChainID)) - key, err := ks.GetState(k.Address.Hex(), testutils.SimulatedChainID) + require.NoError(t, ks.Disable(ctx, k.Address, testutils.SimulatedChainID)) + key, err := ks.GetState(ctx, k.Address.Hex(), testutils.SimulatedChainID) require.NoError(t, err) require.Equal(t, key.Disabled, true) }) t.Run("errors if key is not present in keystore", func(t *testing.T) { + ctx := testutils.Context(t) addrNotInKs := testutils.NewAddress() - require.Error(t, ks.Disable(addrNotInKs, testutils.SimulatedChainID)) - _, err := ks.GetState(addrNotInKs.Hex(), testutils.SimulatedChainID) + require.Error(t, ks.Disable(ctx, addrNotInKs, testutils.SimulatedChainID)) + _, err := ks.GetState(ctx, addrNotInKs.Hex(), testutils.SimulatedChainID) require.Error(t, err) }) } diff --git a/core/services/keystore/mocks/eth.go b/core/services/keystore/mocks/eth.go index b3827398fd5..a5dd612d9e5 100644 --- a/core/services/keystore/mocks/eth.go +++ b/core/services/keystore/mocks/eth.go @@ -3,9 +3,11 @@ package mocks import ( + context "context" big "math/big" common "github.com/ethereum/go-ethereum/common" + ethkey "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" mock "github.com/stretchr/testify/mock" @@ -20,14 +22,14 @@ type Eth struct { mock.Mock } -// Add provides a mock function with given fields: address, chainID, qopts -func (_m *Eth) Add(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +// Add provides a mock function with given fields: ctx, address, chainID, qopts +func (_m *Eth) Add(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { _va := make([]interface{}, len(qopts)) for _i := range qopts { _va[_i] = qopts[_i] } var _ca []interface{} - _ca = append(_ca, address, chainID) + _ca = append(_ca, ctx, address, chainID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -36,8 +38,8 @@ func (_m *Eth) Add(address common.Address, chainID *big.Int, qopts ...pg.QOpt) e } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, *big.Int, ...pg.QOpt) error); ok { - r0 = rf(address, chainID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, ...pg.QOpt) error); ok { + r0 = rf(ctx, address, chainID, qopts...) } else { r0 = ret.Error(0) } @@ -45,17 +47,17 @@ func (_m *Eth) Add(address common.Address, chainID *big.Int, qopts ...pg.QOpt) e return r0 } -// CheckEnabled provides a mock function with given fields: address, chainID -func (_m *Eth) CheckEnabled(address common.Address, chainID *big.Int) error { - ret := _m.Called(address, chainID) +// CheckEnabled provides a mock function with given fields: ctx, address, chainID +func (_m *Eth) CheckEnabled(ctx context.Context, address common.Address, chainID *big.Int) error { + ret := _m.Called(ctx, address, chainID) if len(ret) == 0 { panic("no return value specified for CheckEnabled") } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, *big.Int) error); ok { - r0 = rf(address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) error); ok { + r0 = rf(ctx, address, chainID) } else { r0 = ret.Error(0) } @@ -63,13 +65,14 @@ func (_m *Eth) CheckEnabled(address common.Address, chainID *big.Int) error { return r0 } -// Create provides a mock function with given fields: chainIDs -func (_m *Eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { +// Create provides a mock function with given fields: ctx, chainIDs +func (_m *Eth) Create(ctx context.Context, chainIDs ...*big.Int) (ethkey.KeyV2, error) { _va := make([]interface{}, len(chainIDs)) for _i := range chainIDs { _va[_i] = chainIDs[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -79,17 +82,17 @@ func (_m *Eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { var r0 ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func(...*big.Int) (ethkey.KeyV2, error)); ok { - return rf(chainIDs...) + if rf, ok := ret.Get(0).(func(context.Context, ...*big.Int) (ethkey.KeyV2, error)); ok { + return rf(ctx, chainIDs...) } - if rf, ok := ret.Get(0).(func(...*big.Int) ethkey.KeyV2); ok { - r0 = rf(chainIDs...) + if rf, ok := ret.Get(0).(func(context.Context, ...*big.Int) ethkey.KeyV2); ok { + r0 = rf(ctx, chainIDs...) } else { r0 = ret.Get(0).(ethkey.KeyV2) } - if rf, ok := ret.Get(1).(func(...*big.Int) error); ok { - r1 = rf(chainIDs...) + if rf, ok := ret.Get(1).(func(context.Context, ...*big.Int) error); ok { + r1 = rf(ctx, chainIDs...) } else { r1 = ret.Error(1) } @@ -97,9 +100,9 @@ func (_m *Eth) Create(chainIDs ...*big.Int) (ethkey.KeyV2, error) { return r0, r1 } -// Delete provides a mock function with given fields: id -func (_m *Eth) Delete(id string) (ethkey.KeyV2, error) { - ret := _m.Called(id) +// Delete provides a mock function with given fields: ctx, id +func (_m *Eth) Delete(ctx context.Context, id string) (ethkey.KeyV2, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for Delete") @@ -107,17 +110,17 @@ func (_m *Eth) Delete(id string) (ethkey.KeyV2, error) { var r0 ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func(string) (ethkey.KeyV2, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) (ethkey.KeyV2, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(string) ethkey.KeyV2); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) ethkey.KeyV2); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(ethkey.KeyV2) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -125,14 +128,14 @@ func (_m *Eth) Delete(id string) (ethkey.KeyV2, error) { return r0, r1 } -// Disable provides a mock function with given fields: address, chainID, qopts -func (_m *Eth) Disable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +// Disable provides a mock function with given fields: ctx, address, chainID, qopts +func (_m *Eth) Disable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { _va := make([]interface{}, len(qopts)) for _i := range qopts { _va[_i] = qopts[_i] } var _ca []interface{} - _ca = append(_ca, address, chainID) + _ca = append(_ca, ctx, address, chainID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -141,8 +144,8 @@ func (_m *Eth) Disable(address common.Address, chainID *big.Int, qopts ...pg.QOp } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, *big.Int, ...pg.QOpt) error); ok { - r0 = rf(address, chainID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, ...pg.QOpt) error); ok { + r0 = rf(ctx, address, chainID, qopts...) } else { r0 = ret.Error(0) } @@ -150,14 +153,14 @@ func (_m *Eth) Disable(address common.Address, chainID *big.Int, qopts ...pg.QOp return r0 } -// Enable provides a mock function with given fields: address, chainID, qopts -func (_m *Eth) Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { +// Enable provides a mock function with given fields: ctx, address, chainID, qopts +func (_m *Eth) Enable(ctx context.Context, address common.Address, chainID *big.Int, qopts ...pg.QOpt) error { _va := make([]interface{}, len(qopts)) for _i := range qopts { _va[_i] = qopts[_i] } var _ca []interface{} - _ca = append(_ca, address, chainID) + _ca = append(_ca, ctx, address, chainID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -166,8 +169,8 @@ func (_m *Eth) Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, *big.Int, ...pg.QOpt) error); ok { - r0 = rf(address, chainID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, ...pg.QOpt) error); ok { + r0 = rf(ctx, address, chainID, qopts...) } else { r0 = ret.Error(0) } @@ -175,9 +178,9 @@ func (_m *Eth) Enable(address common.Address, chainID *big.Int, qopts ...pg.QOpt return r0 } -// EnabledAddressesForChain provides a mock function with given fields: chainID -func (_m *Eth) EnabledAddressesForChain(chainID *big.Int) ([]common.Address, error) { - ret := _m.Called(chainID) +// EnabledAddressesForChain provides a mock function with given fields: ctx, chainID +func (_m *Eth) EnabledAddressesForChain(ctx context.Context, chainID *big.Int) ([]common.Address, error) { + ret := _m.Called(ctx, chainID) if len(ret) == 0 { panic("no return value specified for EnabledAddressesForChain") @@ -185,19 +188,19 @@ func (_m *Eth) EnabledAddressesForChain(chainID *big.Int) ([]common.Address, err var r0 []common.Address var r1 error - if rf, ok := ret.Get(0).(func(*big.Int) ([]common.Address, error)); ok { - return rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) ([]common.Address, error)); ok { + return rf(ctx, chainID) } - if rf, ok := ret.Get(0).(func(*big.Int) []common.Address); ok { - r0 = rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) []common.Address); ok { + r0 = rf(ctx, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]common.Address) } } - if rf, ok := ret.Get(1).(func(*big.Int) error); ok { - r1 = rf(chainID) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, chainID) } else { r1 = ret.Error(1) } @@ -205,9 +208,9 @@ func (_m *Eth) EnabledAddressesForChain(chainID *big.Int) ([]common.Address, err return r0, r1 } -// EnabledKeysForChain provides a mock function with given fields: chainID -func (_m *Eth) EnabledKeysForChain(chainID *big.Int) ([]ethkey.KeyV2, error) { - ret := _m.Called(chainID) +// EnabledKeysForChain provides a mock function with given fields: ctx, chainID +func (_m *Eth) EnabledKeysForChain(ctx context.Context, chainID *big.Int) ([]ethkey.KeyV2, error) { + ret := _m.Called(ctx, chainID) if len(ret) == 0 { panic("no return value specified for EnabledKeysForChain") @@ -215,19 +218,19 @@ func (_m *Eth) EnabledKeysForChain(chainID *big.Int) ([]ethkey.KeyV2, error) { var r0 []ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func(*big.Int) ([]ethkey.KeyV2, error)); ok { - return rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) ([]ethkey.KeyV2, error)); ok { + return rf(ctx, chainID) } - if rf, ok := ret.Get(0).(func(*big.Int) []ethkey.KeyV2); ok { - r0 = rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) []ethkey.KeyV2); ok { + r0 = rf(ctx, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ethkey.KeyV2) } } - if rf, ok := ret.Get(1).(func(*big.Int) error); ok { - r1 = rf(chainID) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, chainID) } else { r1 = ret.Error(1) } @@ -235,13 +238,14 @@ func (_m *Eth) EnabledKeysForChain(chainID *big.Int) ([]ethkey.KeyV2, error) { return r0, r1 } -// EnsureKeys provides a mock function with given fields: chainIDs -func (_m *Eth) EnsureKeys(chainIDs ...*big.Int) error { +// EnsureKeys provides a mock function with given fields: ctx, chainIDs +func (_m *Eth) EnsureKeys(ctx context.Context, chainIDs ...*big.Int) error { _va := make([]interface{}, len(chainIDs)) for _i := range chainIDs { _va[_i] = chainIDs[_i] } var _ca []interface{} + _ca = append(_ca, ctx) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -250,8 +254,8 @@ func (_m *Eth) EnsureKeys(chainIDs ...*big.Int) error { } var r0 error - if rf, ok := ret.Get(0).(func(...*big.Int) error); ok { - r0 = rf(chainIDs...) + if rf, ok := ret.Get(0).(func(context.Context, ...*big.Int) error); ok { + r0 = rf(ctx, chainIDs...) } else { r0 = ret.Error(0) } @@ -259,9 +263,9 @@ func (_m *Eth) EnsureKeys(chainIDs ...*big.Int) error { return r0 } -// Export provides a mock function with given fields: id, password -func (_m *Eth) Export(id string, password string) ([]byte, error) { - ret := _m.Called(id, password) +// Export provides a mock function with given fields: ctx, id, password +func (_m *Eth) Export(ctx context.Context, id string, password string) ([]byte, error) { + ret := _m.Called(ctx, id, password) if len(ret) == 0 { panic("no return value specified for Export") @@ -269,19 +273,19 @@ func (_m *Eth) Export(id string, password string) ([]byte, error) { var r0 []byte var r1 error - if rf, ok := ret.Get(0).(func(string, string) ([]byte, error)); ok { - return rf(id, password) + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]byte, error)); ok { + return rf(ctx, id, password) } - if rf, ok := ret.Get(0).(func(string, string) []byte); ok { - r0 = rf(id, password) + if rf, ok := ret.Get(0).(func(context.Context, string, string) []byte); ok { + r0 = rf(ctx, id, password) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if rf, ok := ret.Get(1).(func(string, string) error); ok { - r1 = rf(id, password) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, id, password) } else { r1 = ret.Error(1) } @@ -289,9 +293,9 @@ func (_m *Eth) Export(id string, password string) ([]byte, error) { return r0, r1 } -// Get provides a mock function with given fields: id -func (_m *Eth) Get(id string) (ethkey.KeyV2, error) { - ret := _m.Called(id) +// Get provides a mock function with given fields: ctx, id +func (_m *Eth) Get(ctx context.Context, id string) (ethkey.KeyV2, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for Get") @@ -299,17 +303,17 @@ func (_m *Eth) Get(id string) (ethkey.KeyV2, error) { var r0 ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func(string) (ethkey.KeyV2, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) (ethkey.KeyV2, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(string) ethkey.KeyV2); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, string) ethkey.KeyV2); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(ethkey.KeyV2) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -317,9 +321,9 @@ func (_m *Eth) Get(id string) (ethkey.KeyV2, error) { return r0, r1 } -// GetAll provides a mock function with given fields: -func (_m *Eth) GetAll() ([]ethkey.KeyV2, error) { - ret := _m.Called() +// GetAll provides a mock function with given fields: ctx +func (_m *Eth) GetAll(ctx context.Context) ([]ethkey.KeyV2, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetAll") @@ -327,19 +331,19 @@ func (_m *Eth) GetAll() ([]ethkey.KeyV2, error) { var r0 []ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func() ([]ethkey.KeyV2, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]ethkey.KeyV2, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []ethkey.KeyV2); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []ethkey.KeyV2); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ethkey.KeyV2) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -347,14 +351,14 @@ func (_m *Eth) GetAll() ([]ethkey.KeyV2, error) { return r0, r1 } -// GetRoundRobinAddress provides a mock function with given fields: chainID, addresses -func (_m *Eth) GetRoundRobinAddress(chainID *big.Int, addresses ...common.Address) (common.Address, error) { +// GetRoundRobinAddress provides a mock function with given fields: ctx, chainID, addresses +func (_m *Eth) GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (common.Address, error) { _va := make([]interface{}, len(addresses)) for _i := range addresses { _va[_i] = addresses[_i] } var _ca []interface{} - _ca = append(_ca, chainID) + _ca = append(_ca, ctx, chainID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -364,19 +368,19 @@ func (_m *Eth) GetRoundRobinAddress(chainID *big.Int, addresses ...common.Addres var r0 common.Address var r1 error - if rf, ok := ret.Get(0).(func(*big.Int, ...common.Address) (common.Address, error)); ok { - return rf(chainID, addresses...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, ...common.Address) (common.Address, error)); ok { + return rf(ctx, chainID, addresses...) } - if rf, ok := ret.Get(0).(func(*big.Int, ...common.Address) common.Address); ok { - r0 = rf(chainID, addresses...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, ...common.Address) common.Address); ok { + r0 = rf(ctx, chainID, addresses...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(common.Address) } } - if rf, ok := ret.Get(1).(func(*big.Int, ...common.Address) error); ok { - r1 = rf(chainID, addresses...) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int, ...common.Address) error); ok { + r1 = rf(ctx, chainID, addresses...) } else { r1 = ret.Error(1) } @@ -384,9 +388,9 @@ func (_m *Eth) GetRoundRobinAddress(chainID *big.Int, addresses ...common.Addres return r0, r1 } -// GetState provides a mock function with given fields: id, chainID -func (_m *Eth) GetState(id string, chainID *big.Int) (ethkey.State, error) { - ret := _m.Called(id, chainID) +// GetState provides a mock function with given fields: ctx, id, chainID +func (_m *Eth) GetState(ctx context.Context, id string, chainID *big.Int) (ethkey.State, error) { + ret := _m.Called(ctx, id, chainID) if len(ret) == 0 { panic("no return value specified for GetState") @@ -394,17 +398,17 @@ func (_m *Eth) GetState(id string, chainID *big.Int) (ethkey.State, error) { var r0 ethkey.State var r1 error - if rf, ok := ret.Get(0).(func(string, *big.Int) (ethkey.State, error)); ok { - return rf(id, chainID) + if rf, ok := ret.Get(0).(func(context.Context, string, *big.Int) (ethkey.State, error)); ok { + return rf(ctx, id, chainID) } - if rf, ok := ret.Get(0).(func(string, *big.Int) ethkey.State); ok { - r0 = rf(id, chainID) + if rf, ok := ret.Get(0).(func(context.Context, string, *big.Int) ethkey.State); ok { + r0 = rf(ctx, id, chainID) } else { r0 = ret.Get(0).(ethkey.State) } - if rf, ok := ret.Get(1).(func(string, *big.Int) error); ok { - r1 = rf(id, chainID) + if rf, ok := ret.Get(1).(func(context.Context, string, *big.Int) error); ok { + r1 = rf(ctx, id, chainID) } else { r1 = ret.Error(1) } @@ -412,9 +416,9 @@ func (_m *Eth) GetState(id string, chainID *big.Int) (ethkey.State, error) { return r0, r1 } -// GetStateForKey provides a mock function with given fields: _a0 -func (_m *Eth) GetStateForKey(_a0 ethkey.KeyV2) (ethkey.State, error) { - ret := _m.Called(_a0) +// GetStateForKey provides a mock function with given fields: ctx, key +func (_m *Eth) GetStateForKey(ctx context.Context, key ethkey.KeyV2) (ethkey.State, error) { + ret := _m.Called(ctx, key) if len(ret) == 0 { panic("no return value specified for GetStateForKey") @@ -422,17 +426,17 @@ func (_m *Eth) GetStateForKey(_a0 ethkey.KeyV2) (ethkey.State, error) { var r0 ethkey.State var r1 error - if rf, ok := ret.Get(0).(func(ethkey.KeyV2) (ethkey.State, error)); ok { - return rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, ethkey.KeyV2) (ethkey.State, error)); ok { + return rf(ctx, key) } - if rf, ok := ret.Get(0).(func(ethkey.KeyV2) ethkey.State); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, ethkey.KeyV2) ethkey.State); ok { + r0 = rf(ctx, key) } else { r0 = ret.Get(0).(ethkey.State) } - if rf, ok := ret.Get(1).(func(ethkey.KeyV2) error); ok { - r1 = rf(_a0) + if rf, ok := ret.Get(1).(func(context.Context, ethkey.KeyV2) error); ok { + r1 = rf(ctx, key) } else { r1 = ret.Error(1) } @@ -440,9 +444,9 @@ func (_m *Eth) GetStateForKey(_a0 ethkey.KeyV2) (ethkey.State, error) { return r0, r1 } -// GetStatesForChain provides a mock function with given fields: chainID -func (_m *Eth) GetStatesForChain(chainID *big.Int) ([]ethkey.State, error) { - ret := _m.Called(chainID) +// GetStatesForChain provides a mock function with given fields: ctx, chainID +func (_m *Eth) GetStatesForChain(ctx context.Context, chainID *big.Int) ([]ethkey.State, error) { + ret := _m.Called(ctx, chainID) if len(ret) == 0 { panic("no return value specified for GetStatesForChain") @@ -450,19 +454,19 @@ func (_m *Eth) GetStatesForChain(chainID *big.Int) ([]ethkey.State, error) { var r0 []ethkey.State var r1 error - if rf, ok := ret.Get(0).(func(*big.Int) ([]ethkey.State, error)); ok { - return rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) ([]ethkey.State, error)); ok { + return rf(ctx, chainID) } - if rf, ok := ret.Get(0).(func(*big.Int) []ethkey.State); ok { - r0 = rf(chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) []ethkey.State); ok { + r0 = rf(ctx, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ethkey.State) } } - if rf, ok := ret.Get(1).(func(*big.Int) error); ok { - r1 = rf(chainID) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, chainID) } else { r1 = ret.Error(1) } @@ -470,9 +474,9 @@ func (_m *Eth) GetStatesForChain(chainID *big.Int) ([]ethkey.State, error) { return r0, r1 } -// GetStatesForKeys provides a mock function with given fields: _a0 -func (_m *Eth) GetStatesForKeys(_a0 []ethkey.KeyV2) ([]ethkey.State, error) { - ret := _m.Called(_a0) +// GetStatesForKeys provides a mock function with given fields: ctx, keys +func (_m *Eth) GetStatesForKeys(ctx context.Context, keys []ethkey.KeyV2) ([]ethkey.State, error) { + ret := _m.Called(ctx, keys) if len(ret) == 0 { panic("no return value specified for GetStatesForKeys") @@ -480,19 +484,19 @@ func (_m *Eth) GetStatesForKeys(_a0 []ethkey.KeyV2) ([]ethkey.State, error) { var r0 []ethkey.State var r1 error - if rf, ok := ret.Get(0).(func([]ethkey.KeyV2) ([]ethkey.State, error)); ok { - return rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, []ethkey.KeyV2) ([]ethkey.State, error)); ok { + return rf(ctx, keys) } - if rf, ok := ret.Get(0).(func([]ethkey.KeyV2) []ethkey.State); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, []ethkey.KeyV2) []ethkey.State); ok { + r0 = rf(ctx, keys) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ethkey.State) } } - if rf, ok := ret.Get(1).(func([]ethkey.KeyV2) error); ok { - r1 = rf(_a0) + if rf, ok := ret.Get(1).(func(context.Context, []ethkey.KeyV2) error); ok { + r1 = rf(ctx, keys) } else { r1 = ret.Error(1) } @@ -500,14 +504,14 @@ func (_m *Eth) GetStatesForKeys(_a0 []ethkey.KeyV2) ([]ethkey.State, error) { return r0, r1 } -// Import provides a mock function with given fields: keyJSON, password, chainIDs -func (_m *Eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) { +// Import provides a mock function with given fields: ctx, keyJSON, password, chainIDs +func (_m *Eth) Import(ctx context.Context, keyJSON []byte, password string, chainIDs ...*big.Int) (ethkey.KeyV2, error) { _va := make([]interface{}, len(chainIDs)) for _i := range chainIDs { _va[_i] = chainIDs[_i] } var _ca []interface{} - _ca = append(_ca, keyJSON, password) + _ca = append(_ca, ctx, keyJSON, password) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -517,17 +521,17 @@ func (_m *Eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (et var r0 ethkey.KeyV2 var r1 error - if rf, ok := ret.Get(0).(func([]byte, string, ...*big.Int) (ethkey.KeyV2, error)); ok { - return rf(keyJSON, password, chainIDs...) + if rf, ok := ret.Get(0).(func(context.Context, []byte, string, ...*big.Int) (ethkey.KeyV2, error)); ok { + return rf(ctx, keyJSON, password, chainIDs...) } - if rf, ok := ret.Get(0).(func([]byte, string, ...*big.Int) ethkey.KeyV2); ok { - r0 = rf(keyJSON, password, chainIDs...) + if rf, ok := ret.Get(0).(func(context.Context, []byte, string, ...*big.Int) ethkey.KeyV2); ok { + r0 = rf(ctx, keyJSON, password, chainIDs...) } else { r0 = ret.Get(0).(ethkey.KeyV2) } - if rf, ok := ret.Get(1).(func([]byte, string, ...*big.Int) error); ok { - r1 = rf(keyJSON, password, chainIDs...) + if rf, ok := ret.Get(1).(func(context.Context, []byte, string, ...*big.Int) error); ok { + r1 = rf(ctx, keyJSON, password, chainIDs...) } else { r1 = ret.Error(1) } @@ -535,9 +539,9 @@ func (_m *Eth) Import(keyJSON []byte, password string, chainIDs ...*big.Int) (et return r0, r1 } -// SignTx provides a mock function with given fields: fromAddress, tx, chainID -func (_m *Eth) SignTx(fromAddress common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { - ret := _m.Called(fromAddress, tx, chainID) +// SignTx provides a mock function with given fields: ctx, fromAddress, tx, chainID +func (_m *Eth) SignTx(ctx context.Context, fromAddress common.Address, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { + ret := _m.Called(ctx, fromAddress, tx, chainID) if len(ret) == 0 { panic("no return value specified for SignTx") @@ -545,19 +549,19 @@ func (_m *Eth) SignTx(fromAddress common.Address, tx *types.Transaction, chainID var r0 *types.Transaction var r1 error - if rf, ok := ret.Get(0).(func(common.Address, *types.Transaction, *big.Int) (*types.Transaction, error)); ok { - return rf(fromAddress, tx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *types.Transaction, *big.Int) (*types.Transaction, error)); ok { + return rf(ctx, fromAddress, tx, chainID) } - if rf, ok := ret.Get(0).(func(common.Address, *types.Transaction, *big.Int) *types.Transaction); ok { - r0 = rf(fromAddress, tx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *types.Transaction, *big.Int) *types.Transaction); ok { + r0 = rf(ctx, fromAddress, tx, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*types.Transaction) } } - if rf, ok := ret.Get(1).(func(common.Address, *types.Transaction, *big.Int) error); ok { - r1 = rf(fromAddress, tx, chainID) + if rf, ok := ret.Get(1).(func(context.Context, common.Address, *types.Transaction, *big.Int) error); ok { + r1 = rf(ctx, fromAddress, tx, chainID) } else { r1 = ret.Error(1) } @@ -565,9 +569,9 @@ func (_m *Eth) SignTx(fromAddress common.Address, tx *types.Transaction, chainID return r0, r1 } -// SubscribeToKeyChanges provides a mock function with given fields: -func (_m *Eth) SubscribeToKeyChanges() (chan struct{}, func()) { - ret := _m.Called() +// SubscribeToKeyChanges provides a mock function with given fields: ctx +func (_m *Eth) SubscribeToKeyChanges(ctx context.Context) (chan struct{}, func()) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for SubscribeToKeyChanges") @@ -575,19 +579,19 @@ func (_m *Eth) SubscribeToKeyChanges() (chan struct{}, func()) { var r0 chan struct{} var r1 func() - if rf, ok := ret.Get(0).(func() (chan struct{}, func())); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (chan struct{}, func())); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() chan struct{}); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) chan struct{}); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(chan struct{}) } } - if rf, ok := ret.Get(1).(func() func()); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) func()); ok { + r1 = rf(ctx) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(func()) @@ -597,14 +601,14 @@ func (_m *Eth) SubscribeToKeyChanges() (chan struct{}, func()) { return r0, r1 } -// XXXTestingOnlyAdd provides a mock function with given fields: key -func (_m *Eth) XXXTestingOnlyAdd(key ethkey.KeyV2) { - _m.Called(key) +// XXXTestingOnlyAdd provides a mock function with given fields: ctx, key +func (_m *Eth) XXXTestingOnlyAdd(ctx context.Context, key ethkey.KeyV2) { + _m.Called(ctx, key) } -// XXXTestingOnlySetState provides a mock function with given fields: _a0 -func (_m *Eth) XXXTestingOnlySetState(_a0 ethkey.State) { - _m.Called(_a0) +// XXXTestingOnlySetState provides a mock function with given fields: ctx, keyState +func (_m *Eth) XXXTestingOnlySetState(ctx context.Context, keyState ethkey.State) { + _m.Called(ctx, keyState) } // NewEth creates a new instance of Eth. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index 0eed680a3d8..6d7757ea528 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -1,6 +1,7 @@ package ocr import ( + "context" "fmt" "strings" "time" @@ -87,7 +88,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(spec job.Job, q pg.Queryer) error { return nil } // ServicesForSpec returns the OCR services that need to run for this job -func (d *Delegate) ServicesForSpec(jb job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { if jb.OCROracleSpec == nil { return nil, errors.Errorf("offchainreporting.Delegate expects an *job.OffchainreportingOracleSpec to be present, got %v", jb) } diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 336d1ae3800..8677f48e8b5 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -352,7 +352,7 @@ func (d *Delegate) cleanupEVM(jb job.Job, q pg.Queryer, relayID relay.ID) error } // ServicesForSpec returns the OCR2 services that need to run for this job -func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { spec := jb.OCR2OracleSpec if spec == nil { return nil, errors.Errorf("offchainreporting2.Delegate expects an *job.OCR2OracleSpec to be present, got %v", jb) @@ -438,7 +438,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { spec.CaptureEATelemetry = d.cfg.OCR2().CaptureEATelemetry() - ctx := lggrCtx.ContextWithValues(context.Background()) + ctx = lggrCtx.ContextWithValues(ctx) switch spec.PluginType { case types.Mercury: return d.newServicesMercury(ctx, lggr, jb, bootstrapPeers, kb, ocrDB, lc, ocrLogger) @@ -466,7 +466,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { ) thresholdPluginDB := NewDB(d.db, spec.ID, thresholdPluginId, lggr, d.cfg.Database()) s4PluginDB := NewDB(d.db, spec.ID, s4PluginId, lggr, d.cfg.Database()) - return d.newServicesOCR2Functions(lggr, jb, bootstrapPeers, kb, ocrDB, thresholdPluginDB, s4PluginDB, lc, ocrLogger) + return d.newServicesOCR2Functions(ctx, lggr, jb, bootstrapPeers, kb, ocrDB, thresholdPluginDB, s4PluginDB, lc, ocrLogger) case types.GenericPlugin: return d.newServicesGenericPlugin(ctx, lggr, jb, bootstrapPeers, kb, ocrDB, lc, ocrLogger, d.capabilitiesRegistry) @@ -1510,6 +1510,7 @@ func (d *Delegate) newServicesOCR2Keepers20( } func (d *Delegate) newServicesOCR2Functions( + ctx context.Context, lggr logger.SugaredLogger, jb job.Job, bootstrapPeers []commontypes.BootstrapperLocator, @@ -1535,6 +1536,7 @@ func (d *Delegate) newServicesOCR2Functions( } createPluginProvider := func(pluginType functionsRelay.FunctionsPluginType, relayerName string) (evmrelaytypes.FunctionsProvider, error) { return evmrelay.NewFunctionsProvider( + ctx, chain, types.RelayArgs{ ExternalJobID: jb.ExternalJobID, @@ -1646,7 +1648,7 @@ func (d *Delegate) newServicesOCR2Functions( LogPollerWrapper: functionsProvider.LogPollerWrapper(), } - functionsServices, err := functions.NewFunctionsServices(&functionsOracleArgs, &thresholdOracleArgs, &s4OracleArgs, &functionsServicesConfig) + functionsServices, err := functions.NewFunctionsServices(ctx, &functionsOracleArgs, &thresholdOracleArgs, &s4OracleArgs, &functionsServicesConfig) if err != nil { return nil, errors.Wrap(err, "error calling NewFunctionsServices") } diff --git a/core/services/ocr2/plugins/functions/integration_tests/v1/internal/testutils.go b/core/services/ocr2/plugins/functions/integration_tests/v1/internal/testutils.go index 3061d818bf1..7f562c9adb3 100644 --- a/core/services/ocr2/plugins/functions/integration_tests/v1/internal/testutils.go +++ b/core/services/ocr2/plugins/functions/integration_tests/v1/internal/testutils.go @@ -348,7 +348,7 @@ func StartNewNode( app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, b, p2pKey) - sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) require.Len(t, sendingKeys, 1) transmitter := sendingKeys[0].Address diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index a49ce4be90a..27835127d0d 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -1,6 +1,7 @@ package functions import ( + "context" "encoding/json" "math/big" "slices" @@ -59,7 +60,7 @@ const ( ) // Create all OCR2 plugin Oracles and all extra services needed to run a Functions job. -func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs *libocr2.OCR2OracleArgs, conf *FunctionsServicesConfig) ([]job.ServiceCtx, error) { +func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOracleArgs, s4OracleArgs *libocr2.OCR2OracleArgs, conf *FunctionsServicesConfig) ([]job.ServiceCtx, error) { pluginORM := functions.NewORM(conf.DB, conf.Logger, conf.QConfig, common.HexToAddress(conf.ContractID)) s4ORM := s4.NewPostgresORM(conf.DB, conf.Logger, conf.QConfig, s4.SharedTableName, FunctionsS4Namespace) @@ -156,7 +157,7 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return nil, errors.Wrap(err, "failed to create a OnchainSubscriptions") } connectorLogger := conf.Logger.Named("GatewayConnector").With("jobName", conf.Job.PipelineSpec.JobName) - connector, err2 := NewConnector(&pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger) + connector, err2 := NewConnector(ctx, &pluginConfig, conf.EthKeystore, conf.Chain.ID(), s4Storage, allowlist, rateLimiter, subscriptions, functionsListener, offchainTransmitter, connectorLogger) if err2 != nil { return nil, errors.Wrap(err, "failed to create a GatewayConnector") } @@ -183,8 +184,8 @@ func NewFunctionsServices(functionsOracleArgs, thresholdOracleArgs, s4OracleArgs return allServices, nil } -func NewConnector(pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) { - enabledKeys, err := ethKeystore.EnabledKeysForChain(chainID) +func NewConnector(ctx context.Context, pluginConfig *config.PluginConfig, ethKeystore keystore.Eth, chainID *big.Int, s4Storage s4.Storage, allowlist gwAllowlist.OnchainAllowlist, rateLimiter *hc.RateLimiter, subscriptions gwSubscriptions.OnchainSubscriptions, listener functions.FunctionsListener, offchainTransmitter functions.OffchainTransmitter, lggr logger.Logger) (connector.GatewayConnector, error) { + enabledKeys, err := ethKeystore.EnabledKeysForChain(ctx, chainID) if err != nil { return nil, err } diff --git a/core/services/ocr2/plugins/functions/plugin_test.go b/core/services/ocr2/plugins/functions/plugin_test.go index cd7956240df..fdd20b0a932 100644 --- a/core/services/ocr2/plugins/functions/plugin_test.go +++ b/core/services/ocr2/plugins/functions/plugin_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" sfmocks "github.com/smartcontractkit/chainlink/v2/core/services/functions/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" @@ -23,6 +24,9 @@ import ( func TestNewConnector_Success(t *testing.T) { t.Parallel() + + ctx := testutils.Context(t) + keyV2, err := ethkey.NewV2() require.NoError(t, err) @@ -39,16 +43,19 @@ func TestNewConnector_Success(t *testing.T) { require.NoError(t, err) listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) - ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{keyV2}, nil) + ethKeystore.On("EnabledKeysForChain", mock.Anything, mock.Anything).Return([]ethkey.KeyV2{keyV2}, nil) config := &config.PluginConfig{ GatewayConnectorConfig: gwcCfg, } - _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) + _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.NoError(t, err) } func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) { t.Parallel() + + ctx := testutils.Context(t) + addresses := []string{ "0x00000000DE801ceE9471ADf23370c48b011f82a6", "0x11111111DE801ceE9471ADf23370c48b011f82a6", @@ -67,10 +74,10 @@ func TestNewConnector_NoKeyForConfiguredAddress(t *testing.T) { require.NoError(t, err) listener := sfmocks.NewFunctionsListener(t) offchainTransmitter := sfmocks.NewOffchainTransmitter(t) - ethKeystore.On("EnabledKeysForChain", mock.Anything).Return([]ethkey.KeyV2{{Address: common.HexToAddress(addresses[1])}}, nil) + ethKeystore.On("EnabledKeysForChain", mock.Anything, mock.Anything).Return([]ethkey.KeyV2{{Address: common.HexToAddress(addresses[1])}}, nil) config := &config.PluginConfig{ GatewayConnectorConfig: gwcCfg, } - _, err = functions.NewConnector(config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) + _, err = functions.NewConnector(ctx, config, ethKeystore, chainID, s4Storage, allowlist, rateLimiter, subscriptions, listener, offchainTransmitter, logger.TestLogger(t)) require.Error(t, err) } diff --git a/core/services/ocr2/plugins/ocr2vrf/internal/ocr2vrf_integration_test.go b/core/services/ocr2/plugins/ocr2vrf/internal/ocr2vrf_integration_test.go index 4a01ee7904f..8d50b8076bb 100644 --- a/core/services/ocr2/plugins/ocr2vrf/internal/ocr2vrf_integration_test.go +++ b/core/services/ocr2/plugins/ocr2vrf/internal/ocr2vrf_integration_test.go @@ -226,6 +226,7 @@ func setupNodeOCR2( useForwarders bool, p2pV2Bootstrappers []commontypes.BootstrapperLocator, ) *ocr2Node { + ctx := testutils.Context(t) p2pKey := keystest.NewP2PKeyV2(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.Insecure.OCRDevelopmentMode = ptr(true) // Disables ocr spec validation so we can have fast polling for the test. @@ -255,7 +256,7 @@ func setupNodeOCR2( var sendingKeys []ethkey.KeyV2 { var err error - sendingKeys, err = app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + sendingKeys, err = app.KeyStore.Eth().EnabledKeysForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) require.Len(t, sendingKeys, 1) } @@ -266,10 +267,10 @@ func setupNodeOCR2( sendingKeysAddresses := []common.Address{sendingKeys[0].Address} // Add new sending key. - k, err := app.KeyStore.Eth().Create() + k, err := app.KeyStore.Eth().Create(ctx) require.NoError(t, err) - require.NoError(t, app.KeyStore.Eth().Add(k.Address, testutils.SimulatedChainID)) - require.NoError(t, app.KeyStore.Eth().Enable(k.Address, testutils.SimulatedChainID)) + require.NoError(t, app.KeyStore.Eth().Add(ctx, k.Address, testutils.SimulatedChainID)) + require.NoError(t, app.KeyStore.Eth().Enable(ctx, k.Address, testutils.SimulatedChainID)) sendingKeys = append(sendingKeys, k) sendingKeysAddresses = append(sendingKeysAddresses, k.Address) @@ -296,7 +297,7 @@ func setupNodeOCR2( var sendingKeyStrings []string for _, k := range sendingKeys { sendingKeyStrings = append(sendingKeyStrings, k.Address.String()) - n, err := b.NonceAt(testutils.Context(t), owner.From, nil) + n, err := b.NonceAt(ctx, owner.From, nil) require.NoError(t, err) tx := cltest.NewLegacyTransaction( @@ -307,7 +308,7 @@ func setupNodeOCR2( nil) signedTx, err := owner.Signer(owner.From, tx) require.NoError(t, err) - err = b.SendTransaction(testutils.Context(t), signedTx) + err = b.SendTransaction(ctx, signedTx) require.NoError(t, err) b.Commit() } @@ -429,7 +430,7 @@ func runOCR2VRFTest(t *testing.T, useForwarders bool) { ) t.Log("Adding bootstrap node job") - err = bootstrapNode.app.Start(testutils.Context(t)) + err = bootstrapNode.app.Start(ctx) require.NoError(t, err) evmChains := bootstrapNode.app.GetRelayers().LegacyEVMChains() @@ -457,7 +458,7 @@ fromBlock = %d for x := 1; x < len(sendingKeys[i]); x++ { sendingKeysString = fmt.Sprintf(`%s,"%s"`, sendingKeysString, sendingKeys[i][x]) } - err = apps[i].Start(testutils.Context(t)) + err = apps[i].Start(ctx) require.NoError(t, err) jobSpec := fmt.Sprintf(` @@ -511,7 +512,7 @@ linkEthFeedAddress = "%s" emptyHash := crypto.Keccak256Hash(emptyKH[:]) gomega.NewWithT(t).Eventually(func() bool { kh, err2 := uni.beacon.SProvingKeyHash(&bind.CallOpts{ - Context: testutils.Context(t), + Context: ctx, }) require.NoError(t, err2) t.Log("proving keyhash:", hexutil.Encode(kh[:])) @@ -661,7 +662,7 @@ linkEthFeedAddress = "%s" totalNopPayout := new(big.Int) for idx, payeeTransactor := range payeeTransactors { // Fund the payee with some ETH. - n, err2 := uni.backend.NonceAt(testutils.Context(t), uni.owner.From, nil) + n, err2 := uni.backend.NonceAt(ctx, uni.owner.From, nil) require.NoError(t, err2) tx := cltest.NewLegacyTransaction( n, payeeTransactor.From, @@ -671,7 +672,7 @@ linkEthFeedAddress = "%s" nil) signedTx, err2 := uni.owner.Signer(uni.owner.From, tx) require.NoError(t, err2) - err2 = uni.backend.SendTransaction(testutils.Context(t), signedTx) + err2 = uni.backend.SendTransaction(ctx, signedTx) require.NoError(t, err2) _, err2 = uni.beacon.WithdrawPayment(payeeTransactor, transmitters[idx]) diff --git a/core/services/ocrbootstrap/delegate.go b/core/services/ocrbootstrap/delegate.go index 27ddd53bd52..46c664007bc 100644 --- a/core/services/ocrbootstrap/delegate.go +++ b/core/services/ocrbootstrap/delegate.go @@ -79,7 +79,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) { } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(jb job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { spec := jb.BootstrapSpec if spec == nil { return nil, errors.Errorf("Bootstrap.Delegate expects an *job.BootstrapSpec to be present, got %v", jb) @@ -109,7 +109,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) (services []job.ServiceCtx, err e ContractID: spec.ContractID, FeedID: spec.FeedID, } - ctx := ctxVals.ContextWithValues(context.Background()) + ctx = ctxVals.ContextWithValues(ctx) var relayCfg relayConfig if err = json.Unmarshal(spec.RelayConfig.Bytes(), &relayCfg); err != nil { diff --git a/core/services/ocrcommon/transmitter.go b/core/services/ocrcommon/transmitter.go index 9cdf6a0c5a9..1c4173798ea 100644 --- a/core/services/ocrcommon/transmitter.go +++ b/core/services/ocrcommon/transmitter.go @@ -12,7 +12,7 @@ import ( ) type roundRobinKeystore interface { - GetRoundRobinAddress(chainID *big.Int, addresses ...common.Address) (address common.Address, err error) + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (address common.Address, err error) } type txManager interface { @@ -66,7 +66,7 @@ func NewTransmitter( func (t *transmitter) CreateEthTransaction(ctx context.Context, toAddress common.Address, payload []byte, txMeta *txmgr.TxMeta) error { - roundRobinFromAddress, err := t.keystore.GetRoundRobinAddress(t.chainID, t.fromAddresses...) + roundRobinFromAddress, err := t.keystore.GetRoundRobinAddress(ctx, t.chainID, t.fromAddresses...) if err != nil { return errors.Wrap(err, "skipped OCR transmission, error getting round-robin address") } diff --git a/core/services/pipeline/task.eth_tx.go b/core/services/pipeline/task.eth_tx.go index 1687c974140..ffd496c486d 100644 --- a/core/services/pipeline/task.eth_tx.go +++ b/core/services/pipeline/task.eth_tx.go @@ -47,7 +47,7 @@ type ETHTxTask struct { } type ETHKeyStore interface { - GetRoundRobinAddress(chainID *big.Int, addrs ...common.Address) (common.Address, error) + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addrs ...common.Address) (common.Address, error) } var _ Task = (*ETHTxTask)(nil) @@ -127,7 +127,7 @@ func (t *ETHTxTask) Run(ctx context.Context, lggr logger.Logger, vars Vars, inpu return Result{Error: err}, runInfo } - fromAddr, err := t.keyStore.GetRoundRobinAddress(chain.ID(), fromAddrs...) + fromAddr, err := t.keyStore.GetRoundRobinAddress(ctx, chain.ID(), fromAddrs...) if err != nil { err = errors.Wrap(err, "ETHTxTask failed to get fromAddress") lggr.Error(err) diff --git a/core/services/pipeline/task.eth_tx_test.go b/core/services/pipeline/task.eth_tx_test.go index 5f5019d1967..d50949e120d 100644 --- a/core/services/pipeline/task.eth_tx_test.go +++ b/core/services/pipeline/task.eth_tx_test.go @@ -83,7 +83,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -131,7 +131,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -169,7 +169,7 @@ func TestETHTxTask(t *testing.T) { nil, func(keyStore *keystoremocks.Eth, txManager *txmmocks.MockEvmTxManager) { addr := common.HexToAddress("0x882969652440ccf14a5dbb9bd53eb21cb1e11e5c") - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, addr).Return(addr, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, addr).Return(addr, nil) txManager.On("CreateTransaction", mock.Anything, mock.MatchedBy(func(tx txmgr.TxRequest) bool { return tx.MinConfirmations == clnull.Uint32From(2) })).Return(txmgr.Tx{}, nil) @@ -209,7 +209,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -255,7 +255,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -286,7 +286,7 @@ func TestETHTxTask(t *testing.T) { data := []byte("foobar") gasLimit := uint32(12345) txMeta := &txmgr.TxMeta{FailOnRevert: null.BoolFrom(false)} - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -321,7 +321,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -356,7 +356,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -395,7 +395,7 @@ func TestETHTxTask(t *testing.T) { nil, func(keyStore *keystoremocks.Eth, txManager *txmmocks.MockEvmTxManager) { - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID).Return(nil, errors.New("uh oh")) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID).Return(nil, errors.New("uh oh")) }, nil, pipeline.ErrTaskRunFailed, "while querying keystore", pipeline.RunInfo{IsRetryable: true}, }, @@ -422,7 +422,7 @@ func TestETHTxTask(t *testing.T) { RequestTxHash: &reqTxHash, FailOnRevert: null.BoolFrom(false), } - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, txmgr.TxRequest{ FromAddress: from, ToAddress: to, @@ -519,7 +519,7 @@ func TestETHTxTask(t *testing.T) { nil, func(keyStore *keystoremocks.Eth, txManager *txmmocks.MockEvmTxManager) { from := common.HexToAddress("0x882969652440ccf14a5dbb9bd53eb21cb1e11e5c") - keyStore.On("GetRoundRobinAddress", testutils.FixtureChainID, from).Return(from, nil) + keyStore.On("GetRoundRobinAddress", mock.Anything, testutils.FixtureChainID, from).Return(from, nil) txManager.On("CreateTransaction", mock.Anything, mock.MatchedBy(func(tx txmgr.TxRequest) bool { return tx.MinConfirmations == clnull.Uint32From(3) && tx.PipelineTaskRunID != nil })).Return(txmgr.Tx{}, nil) diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index dcccbb90c79..a02885cb556 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -159,6 +159,10 @@ func (r *Relayer) HealthReport() (report map[string]error) { } func (r *Relayer) NewPluginProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (commontypes.PluginProvider, error) { + + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + lggr := r.lggr.Named("PluginProvider").Named(rargs.ExternalJobID.String()) configWatcher, err := newStandardConfigProvider(r.lggr, r.chain, types.NewRelayOpts(rargs)) @@ -166,7 +170,7 @@ func (r *Relayer) NewPluginProvider(rargs commontypes.RelayArgs, pargs commontyp return nil, err } - transmitter, err := newOnChainContractTransmitter(r.lggr, rargs, pargs.TransmitterID, r.ks.Eth(), configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) + transmitter, err := newOnChainContractTransmitter(ctx, r.lggr, rargs, pargs.TransmitterID, r.ks.Eth(), configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) if err != nil { return nil, err } @@ -300,9 +304,13 @@ func (r *Relayer) NewLLOProvider(rargs commontypes.RelayArgs, pargs commontypes. } func (r *Relayer) NewFunctionsProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (commontypes.FunctionsProvider, error) { + + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + lggr := r.lggr.Named("FunctionsProvider").Named(rargs.ExternalJobID.String()) // TODO(FUN-668): Not ready yet (doesn't implement FunctionsEvents() properly) - return NewFunctionsProvider(r.chain, rargs, pargs, lggr, r.ks.Eth(), functions.FunctionsPlugin) + return NewFunctionsProvider(ctx, r.chain, rargs, pargs, lggr, r.ks.Eth(), functions.FunctionsPlugin) } // NewConfigProvider is called by bootstrap jobs @@ -450,7 +458,7 @@ type configTransmitterOpts struct { pluginGasLimit *uint32 } -func newOnChainContractTransmitter(lggr logger.Logger, rargs commontypes.RelayArgs, transmitterID string, ethKeystore keystore.Eth, configWatcher *configWatcher, opts configTransmitterOpts, transmissionContractABI abi.ABI) (*contractTransmitter, error) { +func newOnChainContractTransmitter(ctx context.Context, lggr logger.Logger, rargs commontypes.RelayArgs, transmitterID string, ethKeystore keystore.Eth, configWatcher *configWatcher, opts configTransmitterOpts, transmissionContractABI abi.ABI) (*contractTransmitter, error) { var relayConfig types.RelayConfig if err := json.Unmarshal(rargs.RelayConfig, &relayConfig); err != nil { return nil, err @@ -473,7 +481,7 @@ func newOnChainContractTransmitter(lggr logger.Logger, rargs commontypes.RelayAr if sendingKeysLength > 1 && s == effectiveTransmitterAddress.String() { return nil, pkgerrors.New("the transmitter is a local sending key with transaction forwarding enabled") } - if err := ethKeystore.CheckEnabled(common.HexToAddress(s), configWatcher.chain.Config().EVM().ChainID()); err != nil { + if err := ethKeystore.CheckEnabled(ctx, common.HexToAddress(s), configWatcher.chain.Config().EVM().ChainID()); err != nil { return nil, pkgerrors.Wrap(err, "one of the sending keys given is not enabled") } fromAddresses = append(fromAddresses, common.HexToAddress(s)) @@ -523,6 +531,9 @@ func newOnChainContractTransmitter(lggr logger.Logger, rargs commontypes.RelayAr } func (r *Relayer) NewMedianProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (commontypes.MedianProvider, error) { + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + lggr := r.lggr.Named("MedianProvider").Named(rargs.ExternalJobID.String()) relayOpts := types.NewRelayOpts(rargs) relayConfig, err := relayOpts.RelayConfig() @@ -545,7 +556,7 @@ func (r *Relayer) NewMedianProvider(rargs commontypes.RelayArgs, pargs commontyp reportCodec := evmreportcodec.ReportCodec{} - contractTransmitter, err := newOnChainContractTransmitter(lggr, rargs, pargs.TransmitterID, r.ks.Eth(), configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) + contractTransmitter, err := newOnChainContractTransmitter(ctx, lggr, rargs, pargs.TransmitterID, r.ks.Eth(), configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) if err != nil { return nil, err } diff --git a/core/services/relay/evm/functions.go b/core/services/relay/evm/functions.go index d4e91034496..c10134f3acc 100644 --- a/core/services/relay/evm/functions.go +++ b/core/services/relay/evm/functions.go @@ -90,7 +90,7 @@ func (p *functionsProvider) Codec() commontypes.Codec { return nil } -func NewFunctionsProvider(chain legacyevm.Chain, rargs commontypes.RelayArgs, pargs commontypes.PluginArgs, lggr logger.Logger, ethKeystore keystore.Eth, pluginType functionsRelay.FunctionsPluginType) (evmRelayTypes.FunctionsProvider, error) { +func NewFunctionsProvider(ctx context.Context, chain legacyevm.Chain, rargs commontypes.RelayArgs, pargs commontypes.PluginArgs, lggr logger.Logger, ethKeystore keystore.Eth, pluginType functionsRelay.FunctionsPluginType) (evmRelayTypes.FunctionsProvider, error) { relayOpts := evmRelayTypes.NewRelayOpts(rargs) relayConfig, err := relayOpts.RelayConfig() if err != nil { @@ -121,7 +121,7 @@ func NewFunctionsProvider(chain legacyevm.Chain, rargs commontypes.RelayArgs, pa } var contractTransmitter ContractTransmitter if relayConfig.SendingKeys != nil { - contractTransmitter, err = newFunctionsContractTransmitter(pluginConfig.ContractVersion, rargs, pargs.TransmitterID, configWatcher, ethKeystore, logPollerWrapper, lggr) + contractTransmitter, err = newFunctionsContractTransmitter(ctx, pluginConfig.ContractVersion, rargs, pargs.TransmitterID, configWatcher, ethKeystore, logPollerWrapper, lggr) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func newFunctionsConfigProvider(pluginType functionsRelay.FunctionsPluginType, c return newConfigWatcher(lggr, routerContractAddress, offchainConfigDigester, cp, chain, fromBlock, args.New), nil } -func newFunctionsContractTransmitter(contractVersion uint32, rargs commontypes.RelayArgs, transmitterID string, configWatcher *configWatcher, ethKeystore keystore.Eth, logPollerWrapper evmRelayTypes.LogPollerWrapper, lggr logger.Logger) (ContractTransmitter, error) { +func newFunctionsContractTransmitter(ctx context.Context, contractVersion uint32, rargs commontypes.RelayArgs, transmitterID string, configWatcher *configWatcher, ethKeystore keystore.Eth, logPollerWrapper evmRelayTypes.LogPollerWrapper, lggr logger.Logger) (ContractTransmitter, error) { var relayConfig evmRelayTypes.RelayConfig if err := json.Unmarshal(rargs.RelayConfig, &relayConfig); err != nil { return nil, err @@ -177,7 +177,7 @@ func newFunctionsContractTransmitter(contractVersion uint32, rargs commontypes.R if sendingKeysLength > 1 && s == effectiveTransmitterAddress.String() { return nil, errors.New("the transmitter is a local sending key with transaction forwarding enabled") } - if err := ethKeystore.CheckEnabled(common.HexToAddress(s), configWatcher.chain.Config().EVM().ChainID()); err != nil { + if err := ethKeystore.CheckEnabled(ctx, common.HexToAddress(s), configWatcher.chain.Config().EVM().ChainID()); err != nil { return nil, errors.Wrap(err, "one of the sending keys given is not enabled") } fromAddresses = append(fromAddresses, common.HexToAddress(s)) diff --git a/core/services/relay/evm/ocr2keeper.go b/core/services/relay/evm/ocr2keeper.go index 6bde444d80b..6563604945c 100644 --- a/core/services/relay/evm/ocr2keeper.go +++ b/core/services/relay/evm/ocr2keeper.go @@ -84,13 +84,17 @@ func NewOCR2KeeperRelayer(db *sqlx.DB, chain legacyevm.Chain, lggr logger.Logger } func (r *ocr2keeperRelayer) NewOCR2KeeperProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (OCR2KeeperProvider, error) { + + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + cfgWatcher, err := newOCR2KeeperConfigProvider(r.lggr, r.chain, rargs) if err != nil { return nil, err } gasLimit := cfgWatcher.chain.Config().EVM().OCR2().Automation().GasLimit() - contractTransmitter, err := newOnChainContractTransmitter(r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, cfgWatcher, configTransmitterOpts{pluginGasLimit: &gasLimit}, OCR2AggregatorTransmissionContractABI) + contractTransmitter, err := newOnChainContractTransmitter(ctx, r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, cfgWatcher, configTransmitterOpts{pluginGasLimit: &gasLimit}, OCR2AggregatorTransmissionContractABI) if err != nil { return nil, err } diff --git a/core/services/relay/evm/ocr2vrf.go b/core/services/relay/evm/ocr2vrf.go index d421b38ea77..98753655550 100644 --- a/core/services/relay/evm/ocr2vrf.go +++ b/core/services/relay/evm/ocr2vrf.go @@ -1,6 +1,7 @@ package evm import ( + "context" "encoding/json" "fmt" @@ -59,11 +60,15 @@ func NewOCR2VRFRelayer(db *sqlx.DB, chain legacyevm.Chain, lggr logger.Logger, e } func (r *ocr2vrfRelayer) NewDKGProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (DKGProvider, error) { + + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + configWatcher, err := newOCR2VRFConfigProvider(r.lggr, r.chain, rargs) if err != nil { return nil, err } - contractTransmitter, err := newOnChainContractTransmitter(r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) + contractTransmitter, err := newOnChainContractTransmitter(ctx, r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) if err != nil { return nil, err } @@ -82,11 +87,15 @@ func (r *ocr2vrfRelayer) NewDKGProvider(rargs commontypes.RelayArgs, pargs commo } func (r *ocr2vrfRelayer) NewOCR2VRFProvider(rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (OCR2VRFProvider, error) { + + // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 + ctx := context.Background() + configWatcher, err := newOCR2VRFConfigProvider(r.lggr, r.chain, rargs) if err != nil { return nil, err } - contractTransmitter, err := newOnChainContractTransmitter(r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) + contractTransmitter, err := newOnChainContractTransmitter(ctx, r.lggr, rargs, pargs.TransmitterID, r.ethKeystore, configWatcher, configTransmitterOpts{}, OCR2AggregatorTransmissionContractABI) if err != nil { return nil, err } diff --git a/core/services/streams/delegate.go b/core/services/streams/delegate.go index f7dc852a50b..5ea0d475d2b 100644 --- a/core/services/streams/delegate.go +++ b/core/services/streams/delegate.go @@ -43,7 +43,7 @@ func (d *Delegate) AfterJobCreated(jb job.Job) {} func (d *Delegate) BeforeJobDeleted(jb job.Job) {} func (d *Delegate) OnDeleteJob(jb job.Job, q pg.Queryer) error { return nil } -func (d *Delegate) ServicesForSpec(jb job.Job) (services []job.ServiceCtx, err error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { if jb.StreamID == nil { return nil, errors.New("streamID is required to be present for stream specs") } diff --git a/core/services/streams/delegate_test.go b/core/services/streams/delegate_test.go index e97da63d522..d177c977e1b 100644 --- a/core/services/streams/delegate_test.go +++ b/core/services/streams/delegate_test.go @@ -3,6 +3,7 @@ package streams import ( "testing" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" @@ -35,12 +36,12 @@ func Test_Delegate(t *testing.T) { t.Run("ServicesForSpec", func(t *testing.T) { jb := job.Job{PipelineSpec: &pipeline.Spec{ID: 1}} t.Run("errors if job is missing streamID", func(t *testing.T) { - _, err := d.ServicesForSpec(jb) + _, err := d.ServicesForSpec(testutils.Context(t), jb) assert.EqualError(t, err, "streamID is required to be present for stream specs") }) jb.StreamID = ptr(uint32(42)) t.Run("returns services", func(t *testing.T) { - srvs, err := d.ServicesForSpec(jb) + srvs, err := d.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) assert.Len(t, srvs, 2) diff --git a/core/services/vrf/delegate.go b/core/services/vrf/delegate.go index ecabbc09c71..14ba341b1b6 100644 --- a/core/services/vrf/delegate.go +++ b/core/services/vrf/delegate.go @@ -1,6 +1,7 @@ package vrf import ( + "context" "fmt" "time" @@ -72,7 +73,7 @@ func (d *Delegate) BeforeJobDeleted(job.Job) {} func (d *Delegate) OnDeleteJob(job.Job, pg.Queryer) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { if jb.VRFSpec == nil || jb.PipelineSpec == nil { return nil, errors.Errorf("vrf.Delegate expects a VRFSpec and PipelineSpec to be present, got %+v", jb) } @@ -128,7 +129,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { for _, task := range pl.Tasks { if _, ok := task.(*pipeline.VRFTaskV2Plus); ok { - if err2 := CheckFromAddressesExist(jb, d.ks.Eth()); err != nil { + if err2 := CheckFromAddressesExist(ctx, jb, d.ks.Eth()); err != nil { return nil, err2 } @@ -187,7 +188,7 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { }, nil } if _, ok := task.(*pipeline.VRFTaskV2); ok { - if err2 := CheckFromAddressesExist(jb, d.ks.Eth()); err != nil { + if err2 := CheckFromAddressesExist(ctx, jb, d.ks.Eth()); err != nil { return nil, err2 } @@ -269,9 +270,9 @@ func (d *Delegate) ServicesForSpec(jb job.Job) ([]job.ServiceCtx, error) { // CheckFromAddressesExist returns an error if and only if one of the addresses // in the VRF spec's fromAddresses field does not exist in the keystore. -func CheckFromAddressesExist(jb job.Job, gethks keystore.Eth) (err error) { +func CheckFromAddressesExist(ctx context.Context, jb job.Job, gethks keystore.Eth) (err error) { for _, a := range jb.VRFSpec.FromAddresses { - _, err2 := gethks.Get(a.Hex()) + _, err2 := gethks.Get(ctx, a.Hex()) err = multierr.Append(err, err2) } return diff --git a/core/services/vrf/delegate_test.go b/core/services/vrf/delegate_test.go index 8ad88d7b73b..bbbb2d75dff 100644 --- a/core/services/vrf/delegate_test.go +++ b/core/services/vrf/delegate_test.go @@ -91,7 +91,7 @@ func buildVrfUni(t *testing.T, db *sqlx.DB, cfg chainlink.GeneralConfig) vrfUniv legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) pr := pipeline.NewRunner(prm, btORM, cfg.JobPipeline(), cfg.WebServer(), legacyChains, ks.Eth(), ks.VRF(), lggr, nil, nil) require.NoError(t, ks.Unlock(testutils.Password)) - k, err2 := ks.Eth().Create(testutils.FixtureChainID) + k, err2 := ks.Eth().Create(testutils.Context(t), testutils.FixtureChainID) require.NoError(t, err2) submitter := k.Address require.NoError(t, err) @@ -167,7 +167,7 @@ func setup(t *testing.T) (vrfUniverse, *v1.Listener, job.Job) { require.NoError(t, err) err = vuni.jrm.CreateJob(&jb) require.NoError(t, err) - vl, err := vd.ServicesForSpec(jb) + vl, err := vd.ServicesForSpec(testutils.Context(t), jb) require.NoError(t, err) require.Len(t, vl, 1) listener := vl[0].(*v1.Listener) @@ -565,7 +565,7 @@ func Test_CheckFromAddressesExist(t *testing.T) { var fromAddresses []string for i := 0; i < 3; i++ { - k, err := ks.Eth().Create(big.NewInt(1337)) + k, err := ks.Eth().Create(testutils.Context(t), big.NewInt(1337)) assert.NoError(t, err) fromAddresses = append(fromAddresses, k.Address.Hex()) } @@ -581,7 +581,7 @@ func Test_CheckFromAddressesExist(t *testing.T) { Toml()) assert.NoError(t, err) - assert.NoError(t, vrf.CheckFromAddressesExist(jb, ks.Eth())) + assert.NoError(t, vrf.CheckFromAddressesExist(testutils.Context(t), jb, ks.Eth())) }) t.Run("one of from addresses doesn't exist", func(t *testing.T) { @@ -593,7 +593,7 @@ func Test_CheckFromAddressesExist(t *testing.T) { var fromAddresses []string for i := 0; i < 3; i++ { - k, err := ks.Eth().Create(big.NewInt(1337)) + k, err := ks.Eth().Create(testutils.Context(t), big.NewInt(1337)) assert.NoError(t, err) fromAddresses = append(fromAddresses, k.Address.Hex()) } @@ -611,7 +611,7 @@ func Test_CheckFromAddressesExist(t *testing.T) { Toml()) assert.NoError(t, err) - assert.Error(t, vrf.CheckFromAddressesExist(jb, ks.Eth())) + assert.Error(t, vrf.CheckFromAddressesExist(testutils.Context(t), jb, ks.Eth())) }) } @@ -701,7 +701,7 @@ func Test_VRFV2PlusServiceFailsWhenVRFOwnerProvided(t *testing.T) { require.NoError(t, err) err = vuni.jrm.CreateJob(&jb) require.NoError(t, err) - _, err = vd.ServicesForSpec(jb) + _, err = vd.ServicesForSpec(testutils.Context(t), jb) require.Error(t, err) require.Equal(t, "VRF Owner is not supported for VRF V2 Plus", err.Error()) } diff --git a/core/services/vrf/v2/integration_v2_test.go b/core/services/vrf/v2/integration_v2_test.go index 39acc3da3e5..9c16e8c0c7a 100644 --- a/core/services/vrf/v2/integration_v2_test.go +++ b/core/services/vrf/v2/integration_v2_test.go @@ -1652,7 +1652,7 @@ func TestIntegrationVRFV2(t *testing.T) { carolContractAddress := uni.consumerContractAddresses[0] app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, key) - keys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.SimulatedChainID) + keys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) require.Zero(t, key.Cmp(keys[0])) @@ -2067,7 +2067,7 @@ func TestStartingCountsV1(t *testing.T) { assert.Equal(t, 0, len(counts)) err = ks.Unlock(testutils.Password) require.NoError(t, err) - k, err := ks.Eth().Create(testutils.SimulatedChainID) + k, err := ks.Eth().Create(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) b := time.Now() n1, n2, n3, n4 := evmtypes.Nonce(0), evmtypes.Nonce(1), evmtypes.Nonce(2), evmtypes.Nonce(3) diff --git a/core/services/vrf/v2/listener_v2_log_processor.go b/core/services/vrf/v2/listener_v2_log_processor.go index be9457d7cee..7f61dd4cf3e 100644 --- a/core/services/vrf/v2/listener_v2_log_processor.go +++ b/core/services/vrf/v2/listener_v2_log_processor.go @@ -387,7 +387,7 @@ func (lsn *listenerV2) processRequestsPerSubBatchHelper( "blockHash", p.req.req.Raw().BlockHash, ) fromAddresses := lsn.fromAddresses() - fromAddress, err := lsn.gethks.GetRoundRobinAddress(lsn.chainID, fromAddresses...) + fromAddress, err := lsn.gethks.GetRoundRobinAddress(ctx, lsn.chainID, fromAddresses...) if err != nil { l.Errorw("Couldn't get next from address", "err", err) continue @@ -717,7 +717,7 @@ func (lsn *listenerV2) processRequestsPerSubHelper( "blockNumber", p.req.req.Raw().BlockNumber, "blockHash", p.req.req.Raw().BlockHash, ) - fromAddress, err := lsn.gethks.GetRoundRobinAddress(lsn.chainID, fromAddresses...) + fromAddress, err := lsn.gethks.GetRoundRobinAddress(ctx, lsn.chainID, fromAddresses...) if err != nil { l.Errorw("Couldn't get next from address", "err", err) continue diff --git a/core/services/vrf/v2/listener_v2_test.go b/core/services/vrf/v2/listener_v2_test.go index d8bc0a6695b..465e3dcaca9 100644 --- a/core/services/vrf/v2/listener_v2_test.go +++ b/core/services/vrf/v2/listener_v2_test.go @@ -186,7 +186,7 @@ func testMaybeSubtractReservedLink(t *testing.T, vrfVersion vrfcommon.Version) { ks := keystore.NewInMemory(db, utils.FastScryptParams, lggr, cfg) require.NoError(t, ks.Unlock("blah")) chainID := testutils.SimulatedChainID - k, err := ks.Eth().Create(chainID) + k, err := ks.Eth().Create(testutils.Context(t), chainID) require.NoError(t, err) subID := new(big.Int).SetUint64(1) @@ -236,7 +236,7 @@ func testMaybeSubtractReservedLink(t *testing.T, vrfVersion vrfcommon.Version) { require.Equal(t, "80000", start.String()) // One key's data should not affect other keys' data in the case of different subscribers. - k2, err := ks.Eth().Create(testutils.SimulatedChainID) + k2, err := ks.Eth().Create(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) anotherSubID := new(big.Int).SetUint64(3) @@ -268,7 +268,7 @@ func testMaybeSubtractReservedNative(t *testing.T, vrfVersion vrfcommon.Version) ks := keystore.NewInMemory(db, utils.FastScryptParams, lggr, cfg) require.NoError(t, ks.Unlock("blah")) chainID := testutils.SimulatedChainID - k, err := ks.Eth().Create(chainID) + k, err := ks.Eth().Create(testutils.Context(t), chainID) require.NoError(t, err) subID := new(big.Int).SetUint64(1) @@ -319,7 +319,7 @@ func testMaybeSubtractReservedNative(t *testing.T, vrfVersion vrfcommon.Version) require.Equal(t, "80000", start.String()) // One key's data should not affect other keys' data in the case of different subscribers. - k2, err := ks.Eth().Create(testutils.SimulatedChainID) + k2, err := ks.Eth().Create(testutils.Context(t), testutils.SimulatedChainID) require.NoError(t, err) anotherSubID := new(big.Int).SetUint64(3) diff --git a/core/services/vrf/v2/reverted_txns.go b/core/services/vrf/v2/reverted_txns.go index 5aead146f5f..be20738a8f5 100644 --- a/core/services/vrf/v2/reverted_txns.go +++ b/core/services/vrf/v2/reverted_txns.go @@ -655,7 +655,7 @@ func (lsn *listenerV2) enqueueForceFulfillmentForRevertedTxn( reqCommitment := revertedTxn.Commitment fromAddresses := lsn.fromAddresses() - fromAddress, err := lsn.gethks.GetRoundRobinAddress(lsn.chainID, fromAddresses...) + fromAddress, err := lsn.gethks.GetRoundRobinAddress(ctx, lsn.chainID, fromAddresses...) if err != nil { return txmgr.Tx{}, errors.Wrap(err, "failed_to_get_vrf_listener_from_address") } diff --git a/core/services/vrf/vrfcommon/types.go b/core/services/vrf/vrfcommon/types.go index 175b362b5aa..06988633e8e 100644 --- a/core/services/vrf/vrfcommon/types.go +++ b/core/services/vrf/vrfcommon/types.go @@ -1,6 +1,7 @@ package vrfcommon import ( + "context" "math/big" "github.com/ethereum/go-ethereum/common" @@ -10,7 +11,7 @@ import ( ) type GethKeyStore interface { - GetRoundRobinAddress(chainID *big.Int, addresses ...common.Address) (common.Address, error) + GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (common.Address, error) } //go:generate mockery --quiet --name Config --output ../mocks/ --case=underscore diff --git a/core/services/webhook/delegate.go b/core/services/webhook/delegate.go index 7e6aab0bb07..3211018d48d 100644 --- a/core/services/webhook/delegate.go +++ b/core/services/webhook/delegate.go @@ -76,7 +76,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) { func (d *Delegate) OnDeleteJob(jb job.Job, q pg.Queryer) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(spec job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { service := &pseudoService{ spec: spec, webhookJobRunner: d.webhookJobRunner, diff --git a/core/services/webhook/delegate_test.go b/core/services/webhook/delegate_test.go index 24c501e7545..c020f641615 100644 --- a/core/services/webhook/delegate_test.go +++ b/core/services/webhook/delegate_test.go @@ -21,6 +21,7 @@ import ( ) func TestWebhookDelegate(t *testing.T) { + ctx := testutils.Context(t) var ( spec = &job.Job{ ID: 123, @@ -50,18 +51,18 @@ func TestWebhookDelegate(t *testing.T) { delegate = webhook.NewDelegate(runner, eiManager, logger.TestLogger(t)) ) - services, err := delegate.ServicesForSpec(*spec) + services, err := delegate.ServicesForSpec(ctx, *spec) require.NoError(t, err) require.Len(t, services, 1) service := services[0] // Should error before service is started - _, err = delegate.WebhookJobRunner().RunJob(testutils.Context(t), spec.ExternalJobID, requestBody, meta) + _, err = delegate.WebhookJobRunner().RunJob(ctx, spec.ExternalJobID, requestBody, meta) require.Error(t, err) require.Equal(t, webhook.ErrJobNotExists, errors.Cause(err)) // Should succeed after service is started upon a successful run - err = service.Start(testutils.Context(t)) + err = service.Start(ctx) require.NoError(t, err) runner.On("Run", mock.Anything, mock.AnythingOfType("*pipeline.Run"), mock.Anything, mock.Anything, mock.Anything). @@ -73,7 +74,7 @@ func TestWebhookDelegate(t *testing.T) { require.Equal(t, vars, run.Inputs.Val) }).Once() - runID, err := delegate.WebhookJobRunner().RunJob(testutils.Context(t), spec.ExternalJobID, requestBody, meta) + runID, err := delegate.WebhookJobRunner().RunJob(ctx, spec.ExternalJobID, requestBody, meta) require.NoError(t, err) require.Equal(t, int64(123), runID) @@ -83,13 +84,13 @@ func TestWebhookDelegate(t *testing.T) { runner.On("Run", mock.Anything, mock.AnythingOfType("*pipeline.Run"), mock.Anything, mock.Anything, mock.Anything). Return(false, expectedErr).Once() - _, err = delegate.WebhookJobRunner().RunJob(testutils.Context(t), spec.ExternalJobID, requestBody, meta) + _, err = delegate.WebhookJobRunner().RunJob(ctx, spec.ExternalJobID, requestBody, meta) require.Equal(t, expectedErr, errors.Cause(err)) // Should error after service is stopped err = service.Close() require.NoError(t, err) - _, err = delegate.WebhookJobRunner().RunJob(testutils.Context(t), spec.ExternalJobID, requestBody, meta) + _, err = delegate.WebhookJobRunner().RunJob(ctx, spec.ExternalJobID, requestBody, meta) require.Equal(t, webhook.ErrJobNotExists, errors.Cause(err)) } diff --git a/core/services/workflows/delegate.go b/core/services/workflows/delegate.go index 6faa0bacdb8..2951c2b4aa3 100644 --- a/core/services/workflows/delegate.go +++ b/core/services/workflows/delegate.go @@ -1,6 +1,7 @@ package workflows import ( + "context" "fmt" "github.com/google/uuid" @@ -34,7 +35,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) {} func (d *Delegate) OnDeleteJob(jb job.Job, q pg.Queryer) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. -func (d *Delegate) ServicesForSpec(spec job.Job) ([]job.ServiceCtx, error) { +func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { engine, err := NewEngine(d.logger, d.registry) if err != nil { return nil, err diff --git a/core/services/workflows/engine_test.go b/core/services/workflows/engine_test.go index 339792fd06d..aa84dc29cc9 100644 --- a/core/services/workflows/engine_test.go +++ b/core/services/workflows/engine_test.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/values" coreCap "github.com/smartcontractkit/chainlink/v2/core/capabilities" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -66,7 +67,7 @@ func (m *mockTriggerCapability) UnregisterTrigger(ctx context.Context, req capab } func TestEngineWithHardcodedWorkflow(t *testing.T) { - ctx := context.Background() + ctx := testutils.Context(t) reg := coreCap.NewRegistry(logger.TestLogger(t)) trigger := &mockTriggerCapability{ diff --git a/core/web/eth_keys_controller.go b/core/web/eth_keys_controller.go index 4e95bc3cb89..e53f30a925a 100644 --- a/core/web/eth_keys_controller.go +++ b/core/web/eth_keys_controller.go @@ -85,13 +85,13 @@ func (ekc *ETHKeysController) Index(c *gin.Context) { ethKeyStore := ekc.app.GetKeyStore().Eth() var keys []ethkey.KeyV2 var err error - keys, err = ethKeyStore.GetAll() + keys, err = ethKeyStore.GetAll(c.Request.Context()) if err != nil { err = errors.Errorf("error getting unlocked keys: %v", err) jsonAPIError(c, http.StatusInternalServerError, err) return } - states, err := ethKeyStore.GetStatesForKeys(keys) + states, err := ethKeyStore.GetStatesForKeys(c.Request.Context(), keys) if err != nil { err = errors.Errorf("error getting key states: %v", err) jsonAPIError(c, http.StatusInternalServerError, err) @@ -99,7 +99,7 @@ func (ekc *ETHKeysController) Index(c *gin.Context) { } var resources []presenters.ETHKeyResource for _, state := range states { - key, err := ethKeyStore.Get(state.Address.Hex()) + key, err := ethKeyStore.Get(c.Request.Context(), state.Address.Hex()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -136,13 +136,13 @@ func (ekc *ETHKeysController) Create(c *gin.Context) { return } - key, err := ethKeyStore.Create(chain.ID()) + key, err := ethKeyStore.Create(c.Request.Context(), chain.ID()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return } - state, err := ethKeyStore.GetState(key.ID(), chain.ID()) + state, err := ethKeyStore.GetState(c.Request.Context(), key.ID(), chain.ID()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -169,7 +169,7 @@ func (ekc *ETHKeysController) Delete(c *gin.Context) { return } - key, err := ethKeyStore.Get(keyID) + key, err := ethKeyStore.Get(c.Request.Context(), keyID) if err != nil { if errors.Is(err, keystore.ErrKeyNotFound) { jsonAPIError(c, http.StatusNotFound, err) @@ -179,13 +179,13 @@ func (ekc *ETHKeysController) Delete(c *gin.Context) { return } - state, err := ethKeyStore.GetStateForKey(key) + state, err := ethKeyStore.GetStateForKey(c.Request.Context(), key) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return } - _, err = ethKeyStore.Delete(keyID) + _, err = ethKeyStore.Delete(c.Request.Context(), keyID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -217,13 +217,13 @@ func (ekc *ETHKeysController) Import(c *gin.Context) { return } - key, err := ethKeyStore.Import(bytes, oldPassword, chain.ID()) + key, err := ethKeyStore.Import(c.Request.Context(), bytes, oldPassword, chain.ID()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return } - state, err := ethKeyStore.GetState(key.ID(), chain.ID()) + state, err := ethKeyStore.GetState(c.Request.Context(), key.ID(), chain.ID()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -245,7 +245,7 @@ func (ekc *ETHKeysController) Export(c *gin.Context) { id := c.Param("address") newPassword := c.Query("newpassword") - bytes, err := ekc.app.GetKeyStore().Eth().Export(id, newPassword) + bytes, err := ekc.app.GetKeyStore().Eth().Export(c.Request.Context(), id, newPassword) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -312,9 +312,9 @@ func (ekc *ETHKeysController) Chain(c *gin.Context) { } if enabled { - err = kst.Enable(address, chain.ID()) + err = kst.Enable(c.Request.Context(), address, chain.ID()) } else { - err = kst.Disable(address, chain.ID()) + err = kst.Disable(c.Request.Context(), address, chain.ID()) } if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) @@ -322,13 +322,13 @@ func (ekc *ETHKeysController) Chain(c *gin.Context) { } } - key, err := kst.Get(keyID) + key, err := kst.Get(c.Request.Context(), keyID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return } - state, err := kst.GetState(key.ID(), chain.ID()) + state, err := kst.GetState(c.Request.Context(), key.ID(), chain.ID()) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return diff --git a/core/web/eth_keys_controller_test.go b/core/web/eth_keys_controller_test.go index e075b3196e1..545abefcd3e 100644 --- a/core/web/eth_keys_controller_test.go +++ b/core/web/eth_keys_controller_test.go @@ -175,7 +175,7 @@ func TestETHKeysController_Index_NotDev(t *testing.T) { defer cleanup() require.Equal(t, http.StatusOK, resp.StatusCode) - expectedKeys, err := app.KeyStore.Eth().GetAll() + expectedKeys, err := app.KeyStore.Eth().GetAll(testutils.Context(t)) require.NoError(t, err) var actualBalances []webpresenters.ETHKeyResource err = cltest.ParseJSONAPIResponse(t, resp, &actualBalances) diff --git a/core/web/resolver/eth_key_test.go b/core/web/resolver/eth_key_test.go index f574c885ff9..7d1e3ff5025 100644 --- a/core/web/resolver/eth_key_test.go +++ b/core/web/resolver/eth_key_test.go @@ -96,9 +96,9 @@ func TestResolver_ETHKeys(t *testing.T) { m := map[string]legacyevm.Chain{states[0].EVMChainID.String(): f.Mocks.chain} legacyEVMChains := legacyevm.NewLegacyChains(m, cfg.EVMConfigs()) - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(keys[0], nil) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(keys[0], nil) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.ethClient.On("LINKBalance", mock.Anything, address, linkAddr).Return(commonassets.NewLinkFromJuels(12), nil) f.Mocks.chain.On("Client").Return(f.Mocks.ethClient) f.Mocks.balM.On("GetEthBalance", address).Return(assets.NewEth(1)) @@ -158,9 +158,9 @@ func TestResolver_ETHKeys(t *testing.T) { } chainID := *big.NewI(12) f.Mocks.legacyEVMChains.On("Get", states[0].EVMChainID.String()).Return(nil, evmrelay.ErrNoChains) - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(keys[0], nil) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(keys[0], nil) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.relayerChainInterops.EVMChains = f.Mocks.legacyEVMChains f.Mocks.evmORM.PutChains(toml.EVMConfig{ChainID: &chainID}) f.Mocks.relayerChainInterops.Relayers = []loop.Relayer{ @@ -202,7 +202,7 @@ func TestResolver_ETHKeys(t *testing.T) { name: "generic error on GetAll()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.ethKs.On("GetAll").Return(nil, gError) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.App.On("GetKeyStore").Return(f.Mocks.keystore) }, @@ -221,8 +221,8 @@ func TestResolver_ETHKeys(t *testing.T) { name: "generic error on GetStatesForKeys()", authenticated: true, before: func(f *gqlTestFramework) { - f.Mocks.ethKs.On("GetAll").Return(keys, nil) - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(nil, gError) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(nil, gError) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.App.On("GetKeyStore").Return(f.Mocks.keystore) }, @@ -251,9 +251,9 @@ func TestResolver_ETHKeys(t *testing.T) { }, } - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(ethkey.KeyV2{}, gError) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(ethkey.KeyV2{}, gError) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.App.On("GetKeyStore").Return(f.Mocks.keystore) }, @@ -283,9 +283,9 @@ func TestResolver_ETHKeys(t *testing.T) { }, } - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(ethkey.KeyV2{}, nil) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(ethkey.KeyV2{}, nil) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.Mocks.legacyEVMChains.On("Get", states[0].EVMChainID.String()).Return(f.Mocks.chain, gError) f.Mocks.relayerChainInterops.EVMChains = f.Mocks.legacyEVMChains @@ -316,9 +316,9 @@ func TestResolver_ETHKeys(t *testing.T) { chainID := *big.NewI(12) linkAddr := common.HexToAddress("0x5431F5F973781809D18643b87B44921b11355d81") - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(keys[0], nil) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(keys[0], nil) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.keystore.On("Eth").Return(f.Mocks.ethKs) f.Mocks.ethClient.On("LINKBalance", mock.Anything, address, linkAddr).Return(commonassets.NewLinkFromJuels(12), gError) f.Mocks.legacyEVMChains.On("Get", states[0].EVMChainID.String()).Return(f.Mocks.chain, nil) @@ -378,9 +378,9 @@ func TestResolver_ETHKeys(t *testing.T) { chainID := *big.NewI(12) linkAddr := common.HexToAddress("0x5431F5F973781809D18643b87B44921b11355d81") - f.Mocks.ethKs.On("GetStatesForKeys", keys).Return(states, nil) - f.Mocks.ethKs.On("Get", keys[0].Address.Hex()).Return(keys[0], nil) - f.Mocks.ethKs.On("GetAll").Return(keys, nil) + f.Mocks.ethKs.On("GetStatesForKeys", mock.Anything, keys).Return(states, nil) + f.Mocks.ethKs.On("Get", mock.Anything, keys[0].Address.Hex()).Return(keys[0], nil) + f.Mocks.ethKs.On("GetAll", mock.Anything).Return(keys, nil) f.Mocks.ethClient.On("LINKBalance", mock.Anything, address, linkAddr).Return(commonassets.NewLinkFromJuels(12), nil) f.Mocks.chain.On("Client").Return(f.Mocks.ethClient) f.Mocks.chain.On("BalanceMonitor").Return(nil) diff --git a/core/web/resolver/query.go b/core/web/resolver/query.go index ccc9da2ab91..f9039fd17fc 100644 --- a/core/web/resolver/query.go +++ b/core/web/resolver/query.go @@ -417,12 +417,12 @@ func (r *Resolver) ETHKeys(ctx context.Context) (*ETHKeysPayloadResolver, error) ks := r.App.GetKeyStore().Eth() - keys, err := ks.GetAll() + keys, err := ks.GetAll(ctx) if err != nil { return nil, fmt.Errorf("error getting unlocked keys: %v", err) } - states, err := ks.GetStatesForKeys(keys) + states, err := ks.GetStatesForKeys(ctx, keys) if err != nil { return nil, fmt.Errorf("error getting key states: %v", err) } @@ -430,7 +430,7 @@ func (r *Resolver) ETHKeys(ctx context.Context) (*ETHKeysPayloadResolver, error) var ethKeys []ETHKey for _, state := range states { - k, err := ks.Get(state.Address.Hex()) + k, err := ks.Get(ctx, state.Address.Hex()) if err != nil { return nil, err }