Skip to content

Commit

Permalink
[Functions] Allowlist support
Browse files Browse the repository at this point in the history
  • Loading branch information
bolekk committed Jun 20, 2023
1 parent f6da19d commit a95408c
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 22 deletions.
2 changes: 1 addition & 1 deletion core/services/gateway/delegate.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (d *Delegate) ServicesForSpec(spec job.Job) (services []job.ServiceCtx, err
if err2 != nil {
return nil, errors.Wrap(err2, "unmarshal gateway config")
}
handlerFactory := NewHandlerFactory(d.lggr)
handlerFactory := NewHandlerFactory(d.chains, d.lggr)
gateway, err := NewGatewayFromConfig(&gatewayConfig, handlerFactory, d.lggr)
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions core/services/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HandlerName = "dummy"
`)

lggr := logger.TestLogger(t)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr)
require.NoError(t, err)
}

Expand All @@ -69,7 +69,7 @@ HandlerName = "dummy"
`)

lggr := logger.TestLogger(t)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr)
require.Error(t, err)
}

Expand All @@ -83,7 +83,7 @@ HandlerName = "no_such_handler"
`)

lggr := logger.TestLogger(t)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr)
require.Error(t, err)
}

Expand All @@ -97,7 +97,7 @@ SomeOtherField = "abcd"
`)

lggr := logger.TestLogger(t)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(lggr), lggr)
_, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr)
require.Error(t, err)
}

Expand Down
10 changes: 6 additions & 4 deletions core/services/gateway/handler_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/config"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers"
Expand All @@ -16,19 +17,20 @@ const (
)

type handlerFactory struct {
lggr logger.Logger
chains evm.ChainSet
lggr logger.Logger
}

var _ HandlerFactory = (*handlerFactory)(nil)

func NewHandlerFactory(lggr logger.Logger) HandlerFactory {
return &handlerFactory{lggr}
func NewHandlerFactory(chains evm.ChainSet, lggr logger.Logger) HandlerFactory {
return &handlerFactory{chains, lggr}
}

func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON) (handlers.Handler, error) {
switch handlerType {
case FunctionsHandlerType:
return functions.NewFunctionsHandler(handlerConfig, donConfig, don, hf.lggr)
return functions.NewFunctionsHandler(handlerConfig, donConfig, don, hf.chains, hf.lggr)
case DummyHandlerType:
return handlers.NewDummyHandler(donConfig, don, hf.lggr)
default:
Expand Down
108 changes: 108 additions & 0 deletions core/services/gateway/handlers/functions/allowlist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package functions

import (
"context"
"fmt"
"math/big"
"sync/atomic"
"time"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/pkg/errors"

evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client"
"github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/ocr2dr_oracle"
"github.com/smartcontractkit/chainlink/v2/core/logger"
)

// OnchainAllowlist maintains an allowlist of addresses fetched from the blockchain (EVM-only).
// Use UpdateFromContract() for a one-time update or UpdatePeriodically() for periodic updates.
// All methods are thread-safe.
//
//go:generate mockery --quiet --name OnchainAllowlist --output ./mocks/ --case=underscore
type OnchainAllowlist interface {
Allow(common.Address) bool
UpdateFromContract(ctx context.Context) error
UpdatePeriodically(ctx context.Context, updateFrequency time.Duration, updateTimeout time.Duration)
}

type onchainAllowlist struct {
allowlist atomic.Pointer[map[common.Address]struct{}]
client evmclient.Client
contract *ocr2dr_oracle.OCR2DROracle
blockConfirmations *big.Int
lggr logger.Logger
}

func NewOnchainAllowlist(client evmclient.Client, contractAddress common.Address, blockConfirmations int64, lggr logger.Logger) (OnchainAllowlist, error) {
if client == nil {
return nil, errors.New("client is nil")
}
if lggr == nil {
return nil, errors.New("logger is nil")
}
contract, err := ocr2dr_oracle.NewOCR2DROracle(contractAddress, client)
if err != nil {
return nil, fmt.Errorf("unexpected error during NewOCR2DROracle: %s", err)
}
allowlist := &onchainAllowlist{
client: client,
contract: contract,
blockConfirmations: big.NewInt(blockConfirmations),
lggr: lggr.Named("OnchainAllowlist"),
}
emptyMap := make(map[common.Address]struct{})
allowlist.allowlist.Store(&emptyMap)
return allowlist, nil
}

func (a *onchainAllowlist) Allow(address common.Address) bool {
allowlist := *a.allowlist.Load()
_, ok := allowlist[address]
return ok
}

func (a *onchainAllowlist) UpdateFromContract(ctx context.Context) error {
latestBlockHeight, err := a.client.LatestBlockHeight(ctx)
if err != nil {
return errors.Wrap(err, "error calling LatestBlockHeight")
}
if latestBlockHeight == nil {
return errors.New("LatestBlockHeight returned nil")
}
blockNum := big.NewInt(0).Sub(latestBlockHeight, a.blockConfirmations)
addrList, err := a.contract.GetAuthorizedSenders(&bind.CallOpts{
Pending: false,
BlockNumber: blockNum,
Context: ctx,
})
if err != nil {
return errors.Wrap(err, "error calling GetAuthorizedSenders")
}
newAllowlist := make(map[common.Address]struct{})
for _, addr := range addrList {
newAllowlist[addr] = struct{}{}
}
a.allowlist.Store(&newAllowlist)
a.lggr.Infow("allowlist updated successfully", "len", len(addrList), "blockNumber", blockNum)
return nil
}

func (a *onchainAllowlist) UpdatePeriodically(ctx context.Context, updateFrequency time.Duration, updateTimeout time.Duration) {
ticker := time.NewTicker(updateFrequency)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
timeoutCtx, cancel := context.WithTimeout(ctx, updateTimeout)
err := a.UpdateFromContract(timeoutCtx)
if err != nil {
a.lggr.Errorw("error calling UpdateFromContract", "err", err)
}
cancel()
}
}
}
67 changes: 67 additions & 0 deletions core/services/gateway/handlers/functions/allowlist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package functions_test

