Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup node bootstrap after client download #1132

Merged
merged 3 commits into from
Aug 4, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Tidy up client snapshot service code & tests
  • Loading branch information
Alenar committed Aug 4, 2023
commit 7b65665ec19dda764273393d39f7e7089e253f81
178 changes: 97 additions & 81 deletions mithril-client/src/services/snapshot.rs
Original file line number Diff line number Diff line change
@@ -201,18 +201,18 @@ impl SnapshotService for MithrilClientSnapshotService {
async fn download(
&self,
snapshot_entity: &SignedEntity<Snapshot>,
pathdir: &Path,
download_dir: &Path,
genesis_verification_key: &str,
progress_target: ProgressDrawTarget,
) -> StdResult<PathBuf> {
debug!("Snapshot service: download.");

let unpack_dir = pathdir.join("db");
let db_dir = download_dir.join("db");
let progress_bar = MultiProgress::with_draw_target(progress_target);
progress_bar.println("1/7 - Checking local disk info…")?;
let unpacker = SnapshotUnpacker;

if let Err(e) = unpacker.check_prerequisites(&unpack_dir, snapshot_entity.artifact.size) {
if let Err(e) = unpacker.check_prerequisites(&db_dir, snapshot_entity.artifact.size) {
self.check_disk_space_error(e)?;
}

@@ -237,49 +237,53 @@ impl SnapshotService for MithrilClientSnapshotService {
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
.progress_chars("#>-"));
let filepath = self
let snapshot_path = self
.snapshot_client
.download(&snapshot_entity.artifact, pathdir, pb)
.download(&snapshot_entity.artifact, download_dir, pb)
.await
.map_err(|e| format!("Could not download file in '{}': {e}", pathdir.display()))?;
.map_err(|e| {
format!(
"Could not download file in '{}': {e}",
download_dir.display()
)
})?;

progress_bar.println("5/7 - Unpacking the snapshot…")?;
let unpacker = unpacker.unpack_snapshot(&filepath, &unpack_dir);
let unpacker = unpacker.unpack_snapshot(&snapshot_path, &db_dir);
self.wait_spinner(&progress_bar, unpacker).await?;

progress_bar.println("6/7 - Computing the snapshot digest…")?;
let unpacked_snapshot_digest = self
.immutable_digester
.compute_digest(&unpack_dir, &certificate.beacon)
.compute_digest(&db_dir, &certificate.beacon)
.await
.map_err(|e| {
format!(
"Could not compute digest in '{}': {e}",
unpack_dir.display()
)
})?;
.map_err(|e| format!("Could not compute digest in '{}': {e}", db_dir.display()))?;

progress_bar.println("7/7 - Verifying the snapshot signature…")?;
let mut protocol_message = certificate.protocol_message.clone();
protocol_message.set_message_part(
ProtocolMessagePartKey::SnapshotDigest,
unpacked_snapshot_digest,
);
if protocol_message.compute_hash() != certificate.signed_message {
let expected_message = {
Alenar marked this conversation as resolved.
Show resolved Hide resolved
let mut protocol_message = certificate.protocol_message.clone();
protocol_message.set_message_part(
ProtocolMessagePartKey::SnapshotDigest,
unpacked_snapshot_digest,
);
protocol_message.compute_hash()
};

if expected_message != certificate.signed_message {
debug!("Digest verification failed, removing unpacked files & directory.");

if let Err(e) = std::fs::remove_dir_all(&unpack_dir) {
warn!("Error while removing unpacked files & directory: {e}.");
if let Err(error) = std::fs::remove_dir_all(&db_dir) {
warn!("Error while removing unpacked files & directory: {error}.");
}

return Err(SnapshotServiceError::CouldNotVerifySnapshot {
digest: snapshot_entity.artifact.digest.clone(),
path: filepath.canonicalize().unwrap(),
path: snapshot_path.canonicalize().unwrap(),
}
.into());
}

Ok(unpack_dir)
Ok(db_dir)
}
}

@@ -298,34 +302,44 @@ mod tests {
test_utils::fake_data,
};
use std::{
ffi::OsStr,
fs::{create_dir_all, File},
io::Write,
};

use crate::{
aggregator_client::{AggregatorClient, MockAggregatorHTTPClient},
dependencies::DependenciesBuilder,
services::mock::*,
FromSnapshotMessageAdapter,
};

use super::super::mock::*;

use super::*;

/// see [`archive_file_path`] to see where the dummy will be created
fn build_dummy_snapshot(digest: &str, data_expected: &str, test_dir: &Path) {
// create a fake file to archive
let data_file_path = {
let data_file_path = test_dir.join("db").join("test_data.txt");
create_dir_all(data_file_path.parent().unwrap()).unwrap();

let mut source_file = File::create(data_file_path.as_path()).unwrap();
write!(source_file, "{data_expected}").unwrap();

data_file_path
};

// create the archive
let archive_file_path = test_dir.join(format!("snapshot-{digest}"));
let data_file_path = test_dir.join(Path::new("db/test_data.txt"));
create_dir_all(data_file_path.parent().unwrap()).unwrap();
let mut source_file = File::create(data_file_path.as_path()).unwrap();
write!(source_file, "{data_expected}").unwrap();
let archive_file = File::create(archive_file_path).unwrap();
let archive_encoder = GzEncoder::new(&archive_file, Compression::default());
let mut archive_builder = tar::Builder::new(archive_encoder);
archive_builder
.append_dir_all(".", data_file_path.parent().unwrap())
.unwrap();
archive_builder.into_inner().unwrap().finish().unwrap();

// remove the fake file
let _ = std::fs::remove_dir_all(data_file_path.parent().unwrap());
}

@@ -361,6 +375,36 @@ mod tests {
}
}

fn get_mocks_for_snapshot_service_configured_to_make_download_succeed() -> (
MockAggregatorHTTPClient,
MockCertificateVerifierImpl,
DumbImmutableDigester,
) {
let mut http_client = MockAggregatorHTTPClient::new();
http_client.expect_probe().returning(|_| Ok(()));
http_client
.expect_download()
.returning(move |_, _, _| Ok(()))
.times(1);
http_client.expect_get_content().returning(|_| {
let mut message = CertificateMessage::dummy();
message.signed_message = message.protocol_message.compute_hash();
let message = serde_json::to_string(&message).unwrap();

Ok(message)
});

let mut certificate_verifier = MockCertificateVerifierImpl::new();
certificate_verifier
.expect_verify_certificate_chain()
.returning(|_, _, _| Ok(()))
.times(1);

let dumb_digester = DumbImmutableDigester::new("snapshot-digest-123", true);

(http_client, certificate_verifier, dumb_digester)
}

fn get_dep_builder(http_client: Arc<dyn AggregatorClient>) -> DependenciesBuilder {
let config_builder: ConfigBuilder<DefaultState> = ConfigBuilder::default();
let config = config_builder
@@ -472,40 +516,25 @@ mod tests {
async fn test_download_snapshot_ok() {
let test_path = std::env::temp_dir().join("test_download_snapshot_ok");
let _ = std::fs::remove_dir_all(&test_path);
let mut http_client = MockAggregatorHTTPClient::new();
http_client.expect_probe().returning(|_| Ok(()));
http_client
.expect_download()
.returning(move |_, _, _| Ok(()))
.times(1);
http_client.expect_get_content().returning(|_| {
let mut message = CertificateMessage::dummy();
message.signed_message = message.protocol_message.compute_hash();
let message = serde_json::to_string(&message).unwrap();

Ok(message)
});
let (http_client, certificate_verifier, digester) =
get_mocks_for_snapshot_service_configured_to_make_download_succeed();

let mut builder = get_dep_builder(Arc::new(http_client));
let mut certificate_verifier = MockCertificateVerifierImpl::new();
certificate_verifier
.expect_verify_certificate_chain()
.returning(|_, _, _| Ok(()))
.times(1);
builder.certificate_verifier = Some(Arc::new(certificate_verifier));
builder.immutable_digester = Some(Arc::new(DumbImmutableDigester::new(
"snapshot-digest-123",
true,
)));
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
builder.immutable_digester = Some(Arc::new(digester));
let snapshot_service = builder.get_snapshot_service().await.unwrap();

let (_, verifier) = setup_genesis();
let genesis_verification_key = verifier.to_verification_key();
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
build_dummy_snapshot(
"digest-10.tar.gz",
"1234567890".repeat(124).as_str(),
&test_path,
);

let (_, verifier) = setup_genesis();
let genesis_verification_key = verifier.to_verification_key();

let filepath = snapshot_service
.download(
&snapshot,
@@ -515,42 +544,27 @@ mod tests {
)
.await
.expect("Snapshot download should succeed.");
assert!(filepath.exists());
let unpack_dir = filepath
.parent()
.expect("Test downloaded file must be in a directory.")
.join("db");
assert!(unpack_dir.is_dir());
assert!(
filepath.is_dir(),
"Unpacked location must be in a directory."
);
assert_eq!(Some(OsStr::new("db")), filepath.file_name());
}

#[tokio::test]
async fn test_download_snapshot_invalid_digest() {
let test_path = std::env::temp_dir().join("test_download_snapshot_invalid_digest");
let _ = std::fs::remove_dir_all(&test_path);
let mut http_client = MockAggregatorHTTPClient::new();
http_client.expect_probe().returning(|_| Ok(()));
http_client
.expect_download()
.returning(move |_, _, _| Ok(()))
.times(1);
http_client.expect_get_content().returning(|_| {
let mut message = CertificateMessage::dummy();
message.signed_message = message.protocol_message.compute_hash();
let message = serde_json::to_string(&message).unwrap();

Ok(message)
});
let http_client = Arc::new(http_client);
let mut dep_builder = get_dep_builder(http_client);
let mut certificate_verifier = MockCertificateVerifierImpl::new();
certificate_verifier
.expect_verify_certificate_chain()
.returning(|_, _, _| Ok(()))
.times(1);
let (http_client, certificate_verifier, _) =
get_mocks_for_snapshot_service_configured_to_make_download_succeed();
let immutable_digester = DumbImmutableDigester::new("snapshot-digest-KO", true);

let mut dep_builder = get_dep_builder(Arc::new(http_client));
dep_builder.certificate_verifier = Some(Arc::new(certificate_verifier));
dep_builder.immutable_digester = Some(Arc::new(immutable_digester));
let snapshot_service = dep_builder.get_snapshot_service().await.unwrap();

let mut signed_entity = FromSnapshotMessageAdapter::adapt(get_snapshot_message());
signed_entity.artifact.digest = "digest-10".to_string();

@@ -561,6 +575,7 @@ mod tests {
"1234567890".repeat(124).as_str(),
&test_path,
);

let err = snapshot_service
.download(
&signed_entity,
@@ -600,14 +615,15 @@ mod tests {
let test_path = std::env::temp_dir().join("test_download_snapshot_dir_already_exists");
let _ = std::fs::remove_dir_all(&test_path);
create_dir_all(test_path.join("db")).unwrap();

let http_client = MockAggregatorHTTPClient::new();
let http_client = Arc::new(http_client);
let mut dep_builder = get_dep_builder(http_client);
let mut dep_builder = get_dep_builder(Arc::new(http_client));
let snapshot_service = dep_builder.get_snapshot_service().await.unwrap();

let (_, verifier) = setup_genesis();
let genesis_verification_key = verifier.to_verification_key();
let snapshot = FromSnapshotMessageAdapter::adapt(get_snapshot_message());

let err = snapshot_service
.download(
&snapshot,