diff --git a/src/protocol/libp2p/kademlia/handle.rs b/src/protocol/libp2p/kademlia/handle.rs index 15903237..a48fd051 100644 --- a/src/protocol/libp2p/kademlia/handle.rs +++ b/src/protocol/libp2p/kademlia/handle.rs @@ -254,12 +254,12 @@ pub enum KademliaEvent { /// The type of the DHT records. #[derive(Debug, Clone)] pub enum RecordsType { - /// Record was found in the local store. + /// Record was found in the local store and [`Quorum::One`] was used. /// /// This contains only a single result. LocalStore(Record), - /// Records found in the network. + /// Records found in the network. This can include the locally found record. Network(Vec), } diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index e1321ba6..421904d1 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -1118,7 +1118,7 @@ impl Kademlia { .closest(&Key::new(key), self.replication_factor) .into(), quorum, - if record.is_some() { 1 } else { 0 }, + record.cloned(), ); } } diff --git a/src/protocol/libp2p/kademlia/query/get_record.rs b/src/protocol/libp2p/kademlia/query/get_record.rs index 12ea8293..019ece17 100644 --- a/src/protocol/libp2p/kademlia/query/get_record.rs +++ b/src/protocol/libp2p/kademlia/query/get_record.rs @@ -106,7 +106,11 @@ pub struct GetRecordContext { impl GetRecordContext { /// Create new [`GetRecordContext`]. - pub fn new(config: GetRecordConfig, in_peers: VecDeque) -> Self { + pub fn new( + config: GetRecordConfig, + in_peers: VecDeque, + found_records: Vec, + ) -> Self { let mut candidates = BTreeMap::new(); for candidate in &in_peers { @@ -123,7 +127,7 @@ impl GetRecordContext { candidates, pending: HashMap::new(), queried: HashSet::new(), - found_records: Vec::new(), + found_records, } } @@ -378,7 +382,7 @@ mod tests { #[test] fn completes_when_no_candidates() { let config = default_config(); - let mut context = GetRecordContext::new(config, VecDeque::new()); + let mut context = GetRecordContext::new(config, VecDeque::new(), Vec::new()); assert!(context.is_done()); let event = context.next_action().unwrap(); assert_eq!(event, QueryAction::QueryFailed { query: QueryId(0) }); @@ -387,7 +391,7 @@ mod tests { known_records: 1, ..default_config() }; - let mut context = GetRecordContext::new(config, VecDeque::new()); + let mut context = GetRecordContext::new(config, VecDeque::new(), Vec::new()); assert!(context.is_done()); let event = context.next_action().unwrap(); assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) }); @@ -405,7 +409,7 @@ mod tests { assert_eq!(in_peers_set.len(), 3); let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect(); - let mut context = GetRecordContext::new(config, in_peers); + let mut context = GetRecordContext::new(config, in_peers, Vec::new()); for num in 0..3 { let event = context.next_action().unwrap(); @@ -444,7 +448,7 @@ mod tests { assert_eq!(in_peers_set.len(), 3); let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect(); - let mut context = GetRecordContext::new(config, in_peers); + let mut context = GetRecordContext::new(config, in_peers, Vec::new()); // Schedule peer queries. for num in 0..3 { diff --git a/src/protocol/libp2p/kademlia/query/mod.rs b/src/protocol/libp2p/kademlia/query/mod.rs index b933ec5b..9f14a6de 100644 --- a/src/protocol/libp2p/kademlia/query/mod.rs +++ b/src/protocol/libp2p/kademlia/query/mod.rs @@ -318,7 +318,7 @@ impl QueryEngine { target: RecordKey, candidates: VecDeque, quorum: Quorum, - count: usize, + local_record: Option, ) -> QueryId { tracing::debug!( target: LOG_TARGET, @@ -331,7 +331,7 @@ impl QueryEngine { let target = Key::new(target); let config = GetRecordConfig { local_peer_id: self.local_peer_id, - known_records: count, + known_records: if local_record.is_some() { 1 } else { 0 }, quorum, replication_factor: self.replication_factor, parallelism_factor: self.parallelism_factor, @@ -339,10 +339,18 @@ impl QueryEngine { target, }; + let found_records = local_record + .into_iter() + .map(|record| PeerRecord { + peer: self.local_peer_id, + record, + }) + .collect(); + self.queries.insert( query_id, QueryType::GetRecord { - context: GetRecordContext::new(config, candidates), + context: GetRecordContext::new(config, candidates, found_records), }, ); @@ -883,7 +891,7 @@ mod tests { ] .into(), Quorum::All, - 3, + None, ); for _ in 0..4 { diff --git a/tests/protocol/kademlia.rs b/tests/protocol/kademlia.rs index bd121c03..d3d62910 100644 --- a/tests/protocol/kademlia.rs +++ b/tests/protocol/kademlia.rs @@ -464,6 +464,108 @@ async fn get_record_retrieves_remote_records() { } } +#[tokio::test] +async fn get_record_retrieves_local_and_remote_records() { + let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build(); + let (kad_config2, mut kad_handle2) = KademliaConfigBuilder::new().build(); + + let config1 = ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config1) + .build(); + + let config2 = ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config2) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + // Let peers jnow about each other + kad_handle1 + .add_known_peer( + *litep2p2.local_peer_id(), + litep2p2.listen_addresses().cloned().collect(), + ) + .await; + kad_handle2 + .add_known_peer( + *litep2p1.local_peer_id(), + litep2p1.listen_addresses().cloned().collect(), + ) + .await; + + // Store the record on `litep2p1``. + let original_record = Record::new(vec![1, 2, 3], vec![0x01]); + let query1 = kad_handle1.put_record(original_record.clone()).await; + + let (mut peer1_stored, mut peer2_stored) = (false, false); + let mut query3 = None; + + loop { + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => { + panic!("record was not retrieved in 10 secs") + } + event = litep2p1.next_event() => {} + event = litep2p2.next_event() => {} + event = kad_handle1.next() => {} + event = kad_handle2.next() => { + match event { + Some(KademliaEvent::IncomingRecord { record: got_record }) => { + assert_eq!(got_record.key, original_record.key); + assert_eq!(got_record.value, original_record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); + + // Get record. + let query_id = kad_handle2 + .get_record(RecordKey::from(vec![1, 2, 3]), Quorum::All).await; + query3 = Some(query_id); + } + Some(KademliaEvent::GetRecordSuccess { query_id: _, records }) => { + match records { + RecordsType::LocalStore(_) => { + panic!("the record was retrieved only from peer2") + } + RecordsType::Network(records) => { + assert_eq!(records.len(), 2); + + // Locally retrieved record goes first. + assert_eq!(records[0].peer, *litep2p2.local_peer_id()); + assert_eq!(records[0].record.key, original_record.key); + assert_eq!(records[0].record.value, original_record.value); + assert_eq!(records[0].record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(records[0].record.expires.is_some()); + + // Remote record from peer 1. + assert_eq!(records[1].peer, *litep2p1.local_peer_id()); + assert_eq!(records[1].record.key, original_record.key); + assert_eq!(records[1].record.value, original_record.value); + assert_eq!(records[1].record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(records[1].record.expires.is_some()); + + break + } + } + } + Some(KademliaEvent::QueryFailed { query_id: _ }) => { + panic!("peer2 query failed") + } + _ => {} + } + } + } + } +} + #[tokio::test] async fn provider_retrieved_by_remote_node() { let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build();