Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: whitelist connection gater #31

Merged
merged 23 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/tss/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func main() {
nil,
p2pConf.ExternalIP,
os.Getenv("PASSWORD"),
[]string{},
true,
)
if nil != err {
log.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions keygen/ecdsa/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ func (s *TssECDSAKeygenTestSuite) SetUpTest(c *C) {
buf, err := base64.StdEncoding.DecodeString(testPriKeyArr[i])
c.Assert(err, IsNil)
if i == 0 {
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "")
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf[:]), IsNil)
s.comms[i] = comm
continue
}
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "")
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf[:]), IsNil)
s.comms[i] = comm
Expand Down
4 changes: 2 additions & 2 deletions keygen/eddsa/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ func (s *EddsaKeygenTestSuite) SetUpTest(c *C) {
buf, err := base64.StdEncoding.DecodeString(testPriKeyArr[i])
c.Assert(err, IsNil)
if i == 0 {
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "")
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
continue
}
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "")
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
Expand Down
4 changes: 2 additions & 2 deletions keysign/ecdsa/keysign_old_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,13 @@ func (s *TssECDSAKeysignOldTestSuite) SetUpTest(c *C) {
buf, err := base64.StdEncoding.DecodeString(testPriKeyArr[i])
c.Assert(err, IsNil)
if i == 0 {
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "")
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
continue
}
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "")
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
Expand Down
4 changes: 2 additions & 2 deletions keysign/ecdsa/keysign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ func (s *TssECDSAKeysignTestSuite) SetUpTest(c *C) {
buf, err := base64.StdEncoding.DecodeString(testPriKeyArr[i])
c.Assert(err, IsNil)
if i == 0 {
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "")
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
continue
}
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "")
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
Expand Down
4 changes: 2 additions & 2 deletions keysign/eddsa/keysign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ func (s *EddsaKeysignTestSuite) SetUpTest(c *C) {
buf, err := base64.StdEncoding.DecodeString(testPriKeyArr[i])
c.Assert(err, IsNil)
if i == 0 {
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "")
comm, err := p2p.NewCommunication("asgard", nil, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
continue
}
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "")
comm, err := p2p.NewCommunication("asgard", []maddr.Multiaddr{multiAddr}, ports[i], "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(buf), IsNil)
s.comms[i] = comm
Expand Down
28 changes: 27 additions & 1 deletion p2p/communication.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,19 @@ type Communication struct {
BroadcastMsgChan chan *messages.BroadcastMsgChan
externalAddr maddr.Multiaddr
streamMgr *StreamMgr
whitelistedPeers []string
disableWhitelist bool
}

// NewCommunication create a new instance of Communication
func NewCommunication(rendezvous string, bootstrapPeers []maddr.Multiaddr, port int, externalIP string) (*Communication, error) {
func NewCommunication(
rendezvous string,
bootstrapPeers []maddr.Multiaddr,
port int,
externalIP string,
whitelistedPeers []string,
disableWhitelist bool,
) (*Communication, error) {
addr, err := maddr.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", port))
if err != nil {
return nil, fmt.Errorf("fail to create listen addr: %w", err)
Expand All @@ -90,6 +99,8 @@ func NewCommunication(rendezvous string, bootstrapPeers []maddr.Multiaddr, port
BroadcastMsgChan: make(chan *messages.BroadcastMsgChan, 1024),
externalAddr: externalAddr,
streamMgr: NewStreamMgr(),
whitelistedPeers: whitelistedPeers,
disableWhitelist: disableWhitelist,
}, nil
}

Expand Down Expand Up @@ -301,6 +312,7 @@ func (c *Communication) startChannel(privKeyBytes []byte) error {
libp2p.AddrsFactory(addressFactory),
libp2p.ResourceManager(m),
libp2p.ConnectionManager(cmgr),
libp2p.ConnectionGater(NewWhitelistConnectionGater(c.whitelistedPeers, c.disableWhitelist, c.logger)),
)
if err != nil {
return fmt.Errorf("fail to create p2p host: %w", err)
Expand Down Expand Up @@ -338,6 +350,20 @@ func (c *Communication) startChannel(privKeyBytes []byte) error {
// This is like telling your friends to meet you at the Eiffel Tower.
routingDiscovery := discovery_routing.NewRoutingDiscovery(kademliaDHT)
discovery_util.Advertise(ctx, routingDiscovery, c.rendezvous)

// Create a goroutine to shut down the DHT after 5 minutes
go func() {
select {
case <-time.After(5 * time.Minute):
c.logger.Info().Msg("Closing Kademlia DHT after 5 minutes")
if err := kademliaDHT.Close(); err != nil {
c.logger.Error().Err(err).Msg("Failed to close Kademlia DHT")
}
case <-ctx.Done():
c.logger.Info().Msg("Context done, not waiting for 5 minutes to close DHT")
}
}()

err = c.bootStrapConnectivityCheck()
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions p2p/communication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type CommunicationTestSuite struct{}
var _ = Suite(&CommunicationTestSuite{})

func (CommunicationTestSuite) TestBasicCommunication(c *C) {
comm, err := NewCommunication("rendezvous", nil, 6668, "")
comm, err := NewCommunication("rendezvous", nil, 6668, "", []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm, NotNil)
comm.SetSubscribe(messages.TSSKeyGenMsg, "hello", make(chan *Message))
Expand Down Expand Up @@ -47,15 +47,15 @@ func (CommunicationTestSuite) TestEstablishP2pCommunication(c *C) {
c.Assert(err, IsNil)
privKey, err := base64.StdEncoding.DecodeString(bootstrapPrivKey)
c.Assert(err, IsNil)
comm, err := NewCommunication("commTest", nil, 2220, fakeExternalIP)
comm, err := NewCommunication("commTest", nil, 2220, fakeExternalIP, []string{}, true)
c.Assert(err, IsNil)
c.Assert(comm.Start(privKey), IsNil)

defer comm.Stop()
sk1, _, err := crypto.GenerateSecp256k1Key(rand.Reader)
sk1raw, _ := sk1.Raw()
c.Assert(err, IsNil)
comm2, err := NewCommunication("commTest", []maddr.Multiaddr{validMultiAddr}, 2221, "")
comm2, err := NewCommunication("commTest", []maddr.Multiaddr{validMultiAddr}, 2221, "", []string{}, true)
c.Assert(err, IsNil)
err = comm2.Start(sk1raw)
c.Assert(err, IsNil)
Expand All @@ -69,14 +69,14 @@ func (CommunicationTestSuite) TestEstablishP2pCommunication(c *C) {
invalidAddr := "/ip4/127.0.0.1/tcp/2220/p2p/" + id.String()
invalidMultiAddr, err := maddr.NewMultiaddr(invalidAddr)
c.Assert(err, IsNil)
comm3, err := NewCommunication("commTest", []maddr.Multiaddr{invalidMultiAddr}, 2222, "")
comm3, err := NewCommunication("commTest", []maddr.Multiaddr{invalidMultiAddr}, 2222, "", []string{}, true)
c.Assert(err, IsNil)
err = comm3.Start(sk1raw)
c.Assert(err, ErrorMatches, "fail to connect to bootstrap peer: fail to connect to any peer")
defer comm3.Stop()

// we connect to one invalid and one valid address
comm4, err := NewCommunication("commTest", []maddr.Multiaddr{invalidMultiAddr, validMultiAddr}, 2223, "")
comm4, err := NewCommunication("commTest", []maddr.Multiaddr{invalidMultiAddr, validMultiAddr}, 2223, "", []string{}, true)
c.Assert(err, IsNil)
err = comm4.Start(sk1raw)
c.Assert(err, IsNil)
Expand Down
69 changes: 69 additions & 0 deletions p2p/whitelist_connection_gater.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package p2p

import (
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/rs/zerolog"

"github.com/libp2p/go-libp2p/core/peer"
maddr "github.com/multiformats/go-multiaddr"
)

type WhitelistConnectionGater struct {
whitelistedPeers map[string]bool
logger zerolog.Logger
disableWhitelist bool
}

func NewWhitelistConnectionGater(whitelistedPeers []string, disableWhitelist bool, logger zerolog.Logger) *WhitelistConnectionGater {
gater := &WhitelistConnectionGater{
disableWhitelist: disableWhitelist,
logger: logger,
whitelistedPeers: make(map[string]bool),
}

for _, p := range whitelistedPeers {
logger.Info().Msgf("Adding peer %s to whitelist", p)
gater.whitelistedPeers[p] = true
}

return gater
}

func (wg *WhitelistConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) {
wg.logger.Info().Msgf("InterceptPeerDial %s", p.String())
if !wg.disableWhitelist {
wg.logger.Info().Msgf("peer allowed %t", wg.whitelistedPeers[p.String()])
return wg.whitelistedPeers[p.String()]
}
return true
}

func (wg *WhitelistConnectionGater) InterceptAddrDial(p peer.ID, m maddr.Multiaddr) (allow bool) {
wg.logger.Info().Msgf("InterceptAddrDial %s %s", p.String(), m.String())
if !wg.disableWhitelist {
wg.logger.Info().Msgf("peer allowed %t", wg.whitelistedPeers[p.String()])
return wg.whitelistedPeers[p.String()]
}
// Not checking addresses here, just allowing based on peer ID
return true
}

func (wg *WhitelistConnectionGater) InterceptAccept(m network.ConnMultiaddrs) (allow bool) {
return true
}

func (wg *WhitelistConnectionGater) InterceptSecured(direction network.Direction, p peer.ID, m network.ConnMultiaddrs) (allow bool) {
wg.logger.Info().Msgf("InterceptSecured %s", p.String())
if !wg.disableWhitelist {
wg.logger.Info().Msgf("peer allowed %t", wg.whitelistedPeers[p.String()])
return wg.whitelistedPeers[p.String()]
}
// _, allow = wg.whitelistedPeers[p]
return true
}

func (wg *WhitelistConnectionGater) InterceptUpgraded(network.Conn) (bool, control.DisconnectReason) {
// Allow connection upgrades
return true, 0
}
24 changes: 23 additions & 1 deletion tss/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func NewTss(
preParams *bkeygen.LocalPreParams,
externalIP string,
tssPassword string,
whitelistedPeers []string,
disableWhitelist bool,
) (*TssServer, error) {
pk := coskey.PubKey{
Key: priKey.PubKey().Bytes()[:],
Expand All @@ -74,6 +76,10 @@ func NewTss(
return nil, fmt.Errorf("fail to create file state manager")
}

if !disableWhitelist && len(whitelistedPeers) == 0 {
return nil, fmt.Errorf("whitelisted peers missing")
}

var bootstrapPeers []maddr.Multiaddr
savedPeers, err := stateManager.RetrieveP2PAddresses()
if err != nil {
Expand All @@ -82,7 +88,23 @@ func NewTss(
bootstrapPeers = savedPeers
bootstrapPeers = append(bootstrapPeers, cmdBootstrapPeers...)
}
comm, err := p2p.NewCommunication(rendezvous, bootstrapPeers, p2pPort, externalIP)

// TODO: make this cleaner
var whitelistedBootstrapPeers []maddr.Multiaddr
for _, b := range bootstrapPeers {
peer, err := peer.AddrInfoFromP2pAddr(b)
if err != nil {
return nil, err
}

for _, w := range whitelistedPeers {
if w == peer.ID.String() {
whitelistedBootstrapPeers = append(whitelistedBootstrapPeers, b)
}
}
}

comm, err := p2p.NewCommunication(rendezvous, whitelistedBootstrapPeers, p2pPort, externalIP, whitelistedPeers, disableWhitelist)
if err != nil {
return nil, fmt.Errorf("fail to create communication layer: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion tss/tss_4nodes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ func (s *FourNodeTestSuite) getTssServer(c *C, index int, conf common.TssConfig,
} else {
peerIDs = nil
}
instance, err := NewTss(peerIDs, s.ports[index], priKey, "Asgard", baseHome, conf, s.preParams[index], "", "password")
instance, err := NewTss(peerIDs, s.ports[index], priKey, "Asgard", baseHome, conf, s.preParams[index], "", "password", []string{}, true)
c.Assert(err, IsNil)
return instance
}
Expand Down
2 changes: 1 addition & 1 deletion tss/tss_4nodes_zeta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *FourNodeScaleZetaSuite) getTssServer(c *C, index int, conf common.TssCo
} else {
peerIDs = nil
}
instance, err := NewTss(peerIDs, s.ports[index], priKey, "Zeta", baseHome, conf, s.preParams[index], "", "password")
instance, err := NewTss(peerIDs, s.ports[index], priKey, "Zeta", baseHome, conf, s.preParams[index], "", "password", []string{}, true)
c.Assert(err, IsNil)
return instance
}
Loading