diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index d82af8e7d0b..bc170bd4e74 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -47,7 +47,7 @@ func (d *Delegate) ServicesForSpec(spec job.Job) (services []job.ServiceCtx, err if err2 != nil { return nil, errors.Wrap(err2, "unmarshal gateway config") } - handlerFactory := NewHandlerFactory(d.lggr) + handlerFactory := NewHandlerFactory(d.chains, d.lggr) gateway, err := NewGatewayFromConfig(&gatewayConfig, handlerFactory, d.lggr) if err != nil { return nil, err diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index fb5d5d3b8bd..1490a9fccd4 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -51,7 +51,7 @@ HandlerName = "dummy" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr) require.NoError(t, err) } @@ -69,7 +69,7 @@ HandlerName = "dummy" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr) require.Error(t, err) } @@ -83,7 +83,7 @@ HandlerName = "no_such_handler" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr) require.Error(t, err) } @@ -97,7 +97,7 @@ SomeOtherField = "abcd" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr) require.Error(t, err) } diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index c80a1d3c6d6..225fe192075 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" @@ -16,19 +17,20 @@ const ( ) type handlerFactory struct { - lggr logger.Logger + chains evm.ChainSet + lggr logger.Logger } var _ HandlerFactory = (*handlerFactory)(nil) -func NewHandlerFactory(lggr logger.Logger) HandlerFactory { - return &handlerFactory{lggr} +func NewHandlerFactory(chains evm.ChainSet, lggr logger.Logger) HandlerFactory { + return &handlerFactory{chains, lggr} } func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON) (handlers.Handler, error) { switch handlerType { case FunctionsHandlerType: - return functions.NewFunctionsHandler(handlerConfig, donConfig, don, hf.lggr) + return functions.NewFunctionsHandler(handlerConfig, donConfig, don, hf.chains, hf.lggr) case DummyHandlerType: return handlers.NewDummyHandler(donConfig, don, hf.lggr) default: diff --git a/core/services/gateway/handlers/functions/allowlist.go b/core/services/gateway/handlers/functions/allowlist.go new file mode 100644 index 00000000000..a98392f127f --- /dev/null +++ b/core/services/gateway/handlers/functions/allowlist.go @@ -0,0 +1,108 @@ +package functions + +import ( + "context" + "fmt" + "math/big" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" + + evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/ocr2dr_oracle" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// OnchainAllowlist maintains an allowlist of addresses fetched from the blockchain (EVM-only). +// Use UpdateFromContract() for a one-time update or UpdatePeriodically() for periodic updates. +// All methods are thread-safe. +// +//go:generate mockery --quiet --name OnchainAllowlist --output ./mocks/ --case=underscore +type OnchainAllowlist interface { + Allow(common.Address) bool + UpdateFromContract(ctx context.Context) error + UpdatePeriodically(ctx context.Context, updateFrequency time.Duration, updateTimeout time.Duration) +} + +type onchainAllowlist struct { + allowlist atomic.Pointer[map[common.Address]struct{}] + client evmclient.Client + contract *ocr2dr_oracle.OCR2DROracle + blockConfirmations *big.Int + lggr logger.Logger +} + +func NewOnchainAllowlist(client evmclient.Client, contractAddress common.Address, blockConfirmations int64, lggr logger.Logger) (OnchainAllowlist, error) { + if client == nil { + return nil, errors.New("client is nil") + } + if lggr == nil { + return nil, errors.New("logger is nil") + } + contract, err := ocr2dr_oracle.NewOCR2DROracle(contractAddress, client) + if err != nil { + return nil, fmt.Errorf("unexpected error during NewOCR2DROracle: %s", err) + } + allowlist := &onchainAllowlist{ + client: client, + contract: contract, + blockConfirmations: big.NewInt(blockConfirmations), + lggr: lggr.Named("OnchainAllowlist"), + } + emptyMap := make(map[common.Address]struct{}) + allowlist.allowlist.Store(&emptyMap) + return allowlist, nil +} + +func (a *onchainAllowlist) Allow(address common.Address) bool { + allowlist := *a.allowlist.Load() + _, ok := allowlist[address] + return ok +} + +func (a *onchainAllowlist) UpdateFromContract(ctx context.Context) error { + latestBlockHeight, err := a.client.LatestBlockHeight(ctx) + if err != nil { + return errors.Wrap(err, "error calling LatestBlockHeight") + } + if latestBlockHeight == nil { + return errors.New("LatestBlockHeight returned nil") + } + blockNum := big.NewInt(0).Sub(latestBlockHeight, a.blockConfirmations) + addrList, err := a.contract.GetAuthorizedSenders(&bind.CallOpts{ + Pending: false, + BlockNumber: blockNum, + Context: ctx, + }) + if err != nil { + return errors.Wrap(err, "error calling GetAuthorizedSenders") + } + newAllowlist := make(map[common.Address]struct{}) + for _, addr := range addrList { + newAllowlist[addr] = struct{}{} + } + a.allowlist.Store(&newAllowlist) + a.lggr.Infow("allowlist updated successfully", "len", len(addrList), "blockNumber", blockNum) + return nil +} + +func (a *onchainAllowlist) UpdatePeriodically(ctx context.Context, updateFrequency time.Duration, updateTimeout time.Duration) { + ticker := time.NewTicker(updateFrequency) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + timeoutCtx, cancel := context.WithTimeout(ctx, updateTimeout) + err := a.UpdateFromContract(timeoutCtx) + if err != nil { + a.lggr.Errorw("error calling UpdateFromContract", "err", err) + } + cancel() + } + } +} diff --git a/core/services/gateway/handlers/functions/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist_test.go new file mode 100644 index 00000000000..5df33287dc1 --- /dev/null +++ b/core/services/gateway/handlers/functions/allowlist_test.go @@ -0,0 +1,67 @@ +package functions_test + +import ( + "context" + "encoding/hex" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" +) + +const ( + addr1 = "9ed925d8206a4f88a2f643b28b3035b315753cd6" + addr2 = "ea6721ac65bced841b8ec3fc5fedea6141a0ade4" + addr3 = "84689acc87ff22841b8ec378300da5e141a99911" +) + +func sampleEncodedAllowlist(t *testing.T) []byte { + abiEncodedAddresses := + "0000000000000000000000000000000000000000000000000000000000000020" + + "0000000000000000000000000000000000000000000000000000000000000002" + + "000000000000000000000000" + addr1 + + "000000000000000000000000" + addr2 + rawData, err := hex.DecodeString(abiEncodedAddresses) + require.NoError(t, err) + return rawData +} + +func TestAllowlist_UpdateAndCheck(t *testing.T) { + t.Parallel() + + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil) + allowlist, err := functions.NewOnchainAllowlist(client, common.Address{}, 1, logger.TestLogger(t)) + require.NoError(t, err) + + require.NoError(t, allowlist.UpdateFromContract(context.Background())) + require.False(t, allowlist.Allow(common.Address{})) + require.True(t, allowlist.Allow(common.HexToAddress(addr1))) + require.True(t, allowlist.Allow(common.HexToAddress(addr2))) + require.False(t, allowlist.Allow(common.HexToAddress(addr3))) +} + +func TestAllowlist_UpdatePeriodically(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + client := mocks.NewClient(t) + client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil) + client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + cancel() + }).Return(sampleEncodedAllowlist(t), nil) + allowlist, err := functions.NewOnchainAllowlist(client, common.Address{}, 1, logger.TestLogger(t)) + require.NoError(t, err) + + allowlist.UpdatePeriodically(ctx, time.Millisecond*10, time.Second*1) + require.True(t, allowlist.Allow(common.HexToAddress(addr1))) + require.False(t, allowlist.Allow(common.HexToAddress(addr3))) +} diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index 9f9ac10258d..17718489382 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -3,7 +3,14 @@ package functions import ( "context" "encoding/json" + "errors" + "math/big" + "sync" + "time" + "github.com/ethereum/go-ethereum/common" + + "github.com/smartcontractkit/chainlink/v2/core/chains/evm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" @@ -11,35 +18,84 @@ import ( ) type FunctionsHandlerConfig struct { + AllowlistCheckEnabled bool `json:"allowlistCheckEnabled"` + AllowlistChainID int64 `json:"allowlistChainID"` + AllowlistContractAddress string `json:"allowlistContractAddress"` + AllowlistBlockConfirmations int64 `json:"allowlistBlockConfirmations"` + AllowlistUpdateFrequencySec int `json:"allowlistUpdateFrequencySec"` + AllowlistUpdateTimeoutSec int `json:"allowlistUpdateTimeoutSec"` } type functionsHandler struct { - handlerConfig *FunctionsHandlerConfig - donConfig *config.DONConfig - don handlers.DON - lggr logger.Logger + handlerConfig *FunctionsHandlerConfig + donConfig *config.DONConfig + don handlers.DON + allowlist OnchainAllowlist + serviceContext context.Context + serviceCancel context.CancelFunc + shutdownWaitGroup sync.WaitGroup + lggr logger.Logger } var _ handlers.Handler = (*functionsHandler)(nil) -func NewFunctionsHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, lggr logger.Logger) (handlers.Handler, error) { - var parsedConfig FunctionsHandlerConfig - if err := json.Unmarshal(handlerConfig, &parsedConfig); err != nil { +func NewFunctionsHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, chains evm.ChainSet, lggr logger.Logger) (handlers.Handler, error) { + cfg, err := ParseConfig(handlerConfig) + if err != nil { return nil, err } + var allowlist OnchainAllowlist + if cfg.AllowlistCheckEnabled { + chain, err := chains.Get(big.NewInt(cfg.AllowlistChainID)) + if err != nil { + return nil, err + } + allowlist, err = NewOnchainAllowlist(chain.Client(), common.HexToAddress(cfg.AllowlistContractAddress), cfg.AllowlistBlockConfirmations, lggr) + if err != nil { + return nil, err + } + } + serviceContext, serviceCancel := context.WithCancel(context.Background()) return &functionsHandler{ - handlerConfig: &parsedConfig, - donConfig: donConfig, - don: don, - lggr: lggr, + handlerConfig: cfg, + donConfig: donConfig, + don: don, + allowlist: allowlist, + serviceContext: serviceContext, + serviceCancel: serviceCancel, + lggr: lggr, }, nil } +func ParseConfig(handlerConfig json.RawMessage) (*FunctionsHandlerConfig, error) { + var cfg FunctionsHandlerConfig + if err := json.Unmarshal(handlerConfig, &cfg); err != nil { + return nil, err + } + if cfg.AllowlistCheckEnabled { + if !common.IsHexAddress(cfg.AllowlistContractAddress) { + return nil, errors.New("allowlistContractAddress is not a valid hex address") + } + if cfg.AllowlistUpdateFrequencySec <= 0 { + return nil, errors.New("allowlistUpdateFrequencySec must be positive") + } + if cfg.AllowlistUpdateTimeoutSec <= 0 { + return nil, errors.New("allowlistUpdateTimeoutSec must be positive") + } + } + return &cfg, nil +} + func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { if err := msg.Validate(); err != nil { h.lggr.Debug("received invalid message", "err", err) return err } + sender := common.HexToAddress(msg.Body.Sender) + if h.allowlist != nil && !h.allowlist.Allow(sender) { + h.lggr.Debug("received a message from a non-allowlisted address", "sender", msg.Body.Sender) + return errors.New("sender not allowlisted") + } h.lggr.Debug("received a valid message", "sender", msg.Body.Sender) return nil } @@ -48,10 +104,21 @@ func (h *functionsHandler) HandleNodeMessage(ctx context.Context, msg *api.Messa return nil } -func (h *functionsHandler) Start(context.Context) error { +func (h *functionsHandler) Start(ctx context.Context) error { + if h.allowlist != nil { + checkFreq := time.Duration(h.handlerConfig.AllowlistUpdateFrequencySec) * time.Second + checkTimeout := time.Duration(h.handlerConfig.AllowlistUpdateTimeoutSec) * time.Second + h.shutdownWaitGroup.Add(1) + go func() { + h.allowlist.UpdatePeriodically(h.serviceContext, checkFreq, checkTimeout) + h.shutdownWaitGroup.Done() + }() + } return nil } func (h *functionsHandler) Close() error { + h.serviceCancel() + h.shutdownWaitGroup.Wait() return nil } diff --git a/core/services/gateway/handlers/functions/handler.functions_test.go b/core/services/gateway/handlers/functions/handler.functions_test.go index d1418283c4f..fb51338b5c2 100644 --- a/core/services/gateway/handlers/functions/handler.functions_test.go +++ b/core/services/gateway/handlers/functions/handler.functions_test.go @@ -15,7 +15,7 @@ import ( func TestFunctionsHandler_Basic(t *testing.T) { t.Parallel() - handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, logger.TestLogger(t)) + handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, logger.TestLogger(t)) require.NoError(t, err) // nil message diff --git a/core/services/gateway/handlers/functions/mocks/onchain_allowlist.go b/core/services/gateway/handlers/functions/mocks/onchain_allowlist.go new file mode 100644 index 00000000000..4b2f557fd9c --- /dev/null +++ b/core/services/gateway/handlers/functions/mocks/onchain_allowlist.go @@ -0,0 +1,66 @@ +// Code generated by mockery v2.28.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + common "github.com/ethereum/go-ethereum/common" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// OnchainAllowlist is an autogenerated mock type for the OnchainAllowlist type +type OnchainAllowlist struct { + mock.Mock +} + +// Allow provides a mock function with given fields: _a0 +func (_m *OnchainAllowlist) Allow(_a0 common.Address) bool { + ret := _m.Called(_a0) + + var r0 bool + if rf, ok := ret.Get(0).(func(common.Address) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// UpdateFromContract provides a mock function with given fields: ctx +func (_m *OnchainAllowlist) UpdateFromContract(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePeriodically provides a mock function with given fields: ctx, updateFrequency, updateTimeout +func (_m *OnchainAllowlist) UpdatePeriodically(ctx context.Context, updateFrequency time.Duration, updateTimeout time.Duration) { + _m.Called(ctx, updateFrequency, updateTimeout) +} + +type mockConstructorTestingTNewOnchainAllowlist interface { + mock.TestingT + Cleanup(func()) +} + +// NewOnchainAllowlist creates a new instance of OnchainAllowlist. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewOnchainAllowlist(t mockConstructorTestingTNewOnchainAllowlist) *OnchainAllowlist { + mock := &OnchainAllowlist{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}