diff --git a/.github/workflows/mock_checker.yml b/.github/workflows/mock_checker.yml new file mode 100644 index 00000000..0454e1d4 --- /dev/null +++ b/.github/workflows/mock_checker.yml @@ -0,0 +1,35 @@ +name: Check generated code is up to date + +on: + push: + branches: + - main + pull_request: + branches: + - "**" + +jobs: + generated_code: + name: generated_code + runs-on: ubuntu-22.04 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + + - name: Generate code + run: | + scripts/generate.sh + + - name: Print diff + run: git --no-pager diff + + - name: Fail if diff exists + run: git --no-pager diff --quiet diff --git a/peers/app_request_network.go b/peers/app_request_network.go index d26810e6..eed93bfb 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -1,6 +1,8 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. +//go:generate mockgen -source=$GOFILE -destination=./mocks/mock_app_request_network.go -package=mocks + package peers import ( diff --git a/peers/mocks/mock_app_request_network.go b/peers/mocks/mock_app_request_network.go new file mode 100644 index 00000000..f989358c --- /dev/null +++ b/peers/mocks/mock_app_request_network.go @@ -0,0 +1,128 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: app_request_network.go +// +// Generated by this command: +// +// mockgen -source=app_request_network.go -destination=./mocks/mock_app_request_network.go -package=mocks +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + ids "github.com/ava-labs/avalanchego/ids" + message "github.com/ava-labs/avalanchego/message" + subnets "github.com/ava-labs/avalanchego/subnets" + set "github.com/ava-labs/avalanchego/utils/set" + peers "github.com/ava-labs/awm-relayer/peers" + gomock "go.uber.org/mock/gomock" +) + +// MockAppRequestNetwork is a mock of AppRequestNetwork interface. +type MockAppRequestNetwork struct { + ctrl *gomock.Controller + recorder *MockAppRequestNetworkMockRecorder +} + +// MockAppRequestNetworkMockRecorder is the mock recorder for MockAppRequestNetwork. +type MockAppRequestNetworkMockRecorder struct { + mock *MockAppRequestNetwork +} + +// NewMockAppRequestNetwork creates a new mock instance. +func NewMockAppRequestNetwork(ctrl *gomock.Controller) *MockAppRequestNetwork { + mock := &MockAppRequestNetwork{ctrl: ctrl} + mock.recorder = &MockAppRequestNetworkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAppRequestNetwork) EXPECT() *MockAppRequestNetworkMockRecorder { + return m.recorder +} + +// ConnectPeers mocks base method. +func (m *MockAppRequestNetwork) ConnectPeers(nodeIDs set.Set[ids.NodeID]) set.Set[ids.NodeID] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectPeers", nodeIDs) + ret0, _ := ret[0].(set.Set[ids.NodeID]) + return ret0 +} + +// ConnectPeers indicates an expected call of ConnectPeers. +func (mr *MockAppRequestNetworkMockRecorder) ConnectPeers(nodeIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectPeers", reflect.TypeOf((*MockAppRequestNetwork)(nil).ConnectPeers), nodeIDs) +} + +// ConnectToCanonicalValidators mocks base method. +func (m *MockAppRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*peers.ConnectedCanonicalValidators, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectToCanonicalValidators", subnetID) + ret0, _ := ret[0].(*peers.ConnectedCanonicalValidators) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ConnectToCanonicalValidators indicates an expected call of ConnectToCanonicalValidators. +func (mr *MockAppRequestNetworkMockRecorder) ConnectToCanonicalValidators(subnetID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectToCanonicalValidators", reflect.TypeOf((*MockAppRequestNetwork)(nil).ConnectToCanonicalValidators), subnetID) +} + +// GetSubnetID mocks base method. +func (m *MockAppRequestNetwork) GetSubnetID(blockchainID ids.ID) (ids.ID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubnetID", blockchainID) + ret0, _ := ret[0].(ids.ID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubnetID indicates an expected call of GetSubnetID. +func (mr *MockAppRequestNetworkMockRecorder) GetSubnetID(blockchainID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetID", reflect.TypeOf((*MockAppRequestNetwork)(nil).GetSubnetID), blockchainID) +} + +// RegisterAppRequest mocks base method. +func (m *MockAppRequestNetwork) RegisterAppRequest(requestID ids.RequestID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAppRequest", requestID) +} + +// RegisterAppRequest indicates an expected call of RegisterAppRequest. +func (mr *MockAppRequestNetworkMockRecorder) RegisterAppRequest(requestID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAppRequest", reflect.TypeOf((*MockAppRequestNetwork)(nil).RegisterAppRequest), requestID) +} + +// RegisterRequestID mocks base method. +func (m *MockAppRequestNetwork) RegisterRequestID(requestID uint32, numExpectedResponse int) chan message.InboundMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterRequestID", requestID, numExpectedResponse) + ret0, _ := ret[0].(chan message.InboundMessage) + return ret0 +} + +// RegisterRequestID indicates an expected call of RegisterRequestID. +func (mr *MockAppRequestNetworkMockRecorder) RegisterRequestID(requestID, numExpectedResponse any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterRequestID", reflect.TypeOf((*MockAppRequestNetwork)(nil).RegisterRequestID), requestID, numExpectedResponse) +} + +// Send mocks base method. +func (m *MockAppRequestNetwork) Send(msg message.OutboundMessage, nodeIDs set.Set[ids.NodeID], subnetID ids.ID, allower subnets.Allower) set.Set[ids.NodeID] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", msg, nodeIDs, subnetID, allower) + ret0, _ := ret[0].(set.Set[ids.NodeID]) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockAppRequestNetworkMockRecorder) Send(msg, nodeIDs, subnetID, allower any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAppRequestNetwork)(nil).Send), msg, nodeIDs, subnetID, allower) +} diff --git a/scripts/build.sh b/scripts/build.sh index 3acd64a9..a25a2ed6 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -2,11 +2,14 @@ # Copyright (C) 2024, Ava Labs, Inc. All rights reserved. # See the file LICENSE for licensing terms. +set -e errexit + # Root directory root=$( cd "$(dirname "${BASH_SOURCE[0]}")" cd .. && pwd ) +"$root"/scripts/generate.sh "$root"/scripts/build_relayer.sh "$root"/scripts/build_signature_aggregator.sh diff --git a/scripts/generate.sh b/scripts/generate.sh new file mode 100755 index 00000000..04b0518c --- /dev/null +++ b/scripts/generate.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +# Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +# See the file LICENSE for licensing terms. + +set -e errexit + +source "$root"/scripts/versions.sh +go install -v "go.uber.org/mock/mockgen@$(getDepVersion go.uber.org/mock)" +PATH="$PATH:$(go env GOPATH)/bin" go generate ./... diff --git a/signature-aggregator/aggregator/aggregator.go b/signature-aggregator/aggregator/aggregator.go index 979764a0..161066fc 100644 --- a/signature-aggregator/aggregator/aggregator.go +++ b/signature-aggregator/aggregator/aggregator.go @@ -343,6 +343,8 @@ func (s *SignatureAggregator) CreateSignedMessage( return nil, errNotEnoughSignatures } +// TODO: consider making this function private. its only reference seems to be +// within this module. func (s *SignatureAggregator) GetSubnetID(blockchainID ids.ID) (ids.ID, error) { s.subnetsMapLock.RLock() subnetID, ok := s.subnetIDsByBlockchainID[blockchainID] @@ -359,6 +361,8 @@ func (s *SignatureAggregator) GetSubnetID(blockchainID ids.ID) (ids.ID, error) { return subnetID, nil } +// TODO: consider making this function private. its only reference seems to be +// within this module. func (s *SignatureAggregator) SetSubnetID(blockchainID ids.ID, subnetID ids.ID) { s.subnetsMapLock.Lock() s.subnetIDsByBlockchainID[blockchainID] = subnetID diff --git a/signature-aggregator/aggregator/aggregator_test.go b/signature-aggregator/aggregator/aggregator_test.go new file mode 100644 index 00000000..13a6952c --- /dev/null +++ b/signature-aggregator/aggregator/aggregator_test.go @@ -0,0 +1,87 @@ +package aggregator + +import ( + "testing" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" + "github.com/ava-labs/awm-relayer/peers" + "github.com/ava-labs/awm-relayer/peers/mocks" + "github.com/ava-labs/awm-relayer/signature-aggregator/metrics" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +var sigAggMetrics *metrics.SignatureAggregatorMetrics +var messageCreator message.Creator + +func instantiateAggregator(t *testing.T) ( + *SignatureAggregator, + *mocks.MockAppRequestNetwork, +) { + mockNetwork := mocks.NewMockAppRequestNetwork(gomock.NewController(t)) + if sigAggMetrics == nil { + sigAggMetrics = metrics.NewSignatureAggregatorMetrics(prometheus.DefaultRegisterer) + } + if messageCreator == nil { + var err error + messageCreator, err = message.NewCreator( + logging.NoLog{}, + prometheus.DefaultRegisterer, + constants.DefaultNetworkCompressionType, + constants.DefaultNetworkMaximumInboundTimeout, + ) + require.Equal(t, err, nil) + } + aggregator, err := NewSignatureAggregator( + mockNetwork, + logging.NoLog{}, + 1024, + sigAggMetrics, + messageCreator, + ) + require.Equal(t, err, nil) + return aggregator, mockNetwork +} + +func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) + require.Equal(t, err, nil) + mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) + mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( + &peers.ConnectedCanonicalValidators{ + ConnectedWeight: 0, + TotalValidatorWeight: 0, + ValidatorSet: []*warp.Validator{}, + }, + nil, + ) + _, err = aggregator.CreateSignedMessage(msg, ids.Empty, 80) + require.ErrorContains(t, err, "no signatures") +} + +func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) + require.Equal(t, err, nil) + mockNetwork.EXPECT().GetSubnetID(ids.Empty).Return(ids.Empty, nil) + mockNetwork.EXPECT().ConnectToCanonicalValidators(ids.Empty).Return( + &peers.ConnectedCanonicalValidators{ + ConnectedWeight: 0, + TotalValidatorWeight: 1, + ValidatorSet: []*warp.Validator{}, + }, + nil, + ) + _, err = aggregator.CreateSignedMessage(msg, ids.Empty, 80) + require.ErrorContains( + t, + err, + "failed to connect to a threshold of stake", + ) +}