Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add duplicate message validation to mempool #202

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
mempool_rx,
app_config.consensus.num_shards,
node.shard_senders.clone(),
node.shard_stores.clone(),
);
tokio::spawn(async move { mempool.run().await });

Expand Down
93 changes: 82 additions & 11 deletions src/mempool/mempool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;

use crate::storage::store::engine::{MempoolMessage, Senders};
use crate::storage::{
store::{
engine::{MempoolMessage, Senders},
stores::Stores,
},
trie::merkle_trie::{self, TrieKey},
};

use super::routing::{MessageRouter, ShardRouter};
use tracing::error;
Expand All @@ -21,6 +27,7 @@ impl Default for Config {

pub struct Mempool {
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
message_router: Box<dyn MessageRouter>,
num_shards: u32,
mempool_rx: mpsc::Receiver<MempoolMessage>,
Expand All @@ -31,27 +38,91 @@ impl Mempool {
mempool_rx: mpsc::Receiver<MempoolMessage>,
num_shards: u32,
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
) -> Self {
Mempool {
shard_senders,
shard_stores,
num_shards,
mempool_rx,
message_router: Box::new(ShardRouter {}),
}
}

fn message_exists_in_trie(&mut self, fid: u64, trie_key: Vec<u8>) -> bool {
let shard = self.message_router.route_message(fid, self.num_shards);
let stores = self.shard_stores.get_mut(&shard);
match stores {
None => {
error!("Error finding store for shard: {}", shard);
false
}
Some(stores) => {
// TODO(aditi): The engine reloads its ref to the trie on commit but we maintain a separate ref to the trie here.
stores.trie.reload(&stores.db).unwrap();
match stores.trie.exists(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, yeah, this is going to be an issue. We shouldn't let this leak out of the trie. Maybe one approach is to make the trie non clonable, and the have a readonly trie that always reads from the db isntead of storing anything in memory.

But at a higher level, we shouldn't use the trie for the duplicate check, it would be too slow. Read from the db directly and we can have a separate caching layer here if required.

&merkle_trie::Context::new(),
&stores.db,
trie_key.as_ref(),
) {
Err(err) => {
error!("Error finding key in trie: {}", err);
false
}
Ok(exists) => exists,
}
}
}
}

fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool {
let fid = message.fid();
match message {
MempoolMessage::UserMessage(message) => {
self.message_exists_in_trie(fid, TrieKey::for_message(message))
}
MempoolMessage::ValidatorMessage(validator_message) => {
if let Some(onchain_event) = &validator_message.on_chain_event {
return self
.message_exists_in_trie(fid, TrieKey::for_onchain_event(&onchain_event));
}

if let Some(fname_transfer) = &validator_message.fname_transfer {
if let Some(proof) = &fname_transfer.proof {
let name = String::from_utf8(proof.name.clone()).unwrap();
return self.message_exists_in_trie(
fid,
TrieKey::for_fname(fname_transfer.id, &name),
);
}
}
false
}
}
}

pub fn message_is_valid(&mut self, message: &MempoolMessage) -> bool {
if self.is_message_already_merged(message) {
return false;
}

return true;
}

pub async fn run(&mut self) {
while let Some(message) = self.mempool_rx.recv().await {
let fid = message.fid();
let shard = self.message_router.route_message(fid, self.num_shards);
let senders = self.shard_senders.get(&shard);
match senders {
None => {
error!("Unable to find shard to send message to")
}
Some(senders) => {
if let Err(err) = senders.messages_tx.send(message).await {
error!("Unable to send message to engine: {}", err.to_string())
if self.message_is_valid(&message) {
let fid = message.fid();
let shard = self.message_router.route_message(fid, self.num_shards);
let senders = self.shard_senders.get(&shard);
match senders {
None => {
error!("Unable to find shard to send message to")
}
Some(senders) => {
if let Err(err) = senders.messages_tx.send(message).await {
error!("Unable to send message to engine: {}", err.to_string())
}
}
}
}
Expand Down
46 changes: 46 additions & 0 deletions src/mempool/mempool_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#[cfg(test)]
mod tests {
use std::collections::HashMap;

use tokio::sync::mpsc;

use crate::{
mempool::mempool::Mempool,
storage::store::{
engine::{MempoolMessage, ShardEngine},
test_helper,
},
utils::factory::messages_factory,
};

use self::test_helper::{default_custody_address, default_signer};

fn setup() -> (ShardEngine, Mempool) {
let (_mempool_tx, mempool_rx) = mpsc::channel(100);
let (engine, _) = test_helper::new_engine();
let mut shard_senders = HashMap::new();
shard_senders.insert(1, engine.get_senders());
let mut shard_stores = HashMap::new();
shard_stores.insert(1, engine.get_stores());
let mempool = Mempool::new(mempool_rx, 1, shard_senders, shard_stores);
(engine, mempool)
}

#[tokio::test]
async fn test_duplicate_message_is_invalid() {
let (mut engine, mut mempool) = setup();
test_helper::register_user(
1234,
default_signer(),
default_custody_address(),
&mut engine,
)
.await;
let cast = messages_factory::casts::create_cast_add(1234, "hello", None, None);
let valid = mempool.message_is_valid(&MempoolMessage::UserMessage(cast.clone()));
assert!(valid);
test_helper::commit_message(&mut engine, &cast).await;
let valid = mempool.message_is_valid(&MempoolMessage::UserMessage(cast.clone()));
assert!(!valid)
}
}
3 changes: 3 additions & 0 deletions src/mempool/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
pub mod mempool;
pub mod routing;

#[cfg(test)]
mod mempool_test;
2 changes: 1 addition & 1 deletion src/network/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod tests {
assert_eq!(message_router.route_message(SHARD2_FID, 2), 2);

let (mempool_tx, mempool_rx) = mpsc::channel(1000);
let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone());
let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone(), stores.clone());
tokio::spawn(async move { mempool.run().await });

(
Expand Down
7 changes: 6 additions & 1 deletion tests/consensus_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ impl NodeForTest {
let grpc_shard_stores = node.shard_stores.clone();
let grpc_shard_senders = node.shard_senders.clone();
let (mempool_tx, mempool_rx) = mpsc::channel(100);
let mut mempool = Mempool::new(mempool_rx, num_shards, node.shard_senders.clone());
let mut mempool = Mempool::new(
mempool_rx,
num_shards,
node.shard_senders.clone(),
node.shard_stores.clone(),
);
tokio::spawn(async move { mempool.run().await });

tokio::spawn(async move {
Expand Down