From 57fec5249c5384604779766b9cadd35e2463abea Mon Sep 17 00:00:00 2001 From: Andrew Gaffney Date: Tue, 8 Mar 2022 20:57:34 -0600 Subject: [PATCH] Pass along whether message is a response to the handler Fixes #40 --- protocol/blockfetch/blockfetch.go | 2 +- protocol/chainsync/chainsync.go | 2 +- protocol/handshake/handshake.go | 2 +- protocol/keepalive/keepalive.go | 2 +- protocol/localtxsubmission/localtxsubmission.go | 2 +- protocol/protocol.go | 11 +++++++---- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index e554e5fa..02e8b2d8 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -102,7 +102,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackCon return b } -func (b *BlockFetch) messageHandler(msg protocol.Message) error { +func (b *BlockFetch) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { case MESSAGE_TYPE_START_BATCH: diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index bbf86309..96fd5cc1 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -137,7 +137,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf return c } -func (c *ChainSync) messageHandler(msg protocol.Message) error { +func (c *ChainSync) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { case MESSAGE_TYPE_AWAIT_REPLY: diff --git a/protocol/handshake/handshake.go b/protocol/handshake/handshake.go index df8aae9f..622442f9 100644 --- a/protocol/handshake/handshake.go +++ b/protocol/handshake/handshake.go @@ -75,7 +75,7 @@ func New(options protocol.ProtocolOptions, allowedVersions []uint16) *Handshake return h } -func (h *Handshake) handleMessage(msg protocol.Message) error { +func (h *Handshake) handleMessage(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { case MESSAGE_TYPE_PROPOSE_VERSIONS: diff --git a/protocol/keepalive/keepalive.go b/protocol/keepalive/keepalive.go index dbf60f24..22bbdac8 100644 --- a/protocol/keepalive/keepalive.go +++ b/protocol/keepalive/keepalive.go @@ -85,7 +85,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConf return k } -func (k *KeepAlive) messageHandler(msg protocol.Message) error { +func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { case MESSAGE_TYPE_KEEP_ALIVE: diff --git a/protocol/localtxsubmission/localtxsubmission.go b/protocol/localtxsubmission/localtxsubmission.go index 3e0e6116..19d1cf3e 100644 --- a/protocol/localtxsubmission/localtxsubmission.go +++ b/protocol/localtxsubmission/localtxsubmission.go @@ -82,7 +82,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca return l } -func (l *LocalTxSubmission) messageHandler(msg protocol.Message) error { +func (l *LocalTxSubmission) messageHandler(msg protocol.Message, isResponse bool) error { var err error switch msg.Type() { case MESSAGE_TYPE_SUBMIT_TX: diff --git a/protocol/protocol.go b/protocol/protocol.go index 517c1e6a..ed158b6c 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -54,7 +54,7 @@ type ProtocolOptions struct { Role ProtocolRole } -type MessageHandlerFunc func(Message) error +type MessageHandlerFunc func(Message, bool) error type MessageFromCborFunc func(uint, []byte) (Message, error) func New(config ProtocolConfig) *Protocol { @@ -108,6 +108,7 @@ func (p *Protocol) SendError(err error) { func (p *Protocol) recvLoop() { leftoverData := false + isResponse := false for { var err error // Don't grab the next segment from the muxer if we still have data in the buffer @@ -116,6 +117,8 @@ func (p *Protocol) recvLoop() { segment := <-p.recvChan // Add segment payload to buffer p.recvBuffer.Write(segment.Payload) + // Save whether it's a response + isResponse = segment.IsResponse() } leftoverData = false // Decode message into generic list until we can determine what type of message it is @@ -140,7 +143,7 @@ func (p *Protocol) recvLoop() { p.config.ErrorChan <- fmt.Errorf("%s: received unknown message type: %#v", p.config.Name, tmpMsg) } // Handle message - if err := p.handleMessage(msg); err != nil { + if err := p.handleMessage(msg, isResponse); err != nil { p.config.ErrorChan <- err } if numBytesRead < p.recvBuffer.Len() { @@ -183,7 +186,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) { return newState, nil } -func (p *Protocol) handleMessage(msg Message) error { +func (p *Protocol) handleMessage(msg Message, isResponse bool) error { // Lock the state to prevent collisions p.stateMutex.Lock() if err := p.checkCurrentState(); err != nil { @@ -197,5 +200,5 @@ func (p *Protocol) handleMessage(msg Message) error { p.state = newState p.stateMutex.Unlock() // Call handler function - return p.config.MessageHandlerFunc(msg) + return p.config.MessageHandlerFunc(msg, isResponse) }