Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass along whether message is a response to the handler #42

Merged
merged 1 commit into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion protocol/blockfetch/blockfetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackCon
return b
}

func (b *BlockFetch) messageHandler(msg protocol.Message) error {
func (b *BlockFetch) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
case MESSAGE_TYPE_START_BATCH:
Expand Down
2 changes: 1 addition & 1 deletion protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf
return c
}

func (c *ChainSync) messageHandler(msg protocol.Message) error {
func (c *ChainSync) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
case MESSAGE_TYPE_AWAIT_REPLY:
Expand Down
2 changes: 1 addition & 1 deletion protocol/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func New(options protocol.ProtocolOptions, allowedVersions []uint16) *Handshake
return h
}

func (h *Handshake) handleMessage(msg protocol.Message) error {
func (h *Handshake) handleMessage(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
case MESSAGE_TYPE_PROPOSE_VERSIONS:
Expand Down
2 changes: 1 addition & 1 deletion protocol/keepalive/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *KeepAliveCallbackConf
return k
}

func (k *KeepAlive) messageHandler(msg protocol.Message) error {
func (k *KeepAlive) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
case MESSAGE_TYPE_KEEP_ALIVE:
Expand Down
2 changes: 1 addition & 1 deletion protocol/localtxsubmission/localtxsubmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *CallbackConfig) *Loca
return l
}

func (l *LocalTxSubmission) messageHandler(msg protocol.Message) error {
func (l *LocalTxSubmission) messageHandler(msg protocol.Message, isResponse bool) error {
var err error
switch msg.Type() {
case MESSAGE_TYPE_SUBMIT_TX:
Expand Down
11 changes: 7 additions & 4 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type ProtocolOptions struct {
Role ProtocolRole
}

type MessageHandlerFunc func(Message) error
type MessageHandlerFunc func(Message, bool) error
type MessageFromCborFunc func(uint, []byte) (Message, error)

func New(config ProtocolConfig) *Protocol {
Expand Down Expand Up @@ -108,6 +108,7 @@ func (p *Protocol) SendError(err error) {

func (p *Protocol) recvLoop() {
leftoverData := false
isResponse := false
for {
var err error
// Don't grab the next segment from the muxer if we still have data in the buffer
Expand All @@ -116,6 +117,8 @@ func (p *Protocol) recvLoop() {
segment := <-p.recvChan
// Add segment payload to buffer
p.recvBuffer.Write(segment.Payload)
// Save whether it's a response
isResponse = segment.IsResponse()
}
leftoverData = false
// Decode message into generic list until we can determine what type of message it is
Expand All @@ -140,7 +143,7 @@ func (p *Protocol) recvLoop() {
p.config.ErrorChan <- fmt.Errorf("%s: received unknown message type: %#v", p.config.Name, tmpMsg)
}
// Handle message
if err := p.handleMessage(msg); err != nil {
if err := p.handleMessage(msg, isResponse); err != nil {
p.config.ErrorChan <- err
}
if numBytesRead < p.recvBuffer.Len() {
Expand Down Expand Up @@ -183,7 +186,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
return newState, nil
}

func (p *Protocol) handleMessage(msg Message) error {
func (p *Protocol) handleMessage(msg Message, isResponse bool) error {
// Lock the state to prevent collisions
p.stateMutex.Lock()
if err := p.checkCurrentState(); err != nil {
Expand All @@ -197,5 +200,5 @@ func (p *Protocol) handleMessage(msg Message) error {
p.state = newState
p.stateMutex.Unlock()
// Call handler function
return p.config.MessageHandlerFunc(msg)
return p.config.MessageHandlerFunc(msg, isResponse)
}