Skip to content

Commit

Permalink
Merge #19107: p2p: Move all header verification into the network laye…
Browse files Browse the repository at this point in the history
…r, extend logging

deb5271 Remove header checks out of net_processing (Troy Giorshev)
52d4ae4 Give V1TransportDeserializer CChainParams& member (Troy Giorshev)
5bceef6 Change CMessageHeader Constructor (Troy Giorshev)
1ca20c1 Add doxygen comment for ReceiveMsgBytes (Troy Giorshev)
890b1d7 Move checksum check from net_processing to net (Troy Giorshev)
2716647 Give V1TransportDeserializer an m_node_id member (Troy Giorshev)

Pull request description:

  Inspired by #15206 and #15197, this PR moves all message header verification from the message processing layer and into the network/transport layer.

  In the previous PRs there is a change in behavior, where we would disconnect from peers upon a single failed checksum check.  In various discussions there was concern over whether this was the right choice, and some expressed a desire to see how this would look if it was made to be a pure refactor.

  For more context, see https://bitcoincore.reviews/15206.html#l-81.

  This PR improves the separation between the p2p layers, helping improvements like [BIP324](bitcoin/bitcoin#18242) and #18989.

ACKs for top commit:
  ryanofsky:
    Code review ACK deb5271 just rebase due to conflict on adjacent line
  jnewbery:
    Code review ACK deb5271.

Tree-SHA512: 1a3b7ae883b020cfee1bef968813e04df651ffdad9dd961a826bd80654f2c98676ce7f4721038a1b78d8790e4cebe8060419e3d8affc97ce2b9b4e4b72e6fa9f
  • Loading branch information
fanquake committed Sep 29, 2020
2 parents e36aa35 + deb5271 commit 6af9b31
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 111 deletions.
86 changes: 58 additions & 28 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <net.h>

#include <banman.h>
#include <chainparams.h>
#include <clientversion.h>
#include <consensus/consensus.h>
#include <crypto/sha256.h>
Expand Down Expand Up @@ -607,6 +606,16 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap)
}
#undef X

/**
* Receive bytes from the buffer and deserialize them into messages.
*
* @param[in] pch A pointer to the raw data
* @param[in] nBytes Size of the data
* @param[out] complete Set True if at least one message has been
* deserialized and is ready to be processed
* @return True if the peer should stay connected,
* False if the peer should be disconnected from.
*/
bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete)
{
complete = false;
Expand All @@ -617,25 +626,35 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
while (nBytes > 0) {
// absorb network data
int handled = m_deserializer->Read(pch, nBytes);
if (handled < 0) return false;
if (handled < 0) {
// Serious header problem, disconnect from the peer.
return false;
}

pch += handled;
nBytes -= handled;

if (m_deserializer->Complete()) {
// decompose a transport agnostic CNetMessage from the deserializer
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time);
uint32_t out_err_raw_size{0};
Optional<CNetMessage> result{m_deserializer->GetMessage(time, out_err_raw_size)};
if (!result) {
// Message deserialization failed. Drop the message but don't disconnect the peer.
// store the size of the corrupt message
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
continue;
}

//store received bytes per message command
//to prevent a memory DOS, only allow valid commands
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command);
if (i == mapRecvBytesPerMsgCmd.end())
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
assert(i != mapRecvBytesPerMsgCmd.end());
i->second += msg.m_raw_message_size;
i->second += result->m_raw_message_size;

// push the message to the process queue,
vRecvMsg.push_back(std::move(msg));
vRecvMsg.push_back(std::move(*result));

complete = true;
}
Expand All @@ -662,11 +681,19 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
hdrbuf >> hdr;
}
catch (const std::exception&) {
LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", m_node_id);
return -1;
}

// Check start string, network magic
if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) {
LogPrint(BCLog::NET, "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id);
return -1;
}

// reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH
if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, m_node_id);
return -1;
}

Expand Down Expand Up @@ -701,36 +728,39 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
return data_hash;
}

CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time)
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, uint32_t& out_err_raw_size)
{
// decompose a single CNetMessage from the TransportDeserializer
CNetMessage msg(std::move(vRecv));
Optional<CNetMessage> msg(std::move(vRecv));

// store state about valid header, netmagic and checksum
msg.m_valid_header = hdr.IsValid(message_start);
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
uint256 hash = GetMessageHash();
// store command string, time, and sizes
msg->m_command = hdr.GetCommand();
msg->m_time = time;
msg->m_message_size = hdr.nMessageSize;
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;

// store command string, payload size
msg.m_command = hdr.GetCommand();
msg.m_message_size = hdr.nMessageSize;
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
uint256 hash = GetMessageHash();

// We just received a message off the wire, harvest entropy from the time (and the message checksum)
RandAddEvent(ReadLE32(hash.begin()));

msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0);
if (!msg.m_valid_checksum) {
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n",
SanitizeString(msg.m_command), msg.m_message_size,
// Check checksum and header command string
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
SanitizeString(msg->m_command), msg->m_message_size,
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
HexStr(hdr.pchChecksum));
}

