Skip to content

Commit

Permalink
feat: support multiple allowIPs for a remote connection from one of K…
Browse files Browse the repository at this point in the history
…MS servers
  • Loading branch information
jaeseung-bae committed Aug 29, 2023
1 parent 449aa31 commit f6c01ac
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 27 deletions.
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) {
}
privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) {
config.PrivValidatorListenAddr = addr
config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")]
config.PrivValidatorRemoteAddresses = []string{addr[:strings.Index(addr, ":")]}
require.NoFileExists(t, config.PrivValidatorKeyFile())
output, err := captureStdout(func() {
err := showValidator(ShowValidatorCmd, nil, config)
Expand Down
10 changes: 6 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,12 @@ type BaseConfig struct { //nolint: maligned
// example) tcp://0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`
// Validator's remote address to allow a connection
// Comma separated list of addresses to allow
// ostracon only allows a connection from these addresses separated by a comma
// example) 127.0.0.1
// example) 127.0.0.1,192.168.1.2
PrivValidatorRemoteAddresses []string `mapstructure:"priv_validator_raddrs"`

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`
Expand Down
6 changes: 4 additions & 2 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"
# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# Comma separated list of addresses to allow
# ostracon only allows a connection from these addresses separated by a comma
# example) 127.0.0.1
priv_validator_raddr = "127.0.0.1"
# example) 127.0.0.1,192.168.1.2
priv_validator_raddrs = "127.0.0.1"
# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"
Expand Down
2 changes: 1 addition & 1 deletion node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
}

func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddresses)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
18 changes: 12 additions & 6 deletions privval/internal/ip_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import (
"fmt"
"github.com/Finschia/ostracon/libs/log"
"net"
"strings"
)

type IpFilter struct {
allowAddr string
allowList []string
log log.Logger
}

func NewIpFilter(addr string, l log.Logger) *IpFilter {
func NewIpFilter(allowAddresses []string, l log.Logger) *IpFilter {
return &IpFilter{
allowAddr: addr,
allowList: allowAddresses,
log: l,
}
}
Expand All @@ -26,11 +27,11 @@ func (f *IpFilter) Filter(addr net.Addr) net.Addr {
}

func (f *IpFilter) String() string {
return f.allowAddr
return strings.Join(f.allowList, ",")
}

func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
if len(f.allowAddr) == 0 {
if len(f.allowList) == 0 {
return false
}
hostAddr, _, err := net.SplitHostPort(addr.String())
Expand All @@ -40,5 +41,10 @@ func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
}
return false
}
return f.allowAddr == hostAddr
for _, address := range f.allowList {
if address == hostAddr {
return true
}
}
return false
}
54 changes: 49 additions & 5 deletions privval/internal/ip_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package internal
import (
"github.com/stretchr/testify/assert"
"net"
"strings"
"testing"
)

Expand Down Expand Up @@ -71,21 +72,64 @@ func TestFilterRemoteConnectionByIP(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowIP, nil)
cut := NewIpFilter([]string{tt.fields.allowIP}, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestFilterRemoteConnectionByIPWithMultipleAllowIPs(t *testing.T) {
type fields struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}
tests := []struct {
name string
fields fields
}{
{
"should allow the one in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"192.168.1.1:45678"}, addrStub{"192.168.1.1:45678"}},
},
{
"should not allow any ip which is not in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"127.0.0.1", "192.168.1.1"}, addrStub{"10.0.0.2:45678"}, nil},
},
{
"should works for IPv6 with one of correct ip in the allow list",
struct {
allowList []string
remoteAddr net.Addr
expected net.Addr
}{[]string{"2001:db8::1", "2001:db8::2"}, addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowList, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestIpFilterShouldSetAllowAddress(t *testing.T) {
expected := "192.168.0.1"
expected := []string{"192.168.0.1"}

cut := NewIpFilter(expected, nil)

assert.Equal(t, expected, cut.allowAddr)
assert.Equal(t, expected, cut.allowList)
}

func TestIpFilterStringShouldReturnsIP(t *testing.T) {
expected := "127.0.0.1"
assert.Equal(t, expected, NewIpFilter(expected, nil).String())
expected := []string{"127.0.0.1", "192.168.1.10"}
assert.Equal(t, strings.Join(expected, ","), NewIpFilter(expected, nil).String())
}
7 changes: 3 additions & 4 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListene
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption {
// connections from the only allowed addresses
func SignerListenerEndpointAllowAddress(protocol string, allowedAddresses []string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) {
if protocol == "tcp" || len(protocol) == 0 {
sl.connFilter = internal.NewIpFilter(addr, sl.Logger)
sl.connFilter = internal.NewIpFilter(allowedAddresses, sl.Logger)
return
}
sl.connFilter = internal.NewNullObject()
Expand Down
4 changes: 2 additions & 2 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ func getMockEndpoints(
}

func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1"))
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", []string{"127.0.0.1"}))
_, ok := cut.connFilter.(*internal.IpFilter)
assert.True(t, ok)
}

func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01"))
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", []string{"don't care"}))
_, ok := cut.connFilter.(*internal.NullObject)
assert.True(t, ok)
}
4 changes: 2 additions & 2 deletions privval/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func IsConnTimeout(err error) bool {
}

// NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address
func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*SignerListenerEndpoint, error) {
func NewSignerListener(logger log.Logger, listenAddr string, remoteAddresses []string) (*SignerListenerEndpoint, error) {
var listener net.Listener

protocol, address := tmnet.ProtocolAndAddress(listenAddr)
Expand All @@ -47,7 +47,7 @@ func NewSignerListener(logger log.Logger, listenAddr, remoteAddr string) (*Signe
)
}

pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddr))
pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener, SignerListenerEndpointAllowAddress(protocol, remoteAddresses))

return pve, nil
}
Expand Down

0 comments on commit f6c01ac

Please sign in to comment.