diff --git a/pkg/network/message.go b/pkg/network/message.go index efa58ad29b..f1904a7d37 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -159,7 +159,7 @@ func (m *Message) decodePayload() error { case CMDTX: p = &transaction.Transaction{Network: m.Network} case CMDMerkleBlock: - p = &payload.MerkleBlock{} + p = &payload.MerkleBlock{Network: m.Network} case CMDPing, CMDPong: p = &payload.Ping{} case CMDNotFound: @@ -196,9 +196,6 @@ func (m *Message) Bytes() ([]byte, error) { if err := m.Encode(w.BinWriter); err != nil { return nil, err } - if w.Err != nil { - return nil, w.Err - } return w.Bytes(), nil } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index dd59aac16b..58ffdb3eff 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -1,12 +1,17 @@ package network import ( + "errors" + "math/rand" "testing" "time" + "github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -72,8 +77,7 @@ func TestEncodeDecodeHeaders(t *testing.T) { func TestEncodeDecodeGetAddr(t *testing.T) { // NullPayload should be handled properly - expected := NewMessage(CMDGetAddr, payload.NewNullPayload()) - testserdes.EncodeDecode(t, expected, &Message{}) + testEncodeDecode(t, CMDGetAddr, payload.NewNullPayload()) } func TestEncodeDecodeNil(t *testing.T) { @@ -88,11 +92,239 @@ func TestEncodeDecodeNil(t *testing.T) { } func TestEncodeDecodePing(t *testing.T) { - expected := NewMessage(CMDPing, payload.NewPing(123, 456)) - testserdes.EncodeDecode(t, expected, &Message{}) + testEncodeDecode(t, CMDPing, payload.NewPing(123, 456)) } func TestEncodeDecodeInventory(t *testing.T) { - expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) - testserdes.EncodeDecode(t, expected, &Message{}) + testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) +} + +func TestEncodeDecodeAddr(t *testing.T) { + const count = 3 + p := payload.NewAddressList(count) + p.Addrs[0] = &payload.AddressAndTime{ + Timestamp: rand.Uint32(), + Capabilities: capability.Capabilities{{ + Type: capability.FullNode, + Data: &capability.Node{StartHeight: rand.Uint32()}, + }}, + } + p.Addrs[1] = &payload.AddressAndTime{ + Timestamp: rand.Uint32(), + Capabilities: capability.Capabilities{{ + Type: capability.TCPServer, + Data: &capability.Server{Port: uint16(rand.Uint32())}, + }}, + } + p.Addrs[2] = &payload.AddressAndTime{ + Timestamp: rand.Uint32(), + Capabilities: capability.Capabilities{{ + Type: capability.WSServer, + Data: &capability.Server{Port: uint16(rand.Uint32())}, + }}, + } + testEncodeDecode(t, CMDAddr, p) +} + +func TestEncodeDecodeBlock(t *testing.T) { + t.Run("good", func(t *testing.T) { + testEncodeDecode(t, CMDBlock, newDummyBlock(1)) + }) + t.Run("invalid state root enabled setting", func(t *testing.T) { + expected := NewMessage(CMDBlock, newDummyBlock(1)) + expected.Network = netmode.UnitTestNet + data, err := testserdes.Encode(expected) + require.NoError(t, err) + require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet, StateRootInHeader: true})) + }) +} + +func TestEncodeDecodeGetBlock(t *testing.T) { + t.Run("good, Count>0", func(t *testing.T) { + testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{ + HashStart: random.Uint256(), + Count: int16(rand.Uint32() >> 17), + }) + }) + t.Run("good, Count=-1", func(t *testing.T) { + testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{ + HashStart: random.Uint256(), + Count: -1, + }) + }) + t.Run("bad, Count=-2", func(t *testing.T) { + testEncodeDecodeFail(t, CMDGetBlocks, &payload.GetBlocks{ + HashStart: random.Uint256(), + Count: -2, + }) + }) +} + +func TestEnodeDecodeGetHeaders(t *testing.T) { + testEncodeDecode(t, CMDGetHeaders, &payload.GetBlockByIndex{ + IndexStart: rand.Uint32(), + Count: payload.MaxHeadersAllowed, + }) +} + +func TestEncodeDecodeGetBlockByIndex(t *testing.T) { + t.Run("good, Count>0", func(t *testing.T) { + testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{ + IndexStart: rand.Uint32(), + Count: payload.MaxHeadersAllowed, + }) + }) + t.Run("bad, Count too big", func(t *testing.T) { + testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{ + IndexStart: rand.Uint32(), + Count: payload.MaxHeadersAllowed + 1, + }) + }) + t.Run("good, Count=-1", func(t *testing.T) { + testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{ + IndexStart: rand.Uint32(), + Count: -1, + }) + }) + t.Run("bad, Count=-2", func(t *testing.T) { + testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{ + IndexStart: rand.Uint32(), + Count: -2, + }) + }) +} + +func TestEncodeDecodeTransaction(t *testing.T) { + testEncodeDecode(t, CMDTX, newDummyTx()) +} + +func TestEncodeDecodeMerkleBlock(t *testing.T) { + base := &block.Base{ + PrevHash: random.Uint256(), + Timestamp: rand.Uint64(), + Script: transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + Network: netmode.UnitTestNet, + } + base.Hash() + t.Run("good", func(t *testing.T) { + testEncodeDecode(t, CMDMerkleBlock, &payload.MerkleBlock{ + Network: netmode.UnitTestNet, + Base: base, + TxCount: 1, + Hashes: []util.Uint256{random.Uint256()}, + Flags: []byte{0}, + }) + }) + t.Run("bad, invalid TxCount", func(t *testing.T) { + testEncodeDecodeFail(t, CMDMerkleBlock, &payload.MerkleBlock{ + Base: base, + TxCount: 2, + Hashes: []util.Uint256{random.Uint256()}, + Flags: []byte{0}, + }) + }) +} + +func TestEncodeDecodeNotFound(t *testing.T) { + testEncodeDecode(t, CMDNotFound, &payload.Inventory{ + Type: payload.TXType, + Hashes: []util.Uint256{random.Uint256()}, + }) +} + +func TestInvalidMessages(t *testing.T) { + t.Run("CMDBlock, empty payload", func(t *testing.T) { + testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{}) + }) + t.Run("send decompressed with flag", func(t *testing.T) { + m := NewMessage(CMDTX, newDummyTx()) + data, err := testserdes.Encode(m) + require.NoError(t, err) + require.True(t, m.Flags&Compressed == 0) + data[0] |= byte(Compressed) + require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet})) + }) + t.Run("invalid command", func(t *testing.T) { + testEncodeDecodeFail(t, CommandType(0xFF), &payload.Version{Magic: netmode.UnitTestNet}) + }) + t.Run("very big payload size", func(t *testing.T) { + m := NewMessage(CMDBlock, nil) + w := io.NewBufBinWriter() + w.WriteB(byte(m.Flags)) + w.WriteB(byte(m.Command)) + w.WriteVarBytes(make([]byte, payload.MaxSize+1)) + require.NoError(t, w.Err) + require.Error(t, testserdes.Decode(w.Bytes(), &Message{Network: netmode.UnitTestNet})) + }) + t.Run("fail to encode message if payload can't be serialized", func(t *testing.T) { + m := NewMessage(CMDBlock, failSer(true)) + _, err := m.Bytes() + require.Error(t, err) + + // good otherwise + m = NewMessage(CMDBlock, failSer(false)) + _, err = m.Bytes() + require.NoError(t, err) + }) + t.Run("trimmed payload", func(t *testing.T) { + m := NewMessage(CMDBlock, newDummyBlock(0)) + data, err := testserdes.Encode(m) + require.NoError(t, err) + data = data[:len(data)-1] + require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet})) + }) +} + +type failSer bool + +func (f failSer) EncodeBinary(r *io.BinWriter) { + if f { + r.Err = errors.New("unserializable payload") + } +} + +func (failSer) DecodeBinary(w *io.BinReader) {} + +func newDummyBlock(txCount int) *block.Block { + b := block.New(netmode.UnitTestNet, false) + b.PrevHash = random.Uint256() + b.Timestamp = rand.Uint64() + b.Script.InvocationScript = random.Bytes(2) + b.Script.VerificationScript = random.Bytes(3) + b.Transactions = make([]*transaction.Transaction, txCount) + for i := range b.Transactions { + b.Transactions[i] = newDummyTx() + } + b.Hash() + return b +} + +func newDummyTx() *transaction.Transaction { + tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), int64(rand.Uint64()>>1)) + tx.Signers = []transaction.Signer{{Account: random.Uint160()}} + tx.Size() + tx.Hash() + return tx +} + +func testEncodeDecode(t *testing.T, cmd CommandType, p payload.Payload) *Message { + expected := NewMessage(cmd, p) + expected.Network = netmode.UnitTestNet + actual := &Message{Network: netmode.UnitTestNet} + testserdes.EncodeDecode(t, expected, actual) + return actual +} + +func testEncodeDecodeFail(t *testing.T, cmd CommandType, p payload.Payload) *Message { + expected := NewMessage(cmd, p) + expected.Network = netmode.UnitTestNet + data, err := testserdes.Encode(expected) + require.NoError(t, err) + + actual := &Message{Network: netmode.UnitTestNet} + require.Error(t, testserdes.Decode(data, actual)) + return actual } diff --git a/pkg/network/payload/address_test.go b/pkg/network/payload/address_test.go index 9baad63ea5..4bf58d0b69 100644 --- a/pkg/network/payload/address_test.go +++ b/pkg/network/payload/address_test.go @@ -72,3 +72,22 @@ func TestEncodeDecodeBadAddressList(t *testing.T) { err = testserdes.DecodeBinary(bin, newAL) require.Error(t, err) } + +func TestGetTCPAddress(t *testing.T) { + t.Run("bad, no capability", func(t *testing.T) { + p := &AddressAndTime{} + copy(p.IP[:], net.IPv4(1, 1, 1, 1)) + p.Capabilities = append(p.Capabilities, capability.Capability{ + Type: capability.TCPServer, + Data: &capability.Server{Port: 123}, + }) + s, err := p.GetTCPAddress() + require.NoError(t, err) + require.Equal(t, "1.1.1.1:123", s) + }) + t.Run("bad, no capability", func(t *testing.T) { + p := &AddressAndTime{} + s, err := p.GetTCPAddress() + fmt.Println(s, err) + }) +} diff --git a/pkg/network/payload/inventory_test.go b/pkg/network/payload/inventory_test.go index 4a15d74a97..6f359e9dca 100644 --- a/pkg/network/payload/inventory_test.go +++ b/pkg/network/payload/inventory_test.go @@ -1,12 +1,14 @@ package payload import ( + "strings" "testing" "github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" . "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInventoryEncodeDecode(t *testing.T) { @@ -27,3 +29,17 @@ func TestEmptyInv(t *testing.T) { assert.Equal(t, []byte{byte(TXType), 0}, data) assert.Equal(t, 0, len(msgInv.Hashes)) } + +func TestValid(t *testing.T) { + require.True(t, TXType.Valid()) + require.True(t, BlockType.Valid()) + require.True(t, ConsensusType.Valid()) + require.False(t, InventoryType(0xFF).Valid()) +} + +func TestString(t *testing.T) { + require.Equal(t, "TX", TXType.String()) + require.Equal(t, "block", BlockType.String()) + require.Equal(t, "consensus", ConsensusType.String()) + require.True(t, strings.Contains(InventoryType(0xFF).String(), "unknown")) +} diff --git a/pkg/network/payload/merkleblock.go b/pkg/network/payload/merkleblock.go index 9867e4ab0f..b534990857 100644 --- a/pkg/network/payload/merkleblock.go +++ b/pkg/network/payload/merkleblock.go @@ -1,6 +1,9 @@ package payload import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -9,6 +12,7 @@ import ( // MerkleBlock represents a merkle block packet payload. type MerkleBlock struct { *block.Base + Network netmode.Magic TxCount int Hashes []util.Uint256 Flags []byte @@ -16,7 +20,7 @@ type MerkleBlock struct { // DecodeBinary implements Serializable interface. func (m *MerkleBlock) DecodeBinary(br *io.BinReader) { - m.Base = &block.Base{} + m.Base = &block.Base{Network: m.Network} m.Base.DecodeBinary(br) txCount := int(br.ReadVarUint()) @@ -26,6 +30,9 @@ func (m *MerkleBlock) DecodeBinary(br *io.BinReader) { } m.TxCount = txCount br.ReadArray(&m.Hashes, m.TxCount) + if txCount != len(m.Hashes) { + br.Err = errors.New("invalid tx count") + } m.Flags = br.ReadVarBytes((txCount + 7) / 8) }