// store receive time
msg.m_time = time;

// reset the network deserializer (prepare for the next message)
HexStr(hdr.pchChecksum),
m_node_id);
out_err_raw_size = msg->m_raw_message_size;
msg = nullopt;
} else if (!hdr.IsCommandValid()) {
LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n",
hdr.GetCommand(), msg->m_message_size, m_node_id);
out_err_raw_size = msg->m_raw_message_size;
msg = nullopt;
}

// Always reset the network deserializer (prepare for the next message)
Reset();
return msg;
}
Expand Down Expand Up @@ -2850,7 +2880,7 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn
LogPrint(BCLog::NET, "Added connection peer=%d\n", id);
}

m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION));
m_deserializer = MakeUnique<V1TransportDeserializer>(V1TransportDeserializer(Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION));
m_serializer = MakeUnique<V1TransportSerializer>(V1TransportSerializer());
}

Expand Down
25 changes: 15 additions & 10 deletions src/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
#include <addrman.h>
#include <amount.h>
#include <bloom.h>
#include <chainparams.h>
#include <compat.h>
#include <crypto/siphash.h>
#include <hash.h>
#include <limitedmap.h>
#include <netaddress.h>
#include <net_permissions.h>
#include <netaddress.h>
#include <optional.h>
#include <policy/feerate.h>
#include <protocol.h>
#include <random.h>
Expand Down Expand Up @@ -712,11 +714,8 @@ class CNetMessage {
public:
CDataStream m_recv; //!< received message data
std::chrono::microseconds m_time{0}; //!< time of message receipt
bool m_valid_netmagic = false;
bool m_valid_header = false;
bool m_valid_checksum = false;
uint32_t m_message_size{0}; //!< size of the payload
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
uint32_t m_message_size{0}; //!< size of the payload
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
std::string m_command;

CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {}
Expand All @@ -740,13 +739,15 @@ class TransportDeserializer {
// read and deserialize data
virtual int Read(const char *data, unsigned int bytes) = 0;
// decomposes a message from the context
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0;
virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0;
virtual ~TransportDeserializer() {}
};

class V1TransportDeserializer final : public TransportDeserializer
{
private:
const CChainParams& m_chain_params;
const NodeId m_node_id; // Only for logging
mutable CHash256 hasher;
mutable uint256 data_hash;
bool in_data; // parsing header (false) or data (true)
Expand All @@ -772,8 +773,12 @@ class V1TransportDeserializer final : public TransportDeserializer
}

public:

V1TransportDeserializer(const CMessageHeader::MessageStartChars& pchMessageStartIn, int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) {
V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
: m_chain_params(chain_params),
m_node_id(node_id),
hdrbuf(nTypeIn, nVersionIn),
vRecv(nTypeIn, nVersionIn)
{
Reset();
}

Expand All @@ -793,7 +798,7 @@ class V1TransportDeserializer final : public TransportDeserializer
if (ret < 0) Reset();
return ret;
}
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override;
Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
};