import (
"context"
"encoding/hex"
"math/big"
"testing"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions"
)

const (
addr1 = "9ed925d8206a4f88a2f643b28b3035b315753cd6"
addr2 = "ea6721ac65bced841b8ec3fc5fedea6141a0ade4"
addr3 = "84689acc87ff22841b8ec378300da5e141a99911"
)

func sampleEncodedAllowlist(t *testing.T) []byte {
abiEncodedAddresses :=
"0000000000000000000000000000000000000000000000000000000000000020" +
"0000000000000000000000000000000000000000000000000000000000000002" +
"000000000000000000000000" + addr1 +
"000000000000000000000000" + addr2
rawData, err := hex.DecodeString(abiEncodedAddresses)
require.NoError(t, err)
return rawData
}

func TestAllowlist_UpdateAndCheck(t *testing.T) {
t.Parallel()

client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(sampleEncodedAllowlist(t), nil)
allowlist, err := functions.NewOnchainAllowlist(client, common.Address{}, 1, logger.TestLogger(t))
require.NoError(t, err)

require.NoError(t, allowlist.UpdateFromContract(context.Background()))
require.False(t, allowlist.Allow(common.Address{}))
require.True(t, allowlist.Allow(common.HexToAddress(addr1)))
require.True(t, allowlist.Allow(common.HexToAddress(addr2)))
require.False(t, allowlist.Allow(common.HexToAddress(addr3)))
}

func TestAllowlist_UpdatePeriodically(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
client := mocks.NewClient(t)
client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
client.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
cancel()
}).Return(sampleEncodedAllowlist(t), nil)
allowlist, err := functions.NewOnchainAllowlist(client, common.Address{}, 1, logger.TestLogger(t))
require.NoError(t, err)

allowlist.UpdatePeriodically(ctx, time.Millisecond*10, time.Second*1)
require.True(t, allowlist.Allow(common.HexToAddress(addr1)))
require.False(t, allowlist.Allow(common.HexToAddress(addr3)))
}
91 changes: 79 additions & 12 deletions core/services/gateway/handlers/functions/handler.functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,99 @@ package functions
import (
"context"
"encoding/json"
"errors"
"math/big"
"sync"
"time"

"github.com/ethereum/go-ethereum/common"

"github.com/smartcontractkit/chainlink/v2/core/chains/evm"
"github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/api"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/config"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers"
)

type FunctionsHandlerConfig struct {
AllowlistCheckEnabled bool `json:"allowlistCheckEnabled"`
AllowlistChainID int64 `json:"allowlistChainID"`
AllowlistContractAddress string `json:"allowlistContractAddress"`
AllowlistBlockConfirmations int64 `json:"allowlistBlockConfirmations"`
AllowlistUpdateFrequencySec int `json:"allowlistUpdateFrequencySec"`
AllowlistUpdateTimeoutSec int `json:"allowlistUpdateTimeoutSec"`
}

