diff --git a/cmd/ostracon/commands/show_validator_test.go b/cmd/ostracon/commands/show_validator_test.go index 5f708d428..4cc5ecf93 100644 --- a/cmd/ostracon/commands/show_validator_test.go +++ b/cmd/ostracon/commands/show_validator_test.go @@ -80,7 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) { } privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) { config.PrivValidatorListenAddr = addr - config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")] + config.PrivValidatorRemoteAddresses = []string{addr[:strings.Index(addr, ":")]} require.NoFileExists(t, config.PrivValidatorKeyFile()) output, err := captureStdout(func() { err := showValidator(ShowValidatorCmd, nil, config) diff --git a/config/config.go b/config/config.go index 052f01a85..1b457679b 100644 --- a/config/config.go +++ b/config/config.go @@ -245,10 +245,12 @@ type BaseConfig struct { //nolint: maligned // example) tcp://0.0.0.0:26659 PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"` - // Validator's remote address(without port) to allow a connection - // ostracon only allow a connection from this address - // example) 10.0.0.7 - PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"` + // Validator's remote address to allow a connection + // Comma separated list of addresses to allow + // ostracon only allows a connection from these addresses separated by a comma + // example) 127.0.0.1 + // example) 127.0.0.1,192.168.1.2 + PrivValidatorRemoteAddresses []string `mapstructure:"priv_validator_raddrs"` // A JSON file containing the private key to use for p2p authenticated encryption NodeKey string `mapstructure:"node_key_file"` diff --git a/config/toml.go b/config/toml.go index 34dc59204..2072ac9ad 100644 --- a/config/toml.go +++ b/config/toml.go @@ -160,9 +160,11 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}" priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}" # Validator's remote address to allow a connection -# ostracon only allow a connection from this address +# Comma separated list of addresses to allow +# ostracon only allows a connection from these addresses separated by a comma # example) 127.0.0.1 -priv_validator_raddr = "127.0.0.1" +# example) 127.0.0.1,192.168.1.2 +priv_validator_raddrs = "127.0.0.1" # Path to the JSON file containing the private key to use for node authentication in the p2p protocol node_key_file = "{{ js .BaseConfig.NodeKey }}" diff --git a/node/node.go b/node/node.go index 7bad342bd..6cf8156aa 100644 --- a/node/node.go +++ b/node/node.go @@ -1524,7 +1524,7 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error { } func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) { - pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr) + pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddresses) if err != nil { return nil, fmt.Errorf("failed to start private validator: %w", err) } diff --git a/node/node_test.go b/node/node_test.go index f7d0e4f16..0a8787393 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -170,7 +170,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { if err != nil { return } - config.BaseConfig.PrivValidatorRemoteAddr = addrPart + config.BaseConfig.PrivValidatorRemoteAddresses = []string{addrPart} dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey()) dialerEndpoint := privval.NewSignerDialerEndpoint( diff --git a/privval/internal/ip_filter.go b/privval/internal/ip_filter.go index e9f7cceb8..c513fba68 100644 --- a/privval/internal/ip_filter.go +++ b/privval/internal/ip_filter.go @@ -4,16 +4,17 @@ import ( "fmt" "github.com/Finschia/ostracon/libs/log" "net" + "strings" ) type IpFilter struct { - allowAddr string + allowList []string log log.Logger } -func NewIpFilter(addr string, l log.Logger) *IpFilter { +func NewIpFilter(allowAddresses []string, l log.Logger) *IpFilter { return &IpFilter{ - allowAddr: addr, + allowList: allowAddresses, log: l, } } @@ -26,11 +27,11 @@ func (f *IpFilter) Filter(addr net.Addr) net.Addr { } func (f *IpFilter) String() string { - return f.allowAddr + return strings.Join(f.allowList, ",") } func (f *IpFilter) isAllowedAddr(addr net.Addr) bool { - if len(f.allowAddr) == 0 { + if len(f.allowList) == 0 { return false } hostAddr, _, err := net.SplitHostPort(addr.String()) @@ -40,5 +41,10 @@ func (f *IpFilter) isAllowedAddr(addr net.Addr) bool { } return false } - return f.allowAddr == hostAddr + for _, address := range f.allowList { + if address == hostAddr { + return true + } + } + return false } diff --git a/privval/internal/ip_filter_test.go b/privval/internal/ip_filter_test.go index 6257fa7ca..9d7965a14 100644 --- a/privval/internal/ip_filter_test.go +++ b/privval/internal/ip_filter_test.go @@ -3,6 +3,7 @@ package internal import ( "github.com/stretchr/testify/assert" "net" + "strings" "testing" ) @@ -71,21 +72,64 @@ func TestFilterRemoteConnectionByIP(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cut := NewIpFilter(tt.fields.allowIP, nil) + cut := NewIpFilter([]string{tt.fields.allowIP}, nil) + assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name) + }) + } +} + +func TestFilterRemoteConnectionByIPWithMultipleAllowIPs(t *testing.T) { + type fields struct { + allowList []string + remoteAddr net.Addr + expected net.Addr + } + tests := []struct { + name string + fields fields + }{ + { + "should allow the one in the allow list", + struct { + allowList []string + remoteAddr net.Addr + expected net.Addr + }{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"192.168.1.1:45678"}, addrStub{"192.168.1.1:45678"}}, + }, + { + "should not allow any ip which is not in the allow list", + struct { + allowList []string + remoteAddr net.Addr + expected net.Addr + }{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"10.0.0.2:45678"}, nil}, + }, + { + "should works for IPv6 with one of correct ip in the allow list", + struct { + allowList []string + remoteAddr net.Addr + expected net.Addr + }{[]string{"2001:db8::1", "2001:db8::2"}, addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cut := NewIpFilter(tt.fields.allowList, nil) assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name) }) } } func TestIpFilterShouldSetAllowAddress(t *testing.T) { - expected := "192.168.0.1" + expected := []string{"192.168.0.1"} cut := NewIpFilter(expected, nil) - assert.Equal(t, expected, cut.allowAddr) + assert.Equal(t, expected, cut.allowList) } func TestIpFilterStringShouldReturnsIP(t *testing.T) { - expected := "127.0.0.1" - assert.Equal(t, expected, NewIpFilter(expected, nil).String()) + expected := []string{"127.0.0.1", "192.168.1.10"} + assert.Equal(t, strings.Join(expected, ","), NewIpFilter(expected, nil).String()) } diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index c2c96642d..44b3a3c37 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -26,12 +26,11 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene } // SignerListenerEndpointAllowAddress sets the address to allow -// connections from the only allowed address -// -func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption { +// connections from the only allowed addresses +func SignerListenerEndpointAllowAddress(protocol string, allowedAddresses []string) SignerListenerEndpointOption { return func(sl *SignerListenerEndpoint) { if protocol == "tcp" || len(protocol) == 0 { - sl.connFilter = internal.NewIpFilter(addr, sl.Logger) + sl.connFilter = internal.NewIpFilter(allowedAddresses, sl.Logger) return } sl.connFilter = internal.NewNullObject() diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index 317bab82f..690c8d836 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -216,13 +216,13 @@ func getMockEndpoints( } func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) { - cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1")) + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", []string{"127.0.0.1"})) _, ok := cut.connFilter.(*internal.IpFilter) assert.True(t, ok) } func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) { - cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01")) + cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", []string{"don't care"})) _, ok := cut.connFilter.(*internal.NullObject) assert.True(t, ok) } diff --git a/privval/utils.go b/privval/utils.go index 34607235c..aceddc5f2 100644 --- a/privval/utils.go +++ b/privval/utils.go @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool { } // NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address -func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) { +func NewSignerListener(logger log.Logger, listenAddr string, remoteAddresses []string) (*SignerListenerEndpoint, error) { var listener net.Listener protocol, address := tmnet.ProtocolAndAddress(listenAddr) @@ -47,7 +47,7 @@ func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*Signe ) } - pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddr)) + pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddresses)) return pve, nil }