/** The TransportSerializer prepares messages for the network transport
Expand Down
32 changes: 1 addition & 31 deletions src/net_processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3815,14 +3815,6 @@ bool PeerManager::MaybeDiscourageAndDisconnect(CNode& pnode)

bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgProc)
{
//
// Message format
// (4) message start
// (12) command
// (4) size
// (4) checksum
// (x) data
//
bool fMoreWork = false;

if (!pfrom->vRecvGetData.empty())
Expand Down Expand Up @@ -3863,35 +3855,13 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
CNetMessage& msg(msgs.front());

msg.SetVersion(pfrom->GetCommonVersion());
// Check network magic
if (!msg.m_valid_netmagic) {
LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
pfrom->fDisconnect = true;
return false;
}

// Check header
if (!msg.m_valid_header)
{
LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
return fMoreWork;
}
const std::string& msg_type = msg.m_command;

// Message size
unsigned int nMessageSize = msg.m_message_size;

// Checksum
CDataStream& vRecv = msg.m_recv;
if (!msg.m_valid_checksum)
{
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
SanitizeString(msg_type), nMessageSize, pfrom->GetId());
return fMoreWork;
}

try {
ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc);
ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc);
if (interruptMsgProc)
return false;
if (!pfrom->vRecvGetData.empty())
Expand Down
33 changes: 11 additions & 22 deletions src/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ const static std::string allNetMessageTypes[] = {
};
const static std::vector<std::string> allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes+ARRAYLEN(allNetMessageTypes));

CMessageHeader::CMessageHeader(const MessageStartChars& pchMessageStartIn)
CMessageHeader::CMessageHeader()
{
memcpy(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE);
memset(pchMessageStart, 0, MESSAGE_START_SIZE);
memset(pchCommand, 0, sizeof(pchCommand));
nMessageSize = -1;
memset(pchChecksum, 0, CHECKSUM_SIZE);
Expand All @@ -111,31 +111,20 @@ std::string CMessageHeader::GetCommand() const
return std::string(pchCommand, pchCommand + strnlen(pchCommand, COMMAND_SIZE));
}

bool CMessageHeader::IsValid(const MessageStartChars& pchMessageStartIn) const
bool CMessageHeader::IsCommandValid() const
{
// Check start string
if (memcmp(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE) != 0)
return false;

// Check the command string for errors
for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; p1++)
{
if (*p1 == 0)
{
for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; ++p1) {
if (*p1 == 0) {
// Must be all zeros after the first zero
for (; p1 < pchCommand + COMMAND_SIZE; p1++)
if (*p1 != 0)
for (; p1 < pchCommand + COMMAND_SIZE; ++p1) {
if (*p1 != 0) {
return false;
}
else if (*p1 < ' ' || *p1 > 0x7E)
}
}
} else if (*p1 < ' ' || *p1 > 0x7E) {
return false;
}

// Message size
if (nMessageSize > MAX_SIZE)
{
LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) nMessageSize > MAX_SIZE\n", GetCommand(), nMessageSize);
return false;
}
}

return true;
Expand Down
4 changes: 2 additions & 2 deletions src/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ class CMessageHeader
static constexpr size_t HEADER_SIZE = MESSAGE_START_SIZE + COMMAND_SIZE + MESSAGE_SIZE_SIZE + CHECKSUM_SIZE;
typedef unsigned char MessageStartChars[MESSAGE_START_SIZE];

explicit CMessageHeader(const MessageStartChars& pchMessageStartIn);
explicit CMessageHeader();

/** Construct a P2P message header from message-start characters, a command and the size of the message.
* @note Passing in a `pszCommand` longer than COMMAND_SIZE will result in a run-time assertion error.
*/
CMessageHeader(const MessageStartChars& pchMessageStartIn, const char* pszCommand, unsigned int nMessageSizeIn);

std::string GetCommand() const;
bool IsValid(const MessageStartChars& messageStart) const;
bool IsCommandValid() const;

SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); }

Expand Down
5 changes: 2 additions & 3 deletions src/test/fuzz/deserialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,9 @@ void test_one_input(const std::vector<uint8_t>& buffer)
DeserializeFromFuzzingInput(buffer, s);
AssertEqualAfterSerializeDeserialize(s);
#elif MESSAGEHEADER_DESERIALIZE
const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00};
CMessageHeader mh(pchMessageStart);
CMessageHeader mh;
DeserializeFromFuzzingInput(buffer, mh);
(void)mh.IsValid(pchMessageStart);
(void)mh.IsCommandValid();
#elif ADDRESS_DESERIALIZE
CAddress a;
DeserializeFromFuzzingInput(buffer, a);
Expand Down
20 changes: 9 additions & 11 deletions src/test/fuzz/p2p_transport_deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void initialize()

void test_one_input(const std::vector<uint8_t>& buffer)
{
V1TransportDeserializer deserializer{Params().MessageStart(), SER_NETWORK, INIT_PROTO_VERSION};
// Construct deserializer, with a dummy NodeId
V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
const char* pch = (const char*)buffer.data();
size_t n_bytes = buffer.size();
while (n_bytes > 0) {
Expand All @@ -31,16 +32,13 @@ void test_one_input(const std::vector<uint8_t>& buffer)
n_bytes -= handled;
if (deserializer.Complete()) {
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time);
assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE);
assert(msg.m_raw_message_size <= buffer.size());
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
assert(msg.m_time == m_time);
if (msg.m_valid_header) {
assert(msg.m_valid_netmagic);
}
if (!msg.m_valid_netmagic) {
assert(!msg.m_valid_header);
uint32_t out_err_raw_size{0};
Optional<CNetMessage> result{deserializer.GetMessage(m_time, out_err_raw_size)};
if (result) {
assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
assert(result->m_raw_message_size <= buffer.size());
assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
assert(result->m_time == m_time);
}
}
}
Expand Down
Loading

0 comments on commit 6af9b31

Please sign in to comment.