Skip to content

Commit

Permalink
Merge pull request #36 from cloudstruct/feature/muxer-handshake-race-…
Browse files Browse the repository at this point in the history
…condition

Fix race condition around handshake and muxer protocol registration
  • Loading branch information
agaffney authored Mar 8, 2022
2 parents d838151 + d193a75 commit 75a95e7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
20 changes: 18 additions & 2 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@ import (
"encoding/binary"
"fmt"
"io"
"net"
"sync"
)

const (
// Magic number chosen to represent unknown protocols
PROTOCOL_UNKNOWN uint16 = 0xabcd

// Handshake protocol ID
PROTOCOL_HANDSHAKE = 0
)

type Muxer struct {
conn io.ReadWriteCloser
conn net.Conn
sendMutex sync.Mutex
startChan chan bool
ErrorChan chan error
protocolSenders map[uint16]chan *Segment
protocolReceivers map[uint16]chan *Segment
}

func New(conn io.ReadWriteCloser) *Muxer {
func New(conn net.Conn) *Muxer {
m := &Muxer{
conn: conn,
startChan: make(chan bool, 1),
ErrorChan: make(chan error, 10),
protocolSenders: make(map[uint16]chan *Segment),
protocolReceivers: make(map[uint16]chan *Segment),
Expand All @@ -32,6 +38,10 @@ func New(conn io.ReadWriteCloser) *Muxer {
return m
}

func (m *Muxer) Start() {
m.startChan <- true
}

func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
// Generate channels
senderChan := make(chan *Segment, 10)
Expand Down Expand Up @@ -69,6 +79,7 @@ func (m *Muxer) Send(msg *Segment) error {
}

func (m *Muxer) readLoop() {
started := false
for {
header := SegmentHeader{}
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
Expand All @@ -83,6 +94,11 @@ func (m *Muxer) readLoop() {
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
m.ErrorChan <- err
}
// Wait until the muxer is started to process anything other than handshake messages
if !started && msg.GetProtocolId() != PROTOCOL_HANDSHAKE {
<-m.startChan
started = true
}
// Send message payload to proper receiver
recvChan := m.protocolReceivers[msg.GetProtocolId()]
if recvChan == nil {
Expand Down
10 changes: 10 additions & 0 deletions ouroboros.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type Ouroboros struct {
muxer *muxer.Muxer
ErrorChan chan error
sendKeepAlives bool
delayMuxerStart bool
// Mini-protocols
Handshake *handshake.Handshake
ChainSync *chainsync.ChainSync
Expand All @@ -39,6 +40,7 @@ type OuroborosOptions struct {
Server bool
UseNodeToNodeProtocol bool
SendKeepAlives bool
DelayMuxerStart bool
ChainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig
BlockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig
KeepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig
Expand All @@ -57,6 +59,7 @@ func New(options *OuroborosOptions) (*Ouroboros, error) {
localTxSubmissionCallbackConfig: options.LocalTxSubmissionCallbackConfig,
ErrorChan: options.ErrorChan,
sendKeepAlives: options.SendKeepAlives,
delayMuxerStart: options.DelayMuxerStart,
}
if o.ErrorChan == nil {
o.ErrorChan = make(chan error, 10)
Expand All @@ -69,6 +72,10 @@ func New(options *OuroborosOptions) (*Ouroboros, error) {
return o, nil
}

func (o *Ouroboros) Muxer() *muxer.Muxer {
return o.muxer
}

// Convenience function for creating a connection if you didn't provide one when
// calling New()
func (o *Ouroboros) Dial(proto string, address string) error {
Expand Down Expand Up @@ -134,5 +141,8 @@ func (o *Ouroboros) setupConnection() error {
o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig)
o.LocalTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionCallbackConfig)
}
if !o.delayMuxerStart {
o.muxer.Start()
}
return nil
}

0 comments on commit 75a95e7

Please sign in to comment.