type functionsHandler struct {
handlerConfig *FunctionsHandlerConfig
donConfig *config.DONConfig
don handlers.DON
lggr logger.Logger
handlerConfig *FunctionsHandlerConfig
donConfig *config.DONConfig
don handlers.DON
allowlist OnchainAllowlist
serviceContext context.Context
serviceCancel context.CancelFunc
shutdownWaitGroup sync.WaitGroup
lggr logger.Logger
}

var _ handlers.Handler = (*functionsHandler)(nil)

func NewFunctionsHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, lggr logger.Logger) (handlers.Handler, error) {
var parsedConfig FunctionsHandlerConfig
if err := json.Unmarshal(handlerConfig, &parsedConfig); err != nil {
func NewFunctionsHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, chains evm.ChainSet, lggr logger.Logger) (handlers.Handler, error) {
cfg, err := ParseConfig(handlerConfig)
if err != nil {
return nil, err
}
var allowlist OnchainAllowlist
if cfg.AllowlistCheckEnabled {
chain, err := chains.Get(big.NewInt(cfg.AllowlistChainID))
if err != nil {
return nil, err
}
allowlist, err = NewOnchainAllowlist(chain.Client(), common.HexToAddress(cfg.AllowlistContractAddress), cfg.AllowlistBlockConfirmations, lggr)
if err != nil {
return nil, err
}
}
serviceContext, serviceCancel := context.WithCancel(context.Background())
return &functionsHandler{
handlerConfig: &parsedConfig,
donConfig: donConfig,
don: don,
lggr: lggr,
handlerConfig: cfg,
donConfig: donConfig,
don: don,
allowlist: allowlist,
serviceContext: serviceContext,
serviceCancel: serviceCancel,
lggr: lggr,
}, nil
}

func ParseConfig(handlerConfig json.RawMessage) (*FunctionsHandlerConfig, error) {
var cfg FunctionsHandlerConfig
if err := json.Unmarshal(handlerConfig, &cfg); err != nil {
return nil, err
}
if cfg.AllowlistCheckEnabled {
if !common.IsHexAddress(cfg.AllowlistContractAddress) {
return nil, errors.New("allowlistContractAddress is not a valid hex address")
}
if cfg.AllowlistUpdateFrequencySec <= 0 {
return nil, errors.New("allowlistUpdateFrequencySec must be positive")
}
if cfg.AllowlistUpdateTimeoutSec <= 0 {
return nil, errors.New("allowlistUpdateTimeoutSec must be positive")
}
}
return &cfg, nil
}

func (h *functionsHandler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error {
if err := msg.Validate(); err != nil {
h.lggr.Debug("received invalid message", "err", err)
return err
}
sender := common.HexToAddress(msg.Body.Sender)
if h.allowlist != nil && !h.allowlist.Allow(sender) {
h.lggr.Debug("received a message from a non-allowlisted address", "sender", msg.Body.Sender)
return errors.New("sender not allowlisted")
}
h.lggr.Debug("received a valid message", "sender", msg.Body.Sender)
return nil
}
Expand All @@ -48,10 +104,21 @@ func (h *functionsHandler) HandleNodeMessage(ctx context.Context, msg *api.Messa
return nil
}

func (h *functionsHandler) Start(context.Context) error {
func (h *functionsHandler) Start(ctx context.Context) error {
if h.allowlist != nil {
checkFreq := time.Duration(h.handlerConfig.AllowlistUpdateFrequencySec) * time.Second
checkTimeout := time.Duration(h.handlerConfig.AllowlistUpdateTimeoutSec) * time.Second
h.shutdownWaitGroup.Add(1)
go func() {
h.allowlist.UpdatePeriodically(h.serviceContext, checkFreq, checkTimeout)
h.shutdownWaitGroup.Done()
}()
}
return nil
}

func (h *functionsHandler) Close() error {
h.serviceCancel()
h.shutdownWaitGroup.Wait()
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func TestFunctionsHandler_Basic(t *testing.T) {
t.Parallel()

handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, logger.TestLogger(t))
handler, err := functions.NewFunctionsHandler(json.RawMessage("{}"), &config.DONConfig{}, nil, nil, logger.TestLogger(t))
require.NoError(t, err)

// nil message
Expand Down
Loading

0 comments on commit a95408c

Please sign in to comment.