Skip to content

Commit

Permalink
Fix some mutexes not being properly RAII-fyed
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Lessa committed Feb 4, 2024
1 parent bedfdea commit 0aca68b
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 49 deletions.
14 changes: 7 additions & 7 deletions src/contract/contractmanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Address ContractManager::deriveContractAddress() const {
}

Bytes ContractManager::getDeployedContracts() const {
std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
std::vector<std::string> names;
std::vector<Address> addresses;
for (const auto& [address, contract] : this->contracts_) {
Expand Down Expand Up @@ -157,7 +157,7 @@ const Bytes ContractManager::callContract(const ethCallInfo& callInfo) const {
const auto& [from, to, gasLimit, gasPrice, value, functor, data] = callInfo;
if (to == this->getContractAddress()) return this->ethCallView(callInfo);
if (to == ProtocolContractAddresses.at("rdPoS")) return rdpos_->ethCallView(callInfo);
std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
if (!this->contracts_.contains(to)) {
throw std::runtime_error(std::string(__func__) + "(Bytes): Contract does not exist");
}
Expand All @@ -167,7 +167,7 @@ const Bytes ContractManager::callContract(const ethCallInfo& callInfo) const {
bool ContractManager::isPayable(const ethCallInfo& callInfo) const {
const auto& address = std::get<1>(callInfo);
const auto& functor = std::get<5>(callInfo);
std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
auto it = this->contracts_.find(address);
if (it == this->contracts_.end()) return false;
return it->second->isPayableFunction(functor);
Expand Down Expand Up @@ -198,7 +198,7 @@ bool ContractManager::validateCallContractWithTx(const ethCallInfo& callInfo) {
return true;
}

std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
if (!this->contracts_.contains(to)) {
this->callLogger_.reset();
return false;
Expand All @@ -219,20 +219,20 @@ bool ContractManager::isContractCall(const TxBlock &tx) const {
for (const auto& [protocolName, protocolAddress] : ProtocolContractAddresses) {
if (tx.getTo() == protocolAddress) return true;
}
std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
return this->contracts_.contains(tx.getTo());
}

bool ContractManager::isContractAddress(const Address &address) const {
std::shared_lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
for (const auto& [protocolName, protocolAddress] : ProtocolContractAddresses) {
if (address == protocolAddress) return true;
}
return this->contracts_.contains(address);
}

std::vector<std::pair<std::string, Address>> ContractManager::getContracts() const {
std::shared_lock lock(this->contractsMutex_);
std::shared_lock<std::shared_mutex> lock(this->contractsMutex_);
std::vector<std::pair<std::string, Address>> contracts;
for (const auto& [address, contract] : this->contracts_) {
contracts.push_back({contract->getContractName(), address});
Expand Down
5 changes: 4 additions & 1 deletion src/core/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ class State {
const std::unordered_map<Hash, TxBlock, SafeHash> getMempool() const;

/// Get the mempool's current size.
inline const size_t getMempoolSize() const { std::shared_lock (this->stateMutex_); return this->mempool_.size(); }
inline const size_t getMempoolSize() const {
std::shared_lock<std::shared_mutex> lock (this->stateMutex_);
return this->mempool_.size();
}

/**
* Validate the next block given the current state and its transactions.
Expand Down
30 changes: 15 additions & 15 deletions src/core/storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ StorageStatus Storage::blockExists(const uint64_t& height) {
auto it = this->blockHashByHeight_.find(height);
if (it != this->blockHashByHeight_.end()) {
if (this->blockByHash_.contains(it->second)) return StorageStatus::OnChain;
std::shared_lock lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
if (this->cachedBlocks_.contains(it->second)) return StorageStatus::OnCache;
return StorageStatus::OnDB;
} else {
Expand All @@ -247,15 +247,15 @@ const std::shared_ptr<const Block> Storage::getBlock(const Hash& hash) {
return nullptr;
}
case StorageStatus::OnChain: {
std::shared_lock lock(this->chainLock_);
std::shared_lock<std::shared_mutex> lock(this->chainLock_);
return this->blockByHash_.find(hash)->second;
}
case StorageStatus::OnCache: {
std::shared_lock lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
return this->cachedBlocks_[hash];
}
case StorageStatus::OnDB: {
std::unique_lock lock(this->cacheLock_);
std::unique_lock<std::shared_mutex> lock(this->cacheLock_);
this->cachedBlocks_.insert({hash, std::make_shared<Block>(
this->db_->get(hash.get(), DBPrefix::blocks), this->options_->getChainID()
)});
Expand All @@ -275,16 +275,16 @@ const std::shared_ptr<const Block> Storage::getBlock(const uint64_t& height) {
return nullptr;
}
case StorageStatus::OnChain: {
std::shared_lock lock(this->chainLock_);
std::shared_lock<std::shared_mutex> lock(this->chainLock_);
return this->blockByHash_.find(this->blockHashByHeight_.find(height)->second)->second;
}
case StorageStatus::OnCache: {
std::shared_lock lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
Hash hash = this->blockHashByHeight_.find(height)->second;
return this->cachedBlocks_.find(hash)->second;
}
case StorageStatus::OnDB: {
std::unique_lock lock(this->cacheLock_);
std::unique_lock<std::shared_mutex> lock(this->cacheLock_);
Hash hash = this->blockHashByHeight_.find(height)->second;
auto blockData = this->db_->get(hash.get(), DBPrefix::blocks);
this->cachedBlocks_.insert({hash, std::make_shared<Block>(blockData, this->options_->getChainID())});
Expand Down Expand Up @@ -325,7 +325,7 @@ const std::tuple<
return {std::make_shared<const TxBlock>(transaction), blockHash, blockIndex, blockHeight};
}
case StorageStatus::OnCache: {
std::shared_lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
return this->cachedTxs_[tx];
}
case StorageStatus::OnDB: {
Expand All @@ -336,7 +336,7 @@ const std::tuple<
uint64_t blockHeight = Utils::bytesToUint64(txDataView.subspan(36,8));
Bytes blockData(this->db_->get(blockHash.get(), DBPrefix::blocks));
auto Tx = this->getTxFromBlockWithIndex(blockData, blockIndex);
std::unique_lock(this->cacheLock_);
std::unique_lock<std::shared_mutex> lock(this->cacheLock_);
this->cachedTxs_.insert({tx, {std::make_shared<const TxBlock>(Tx), blockHash, blockIndex, blockHeight}});
return this->cachedTxs_[tx];
}
Expand All @@ -353,7 +353,7 @@ const std::tuple<
return { nullptr, Hash(), 0, 0 };
}
case StorageStatus::OnChain: {
std::shared_lock lock(this->chainLock_);
std::shared_lock<std::shared_mutex> lock(this->chainLock_);
auto txHash = this->blockByHash_[blockHash]->getTxs()[blockIndex].hash();
const auto& [txBlockHash, txBlockIndex, txBlockHeight] = this->txByHash_[txHash];
if (txBlockHash != blockHash || txBlockIndex != blockIndex) {
Expand All @@ -363,14 +363,14 @@ const std::tuple<
return {std::make_shared<const TxBlock>(transaction), txBlockHash, txBlockIndex, txBlockHeight};
}
case StorageStatus::OnCache: {
std::shared_lock lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
auto txHash = this->cachedBlocks_[blockHash]->getTxs()[blockIndex].hash();
return this->cachedTxs_[txHash];
}
case StorageStatus::OnDB: {
Bytes blockData = this->db_->get(blockHash.get(), DBPrefix::blocks);
auto tx = this->getTxFromBlockWithIndex(blockData, blockIndex);
std::unique_lock lock(this->cacheLock_);
std::unique_lock<std::shared_mutex> lock(this->cacheLock_);
auto blockHeight = this->blockHeightByHash_[blockHash];
this->cachedTxs_.insert({tx.hash(), {std::make_shared<TxBlock>(tx), blockHash, blockIndex, blockHeight}});
return this->cachedTxs_[tx.hash()];
Expand All @@ -388,15 +388,15 @@ const std::tuple<
return { nullptr, Hash(), 0, 0 };
}
case StorageStatus::OnChain: {
std::shared_lock lock(this->chainLock_);
std::shared_lock<std::shared_mutex> lock(this->chainLock_);
auto blockHash = this->blockHashByHeight_.find(blockHeight)->second;
auto txHash = this->blockByHash_[blockHash]->getTxs()[blockIndex].hash();
const auto& [txBlockHash, txBlockIndex, txBlockHeight] = this->txByHash_[txHash];
const auto transaction = this->blockByHash_[blockHash]->getTxs()[blockIndex];
return {std::make_shared<TxBlock>(transaction), txBlockHash, txBlockIndex, txBlockHeight};
}
case StorageStatus::OnCache: {
std::shared_lock lock(this->cacheLock_);
std::shared_lock<std::shared_mutex> lock(this->cacheLock_);
auto blockHash = this->blockHashByHeight_.find(blockHeight)->second;
auto txHash = this->cachedBlocks_[blockHash]->getTxs()[blockIndex].hash();
return this->cachedTxs_[txHash];
Expand All @@ -405,7 +405,7 @@ const std::tuple<
auto blockHash = this->blockHashByHeight_.find(blockHeight)->second;
Bytes blockData = this->db_->get(blockHash.get(), DBPrefix::blocks);
auto tx = this->getTxFromBlockWithIndex(blockData, blockIndex);
std::unique_lock lock(this->cacheLock_);
std::unique_lock<std::shared_mutex> lock(this->cacheLock_);
auto blockHeight = this->blockHeightByHash_[blockHash];
this->cachedTxs_.insert({tx.hash(), { std::make_shared<TxBlock>(tx), blockHash, blockIndex, blockHeight}});
return this->cachedTxs_[tx.hash()];
Expand Down
15 changes: 7 additions & 8 deletions src/net/p2p/discovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace P2P {
std::unique_lock lock(this->requestedNodesMutex_);
for (auto it = this->requestedNodes_.begin(); it != this->requestedNodes_.end();) {
if (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::high_resolution_clock::now().time_since_epoch()).count() - it->second > 60
) {
std::chrono::high_resolution_clock::now().time_since_epoch()
).count() - it->second > 60) {
this->requestedNodes_.erase(it++);
} else {
it++;
Expand All @@ -25,8 +25,8 @@ namespace P2P {
std::unordered_set<NodeID, SafeHash>, std::unordered_set<NodeID, SafeHash>
> DiscoveryWorker::listConnectedNodes() {
std::pair<std::unordered_set<NodeID, SafeHash>, std::unordered_set<NodeID ,SafeHash>> connectedNodes;
std::shared_lock requestedNodesLock(this->requestedNodesMutex_);
std::shared_lock sessionsLock(this->manager_.sessionsMutex_);
std::shared_lock<std::shared_mutex> requestedNodesLock(this->requestedNodesMutex_);
std::shared_lock<std::shared_mutex> sessionsLock(this->manager_.sessionsMutex_);
for (const auto& [nodeId, session] : this->manager_.sessions_) {
// Skip nodes that were already requested in the last 60 seconds
if (this->requestedNodes_.contains(nodeId)) continue;
Expand All @@ -51,13 +51,12 @@ namespace P2P {

bool DiscoveryWorker::discoverLoop() {
bool discoveryPass = false;

Logger::logToDebug(LogType::INFO, Log::P2PDiscoveryWorker, __func__, "Discovery thread started");
while (!this->stopWorker_) {
// Check if we reached connection limit
{
std::this_thread::sleep_for(std::chrono::seconds(1));
std::shared_lock lock(this->manager_.sessionsMutex_);
std::shared_lock<std::shared_mutex> lock(this->manager_.sessionsMutex_);
if (this->manager_.sessions_.size() >= this->manager_.minConnections()) {
// If we don't have at least 11 connections, we don't sleep discovery.
// This is to make sure that local_testnet can quickly start up a new
Expand Down Expand Up @@ -95,9 +94,9 @@ namespace P2P {
if (this->stopWorker_) return true;

// Add requested node to list of requested nodes
std::unique_lock(this->requestedNodesMutex_);
std::unique_lock lock(this->requestedNodesMutex_);
this->requestedNodes_[nodeId] = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::high_resolution_clock::now().time_since_epoch()
std::chrono::high_resolution_clock::now().time_since_epoch()
).count();
}
discoveryPass = true;
Expand Down
33 changes: 15 additions & 18 deletions src/net/p2p/managerbase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ namespace P2P {

std::shared_ptr<Request> ManagerBase::sendRequestTo(const NodeID &nodeId, const std::shared_ptr<const Message>& message) {
if (this->closed_) return nullptr;
std::shared_lock lockSession(this->sessionsMutex_); // ManagerBase::sendRequestTo doesn't change sessions_ map.
std::shared_lock<std::shared_mutex> lockSession(this->sessionsMutex_); // ManagerBase::sendRequestTo doesn't change sessions_ map.
if(!sessions_.contains(nodeId)) {
lockSession.unlock(); // Unlock before calling logToDebug to avoid waiting for the lock in the logToDebug function.
Logger::logToDebug(LogType::ERROR, Log::P2PManager, __func__, "Session does not exist at " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second));
return nullptr;
}
auto session = sessions_[nodeId];
// We can only request ping, info and requestNode to discovery nodes
if (session->hostType() == NodeType::DISCOVERY_NODE && (message->command() == CommandType::Info ||
message->command() == CommandType::RequestValidatorTxs)) {
if (session->hostType() == NodeType::DISCOVERY_NODE &&
(message->command() == CommandType::Info || message->command() == CommandType::RequestValidatorTxs)
) {
lockSession.unlock(); // Unlock before calling logToDebug to avoid waiting for the lock in the logToDebug function.
Logger::logToDebug(LogType::INFO, Log::P2PManager, __func__, "Session is discovery, cannot send message");
return nullptr;
Expand All @@ -81,10 +82,11 @@ namespace P2P {
return requests_[message->id()];
}

// ManagerBase::answerSession doesn't change sessions_ map, but we still need to
// be sure that the session io_context doesn't get deleted while we are using it.
void ManagerBase::answerSession(std::weak_ptr<Session> session, const std::shared_ptr<const Message>& message) {
if (this->closed_) return;
std::shared_lock<std::shared_mutex> lockSession(this->sessionsMutex_); // ManagerBase::answerSession doesn't change sessions_ map.
// But we still need to be sure that the session io_context doesn't get deleted while we are using it.
std::shared_lock<std::shared_mutex> lockSession(this->sessionsMutex_);
if (auto ptr = session.lock()) {
ptr->write(message);
} else {
Expand Down Expand Up @@ -114,10 +116,8 @@ namespace P2P {

std::vector<NodeID> ManagerBase::getSessionsIDs() const {
std::vector<NodeID> nodes;
std::shared_lock lock(this->sessionsMutex_);
for (auto& session : sessions_) {
nodes.push_back(session.first);
}
std::shared_lock<std::shared_mutex> lock(this->sessionsMutex_);
for (auto& session : sessions_) nodes.push_back(session.first);
return nodes;
}

Expand All @@ -136,12 +136,9 @@ namespace P2P {
void ManagerBase::connectToServer(const boost::asio::ip::address& address, uint16_t port) {
if (this->closed_) return;
if (address == this->server_->getLocalAddress() && port == this->serverPort_) return; /// Cannot connect to itself.

{
std::shared_lock(this->sessionsMutex_);
if (this->sessions_.contains({address, port})) {
return; // Node is already connected
}
std::shared_lock<std::shared_mutex> lock(this->sessionsMutex_);
if (this->sessions_.contains({address, port})) return; // Node is already connected
}
this->clientfactory_->connectToServer(address, port);
}
Expand All @@ -151,8 +148,8 @@ namespace P2P {
Utils::logToFile("Pinging " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second));
auto requestPtr = sendRequestTo(nodeId, request);
if (requestPtr == nullptr) throw std::runtime_error(
"Failed to send ping to " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second)
);
"Failed to send ping to " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second)
);
requestPtr->answerFuture().wait();
}

Expand All @@ -177,10 +174,10 @@ namespace P2P {
return AnswerDecoder::requestNodes(*answerPtr);
} catch (std::exception &e) {
Logger::logToDebug(LogType::ERROR, Log::P2PParser, __func__,
"Request to " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second) + " failed with error: " + e.what()
"Request to " + nodeId.first.to_string() + ":" + std::to_string(nodeId.second) + " failed with error: " + e.what()
);
return {};
}
}

}

0 comments on commit 0aca68b

Please sign in to comment.