Skip to content

Commit

Permalink
Merge pull request #2032 from input-output-hk/jpraynaud/fix-certifica…
Browse files Browse the repository at this point in the history
…te-chain-chaining

Fix: computation of the chaining of the certificates in tests
  • Loading branch information
jpraynaud authored Oct 22, 2024
2 parents e431657 + 199d6c5 commit 0cbc482
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion mithril-common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mithril-common"
version = "0.4.73"
version = "0.4.74"
description = "Common types, interfaces, and utilities for Mithril nodes."
authors = { workspace = true }
edition = { workspace = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ use thiserror::Error;

use crate::{entities::Certificate, StdError};

#[cfg(test)]
use mockall::automock;

/// [CertificateRetriever] related errors.
#[derive(Debug, Error)]
#[error("Error when retrieving certificate")]
pub struct CertificateRetrieverError(#[source] pub StdError);

/// CertificateRetriever is in charge of retrieving a [Certificate] given its hash
#[cfg_attr(test, automock)]
#[cfg_attr(test, mockall::automock)]
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait CertificateRetriever: Sync + Send {
Expand Down
46 changes: 14 additions & 32 deletions mithril-common/src/certificate_chain/certificate_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ mod tests {
use super::CertificateRetriever;
use super::*;

use crate::certificate_chain::CertificateRetrieverError;
use crate::certificate_chain::{CertificateRetrieverError, FakeCertificaterRetriever};
use crate::crypto_helper::{tests_setup::*, ProtocolClerk};
use crate::test_utils::{MithrilFixtureBuilder, TestLogger};

Expand Down Expand Up @@ -527,24 +527,19 @@ mod tests {
let certificates_per_epoch = 2;
let (fake_certificates, genesis_verifier) =
setup_certificate_chain(total_certificates, certificates_per_epoch);
let mut mock_certificate_retriever = MockCertificateRetrieverImpl::new();
let certificate_retriever =
FakeCertificaterRetriever::from_certificates(&fake_certificates);
let verifier =
MithrilCertificateVerifier::new(TestLogger::stdout(), Arc::new(certificate_retriever));
let certificate_to_verify = fake_certificates[0].clone();
for fake_certificate in fake_certificates.into_iter().skip(1) {
mock_certificate_retriever
.expect_get_certificate_details()
.returning(move |_| Ok(fake_certificate.clone()))
.times(1);
}
let verifier = MithrilCertificateVerifier::new(
TestLogger::stdout(),
Arc::new(mock_certificate_retriever),
);

let verify = verifier
.verify_certificate_chain(
certificate_to_verify,
&genesis_verifier.to_verification_key(),
)
.await;

verify.expect("unexpected error");
}

Expand All @@ -555,23 +550,13 @@ mod tests {
let (mut fake_certificates, genesis_verifier) =
setup_certificate_chain(total_certificates, certificates_per_epoch);
let index_certificate_fail = (total_certificates / 2) as usize;
fake_certificates[index_certificate_fail].hash = "tampered-hash".to_string();
let mut mock_certificate_retriever = MockCertificateRetrieverImpl::new();
fake_certificates[index_certificate_fail].signed_message = "tampered-message".to_string();
let certificate_retriever =
FakeCertificaterRetriever::from_certificates(&fake_certificates);
let verifier =
MithrilCertificateVerifier::new(TestLogger::stdout(), Arc::new(certificate_retriever));
let certificate_to_verify = fake_certificates[0].clone();
for fake_certificate in fake_certificates
.into_iter()
.skip(1)
.take(index_certificate_fail)
{
mock_certificate_retriever
.expect_get_certificate_details()
.returning(move |_| Ok(fake_certificate.clone()))
.times(1);
}
let verifier = MithrilCertificateVerifier::new(
TestLogger::stdout(),
Arc::new(mock_certificate_retriever),
);

let error = verifier
.verify_certificate_chain(
certificate_to_verify,
Expand All @@ -584,10 +569,7 @@ mod tests {
.expect("Can not downcast to `CertificateVerifierError`.");

assert!(
matches!(
error,
CertificateVerifierError::CertificateChainPreviousHashUnmatch
),
matches!(error, CertificateVerifierError::CertificateHashUnmatch),
"unexpected error type: {error:?}"
);
}
Expand Down
77 changes: 77 additions & 0 deletions mithril-common/src/certificate_chain/fake_certificate_retriever.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//! A module used for a fake implementation of a certificate chain retriever
//!
use anyhow::anyhow;
use async_trait::async_trait;
use std::collections::HashMap;
use tokio::sync::RwLock;

use crate::entities::Certificate;

use super::{CertificateRetriever, CertificateRetrieverError};

/// A fake [CertificateRetriever] that returns a [Certificate] given its hash
pub struct FakeCertificaterRetriever {
certificates_map: RwLock<HashMap<String, Certificate>>,
}

impl FakeCertificaterRetriever {
/// Create a new [FakeCertificaterRetriever]
pub fn from_certificates(certificates: &[Certificate]) -> Self {
let certificates_map = certificates
.iter()
.map(|certificate| (certificate.hash.clone(), certificate.clone()))
.collect::<HashMap<_, _>>();
let certificates_map = RwLock::new(certificates_map);

Self { certificates_map }
}
}

#[async_trait]
impl CertificateRetriever for FakeCertificaterRetriever {
async fn get_certificate_details(
&self,
certificate_hash: &str,
) -> Result<Certificate, CertificateRetrieverError> {
let certificates_map = self.certificates_map.read().await;
certificates_map
.get(certificate_hash)
.cloned()
.ok_or_else(|| CertificateRetrieverError(anyhow!("Certificate not found")))
}
}

#[cfg(test)]
mod tests {
use crate::test_utils::fake_data;

use super::*;

#[tokio::test]
async fn fake_certificate_retriever_retrieves_existing_certificate() {
let certificate = fake_data::certificate("certificate-hash-123".to_string());
let certificate_hash = certificate.hash.clone();
let certificate_retriever =
FakeCertificaterRetriever::from_certificates(&[certificate.clone()]);

let retrieved_certificate = certificate_retriever
.get_certificate_details(&certificate_hash)
.await
.expect("Should retrieve certificate");

assert_eq!(retrieved_certificate, certificate);
}

#[tokio::test]
async fn test_fake_certificate_fails_retrieving_unknow_certificate() {
let certificate = fake_data::certificate("certificate-hash-123".to_string());
let certificate_retriever = FakeCertificaterRetriever::from_certificates(&[certificate]);

let retrieved_certificate = certificate_retriever
.get_certificate_details("certificate-hash-not-found")
.await;

retrieved_certificate.expect_err("get_certificate_details shoudl fail");
}
}
7 changes: 7 additions & 0 deletions mithril-common/src/certificate_chain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
mod certificate_genesis;
mod certificate_retriever;
mod certificate_verifier;
cfg_test_tools! {
mod fake_certificate_retriever;
}

pub use certificate_genesis::CertificateGenesisProducer;
pub use certificate_retriever::{CertificateRetriever, CertificateRetrieverError};
pub use certificate_verifier::{
CertificateVerifier, CertificateVerifierError, MithrilCertificateVerifier,
};

cfg_test_tools! {
pub use fake_certificate_retriever::FakeCertificaterRetriever;
}
124 changes: 90 additions & 34 deletions mithril-common/src/test_utils/certificate_chain_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,29 +419,58 @@ impl<'a> CertificateChainBuilder<'a> {
certificate
}

fn update_certificate_previous_hash(
&self,
certificate: Certificate,
previous_certificate: Option<&Certificate>,
) -> Certificate {
let mut certificate = certificate;
certificate.previous_hash = previous_certificate
.map(|c| c.hash.to_string())
.unwrap_or_default();
certificate.hash = certificate.compute_hash();

certificate
}

fn fetch_previous_certificate_from_chain<'b>(
&self,
certificate: &Certificate,
certificates_chained: &'b [Certificate],
) -> Option<&'b Certificate> {
let is_certificate_first_of_epoch = certificates_chained
.last()
.map(|c| c.epoch != certificate.epoch)
.unwrap_or(true);

certificates_chained
.iter()
.rev()
.filter(|c| {
if is_certificate_first_of_epoch {
// The previous certificate of the first certificate of an epoch
// is the first certificate of the previous epoch
c.epoch == certificate.epoch.previous().unwrap()
} else {
// The previous certificate of not the first certificate of an epoch
// is the first certificate of the epoch
c.epoch == certificate.epoch
}
})
.last()
}

// Returns the chained certificates in reverse order
// The latest certificate of the chain is the first in the vector
fn compute_chained_certificates(&self, certificates: Vec<Certificate>) -> Vec<Certificate> {
fn update_certificate_previous_hash(
certificate: Certificate,
previous_certificate: Option<&Certificate>,
) -> Certificate {
let mut certificate = certificate;
certificate.previous_hash = previous_certificate
.map(|c| c.hash.to_string())
.unwrap_or_default();
certificate.hash = certificate.compute_hash();

certificate
}

let mut certificates_chained: Vec<Certificate> =
certificates
.into_iter()
.fold(Vec::new(), |mut certificates_chained, certificate| {
let previous_certificate_maybe = certificates_chained.last();
let certificate =
update_certificate_previous_hash(certificate, previous_certificate_maybe);
let previous_certificate_maybe = self
.fetch_previous_certificate_from_chain(&certificate, &certificates_chained);
let certificate = self
.update_certificate_previous_hash(certificate, previous_certificate_maybe);
certificates_chained.push(certificate);

certificates_chained
Expand Down Expand Up @@ -760,32 +789,59 @@ mod test {

#[test]
fn builds_certificate_chain_correctly_chained() {
let certificates = vec![
Certificate {
epoch: Epoch(1),
..fake_data::certificate("cert-1".to_string())
},
Certificate {
epoch: Epoch(2),
..fake_data::certificate("cert-2".to_string())
},
fn create_fake_certificate(epoch: Epoch, index_in_epoch: u64) -> Certificate {
Certificate {
epoch: Epoch(3),
..fake_data::certificate("cert-3".to_string())
},
epoch,
signed_message: format!("certificate-{}-{index_in_epoch}", *epoch),
..fake_data::certificate("cert-fake".to_string())
}
}

let certificates = vec![
create_fake_certificate(Epoch(1), 1),
create_fake_certificate(Epoch(2), 1),
create_fake_certificate(Epoch(2), 2),
create_fake_certificate(Epoch(3), 1),
create_fake_certificate(Epoch(4), 1),
create_fake_certificate(Epoch(4), 2),
create_fake_certificate(Epoch(4), 3),
];

let certificates_chained =
let mut certificates_chained =
CertificateChainBuilder::default().compute_chained_certificates(certificates);
certificates_chained.reverse();

assert_eq!("", certificates_chained[2].previous_hash);
let certificate_chained_1_1 = &certificates_chained[0];
let certificate_chained_2_1 = &certificates_chained[1];
let certificate_chained_2_2 = &certificates_chained[2];
let certificate_chained_3_1 = &certificates_chained[3];
let certificate_chained_4_1 = &certificates_chained[4];
let certificate_chained_4_2 = &certificates_chained[5];
let certificate_chained_4_3 = &certificates_chained[6];
assert_eq!("", certificate_chained_1_1.previous_hash);
assert_eq!(
certificate_chained_2_1.previous_hash,
certificate_chained_1_1.hash
);
assert_eq!(
certificate_chained_2_2.previous_hash,
certificate_chained_2_1.hash
);
assert_eq!(
certificate_chained_3_1.previous_hash,
certificate_chained_2_1.hash
);
assert_eq!(
certificate_chained_4_1.previous_hash,
certificate_chained_3_1.hash
);
assert_eq!(
certificates_chained[2].hash,
certificates_chained[1].previous_hash
certificate_chained_4_2.previous_hash,
certificate_chained_4_1.hash
);
assert_eq!(
certificates_chained[1].hash,
certificates_chained[0].previous_hash
certificate_chained_4_3.previous_hash,
certificate_chained_4_1.hash
);
}

Expand Down

0 comments on commit 0cbc482

Please sign in to comment.