Skip to content

Commit

Permalink
Merge pull request #1564 from webern/no-thread-bombing
Browse files Browse the repository at this point in the history
pubsys: limit threads during validate-repo
  • Loading branch information
webern authored May 7, 2021
2 parents 6cceb1f + 6dee545 commit 867a7ca
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 21 deletions.
87 changes: 87 additions & 0 deletions tools/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tools/pubsys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ futures = "0.3.5"
indicatif = "0.15.0"
lazy_static = "1.4"
log = "0.4"
num_cpus = "1"
parse-datetime = { path = "../../sources/parse-datetime" }
rayon = "1"
# Need to bring in reqwest with a TLS feature so tough can support TLS repos.
reqwest = { version = "0.11.1", default-features = false, features = ["rustls-tls", "blocking"] }
rusoto_core = { version = "0.46.0", default-features = false, features = ["rustls"] }
Expand Down
53 changes: 32 additions & 21 deletions tools/pubsys/src/repo/validate_repo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use crate::Args;
use log::{info, trace};
use pubsys_config::InfraConfig;
use snafu::{OptionExt, ResultExt};
use std::cmp::min;
use std::fs::File;
use std::io;
use std::path::PathBuf;
use std::thread::spawn;
use std::sync::mpsc;
use structopt::StructOpt;
use tough::{Repository, RepositoryLoader};
use url::Url;
Expand Down Expand Up @@ -38,13 +39,25 @@ pub(crate) struct ValidateRepoArgs {
validate_targets: bool,
}

/// Retrieves listed targets and attempts to download them for validation purposes
/// If we are on a machine with a large number of cores, then we limit the number of simultaneous
/// downloads to this arbitrarily chosen maximum.
const MAX_DOWNLOAD_THREADS: usize = 16;

/// Retrieves listed targets and attempts to download them for validation purposes. We use a Rayon
/// thread pool instead of tokio for async execution because `reqwest::blocking` creates a tokio
/// runtime (and multiple tokio runtimes are not supported).
fn retrieve_targets(repo: &Repository) -> Result<(), Error> {
let targets = &repo.targets().signed.targets;
let thread_pool = rayon::ThreadPoolBuilder::new()
.num_threads(min(num_cpus::get(), MAX_DOWNLOAD_THREADS))
.build()
.context(error::ThreadPool)?;

// create the channels through which our download results will be passed
let (tx, rx) = mpsc::channel();

let mut tasks = Vec::new();
for target in targets.keys().cloned() {
let target = target.to_string();
let tx = tx.clone();
let mut reader = repo
.read_target(&target)
.with_context(|| repo_error::ReadTarget {
Expand All @@ -54,24 +67,22 @@ fn retrieve_targets(repo: &Repository) -> Result<(), Error> {
target: target.to_string(),
})?;
info!("Downloading target: {}", target);
// TODO - limit threads https://github.com/bottlerocket-os/bottlerocket/issues/1522
tasks.push(spawn(move || {
// tough's `Read` implementation validates the target as it's being downloaded
io::copy(&mut reader, &mut io::sink()).context(error::TargetDownload {
target: target.to_string(),
thread_pool.spawn(move || {
tx.send({
// tough's `Read` implementation validates the target as it's being downloaded
io::copy(&mut reader, &mut io::sink()).context(error::TargetDownload {
target: target.to_string(),
})
})
}));
// inability to send on this channel is unrecoverable
.unwrap();
});
}
// close all senders
drop(tx);

// ensure that we join all threads before checking the results
let mut results = Vec::new();
for task in tasks {
let result = task.join().map_err(|e| error::Error::Join {
// the join function is returning an error type that does not implement error or display
inner: format!("{:?}", e),
})?;
results.push(result);
}
// block and await all downloads
let results: Vec<Result<u64, error::Error>> = rx.into_iter().collect();

// check all results and return the first error we see
for result in results {
Expand Down Expand Up @@ -164,8 +175,8 @@ mod error {
#[snafu(display("Missing target: {}", target))]
TargetMissing { target: String },

#[snafu(display("Failed to join thread: {}", inner))]
Join { inner: String },
#[snafu(display("Unable to create thread pool: {}", source))]
ThreadPool { source: rayon::ThreadPoolBuildError },
}
}
pub(crate) use error::Error;

0 comments on commit 867a7ca

Please sign in to comment.