From 708d5c8c5e60f4658f69f6eb74e71930bf63b30f Mon Sep 17 00:00:00 2001
From: Ralf Jung <post@ralfj.de>
Date: Fri, 28 Oct 2022 11:22:33 +0200
Subject: [PATCH] libtest: run all tests in their own thread, if supported by
 the host

---
 test/src/lib.rs     | 58 ++++++++++++++++++++++-----------------------
 test/src/options.rs |  7 ------
 test/src/tests.rs   | 26 +++++++-------------
 3 files changed, 37 insertions(+), 54 deletions(-)

diff --git a/test/src/lib.rs b/test/src/lib.rs
index 141f16d17..a822df38e 100644
--- a/test/src/lib.rs
+++ b/test/src/lib.rs
@@ -40,7 +40,7 @@ pub mod test {
         cli::{parse_opts, TestOpts},
         filter_tests,
         helpers::metrics::{Metric, MetricMap},
-        options::{Concurrent, Options, RunIgnored, RunStrategy, ShouldPanic},
+        options::{Options, RunIgnored, RunStrategy, ShouldPanic},
         run_test, test_main, test_main_static,
         test_result::{TestResult, TrFailed, TrFailedMsg, TrIgnored, TrOk},
         time::{TestExecTime, TestTimeOptions},
@@ -85,7 +85,7 @@ use event::{CompletedTest, TestEvent};
 use helpers::concurrency::get_concurrency;
 use helpers::exit_code::get_exit_code;
 use helpers::shuffle::{get_shuffle_seed, shuffle_tests};
-use options::{Concurrent, RunStrategy};
+use options::RunStrategy;
 use test_result::*;
 use time::TestExecTime;
 
@@ -235,6 +235,19 @@ where
         join_handle: Option<thread::JoinHandle<()>>,
     }
 
+    impl RunningTest {
+        fn join(self, completed_test: &mut CompletedTest) {
+            if let Some(join_handle) = self.join_handle {
+                if let Err(_) = join_handle.join() {
+                    if let TrOk = completed_test.result {
+                        completed_test.result =
+                            TrFailedMsg("panicked after reporting success".to_string());
+                    }
+                }
+            }
+        }
+    }
+
     // Use a deterministic hasher
     type TestMap =
         HashMap<TestId, RunningTest, BuildHasherDefault<collections::hash_map::DefaultHasher>>;
@@ -328,10 +341,10 @@ where
             let (id, test) = remaining.pop_front().unwrap();
             let event = TestEvent::TeWait(test.desc.clone());
             notify_about_test_event(event)?;
-            let join_handle =
-                run_test(opts, !opts.run_tests, id, test, run_strategy, tx.clone(), Concurrent::No);
-            assert!(join_handle.is_none());
-            let completed_test = rx.recv().unwrap();
+            let join_handle = run_test(opts, !opts.run_tests, id, test, run_strategy, tx.clone());
+            // Wait for the test to complete.
+            let mut completed_test = rx.recv().unwrap();
+            RunningTest { join_handle }.join(&mut completed_test);
 
             let event = TestEvent::TeResult(completed_test);
             notify_about_test_event(event)?;
@@ -345,15 +358,8 @@ where
 
                 let event = TestEvent::TeWait(desc.clone());
                 notify_about_test_event(event)?; //here no pad
-                let join_handle = run_test(
-                    opts,
-                    !opts.run_tests,
-                    id,
-                    test,
-                    run_strategy,
-                    tx.clone(),
-                    Concurrent::Yes,
-                );
+                let join_handle =
+                    run_test(opts, !opts.run_tests, id, test, run_strategy, tx.clone());
                 running_tests.insert(id, RunningTest { join_handle });
                 timeout_queue.push_back(TimeoutEntry { id, desc, timeout });
                 pending += 1;
@@ -385,14 +391,7 @@ where
 
             let mut completed_test = res.unwrap();
             let running_test = running_tests.remove(&completed_test.id).unwrap();
-            if let Some(join_handle) = running_test.join_handle {
-                if let Err(_) = join_handle.join() {
-                    if let TrOk = completed_test.result {
-                        completed_test.result =
-                            TrFailedMsg("panicked after reporting success".to_string());
-                    }
-                }
-            }
+            running_test.join(&mut completed_test);
 
             let event = TestEvent::TeResult(completed_test);
             notify_about_test_event(event)?;
@@ -405,8 +404,10 @@ where
         for (id, b) in filtered_benchs {
             let event = TestEvent::TeWait(b.desc.clone());
             notify_about_test_event(event)?;
-            run_test(opts, false, id, b, run_strategy, tx.clone(), Concurrent::No);
-            let completed_test = rx.recv().unwrap();
+            let join_handle = run_test(opts, false, id, b, run_strategy, tx.clone());
+            // Wait for the test to complete.
+            let mut completed_test = rx.recv().unwrap();
+            RunningTest { join_handle }.join(&mut completed_test);
 
             let event = TestEvent::TeResult(completed_test);
             notify_about_test_event(event)?;
@@ -480,7 +481,6 @@ pub fn run_test(
     test: TestDescAndFn,
     strategy: RunStrategy,
     monitor_ch: Sender<CompletedTest>,
-    concurrency: Concurrent,
 ) -> Option<thread::JoinHandle<()>> {
     let TestDescAndFn { desc, testfn } = test;
 
@@ -498,7 +498,6 @@ pub fn run_test(
     struct TestRunOpts {
         pub strategy: RunStrategy,
         pub nocapture: bool,
-        pub concurrency: Concurrent,
         pub time: Option<time::TestTimeOptions>,
     }
 
@@ -509,7 +508,6 @@ pub fn run_test(
         testfn: Box<dyn FnOnce() -> Result<(), String> + Send>,
         opts: TestRunOpts,
     ) -> Option<thread::JoinHandle<()>> {
-        let concurrency = opts.concurrency;
         let name = desc.name.clone();
 
         let runtest = move || match opts.strategy {
@@ -536,7 +534,7 @@ pub fn run_test(
         // the test synchronously, regardless of the concurrency
         // level.
         let supports_threads = !cfg!(target_os = "emscripten") && !cfg!(target_family = "wasm");
-        if concurrency == Concurrent::Yes && supports_threads {
+        if supports_threads {
             let cfg = thread::Builder::new().name(name.as_slice().to_owned());
             let mut runtest = Arc::new(Mutex::new(Some(runtest)));
             let runtest2 = runtest.clone();
@@ -557,7 +555,7 @@ pub fn run_test(
     }
 
     let test_run_opts =
-        TestRunOpts { strategy, nocapture: opts.nocapture, concurrency, time: opts.time_options };
+        TestRunOpts { strategy, nocapture: opts.nocapture, time: opts.time_options };
 
     match testfn {
         DynBenchFn(benchfn) => {
diff --git a/test/src/options.rs b/test/src/options.rs
index baf36b5f1..75ec0b616 100644
--- a/test/src/options.rs
+++ b/test/src/options.rs
@@ -1,12 +1,5 @@
 //! Enums denoting options for test execution.
 
-/// Whether to execute tests concurrently or not
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-pub enum Concurrent {
-    Yes,
-    No,
-}
-
 /// Number of times to run a benchmarked function
 #[derive(Clone, PartialEq, Eq)]
 pub enum BenchMode {
diff --git a/test/src/tests.rs b/test/src/tests.rs
index b54be64ef..7b2e6707f 100644
--- a/test/src/tests.rs
+++ b/test/src/tests.rs
@@ -102,7 +102,7 @@ pub fn do_not_run_ignored_tests() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_ne!(result, TrOk);
 }
@@ -125,7 +125,7 @@ pub fn ignored_tests_result_in_ignored() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_eq!(result, TrIgnored);
 }
@@ -150,7 +150,7 @@ fn test_should_panic() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_eq!(result, TrOk);
 }
@@ -175,7 +175,7 @@ fn test_should_panic_good_message() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_eq!(result, TrOk);
 }
@@ -205,7 +205,7 @@ fn test_should_panic_bad_message() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_eq!(result, TrFailedMsg(failed_msg.to_string()));
 }
@@ -239,7 +239,7 @@ fn test_should_panic_non_string_message_type() {
         testfn: DynTestFn(Box::new(f)),
     };
     let (tx, rx) = channel();
-    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
     assert_eq!(result, TrFailedMsg(failed_msg));
 }
@@ -267,15 +267,7 @@ fn test_should_panic_but_succeeds() {
             testfn: DynTestFn(Box::new(f)),
         };
         let (tx, rx) = channel();
-        run_test(
-            &TestOpts::new(),
-            false,
-            TestId(0),
-            desc,
-            RunStrategy::InProcess,
-            tx,
-            Concurrent::No,
-        );
+        run_test(&TestOpts::new(), false, TestId(0), desc, RunStrategy::InProcess, tx);
         let result = rx.recv().unwrap().result;
         assert_eq!(
             result,
@@ -306,7 +298,7 @@ fn report_time_test_template(report_time: bool) -> Option<TestExecTime> {
 
     let test_opts = TestOpts { time_options, ..TestOpts::new() };
     let (tx, rx) = channel();
-    run_test(&test_opts, false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&test_opts, false, TestId(0), desc, RunStrategy::InProcess, tx);
     let exec_time = rx.recv().unwrap().exec_time;
     exec_time
 }
@@ -345,7 +337,7 @@ fn time_test_failure_template(test_type: TestType) -> TestResult {
 
     let test_opts = TestOpts { time_options: Some(time_options), ..TestOpts::new() };
     let (tx, rx) = channel();
-    run_test(&test_opts, false, TestId(0), desc, RunStrategy::InProcess, tx, Concurrent::No);
+    run_test(&test_opts, false, TestId(0), desc, RunStrategy::InProcess, tx);
     let result = rx.recv().unwrap().result;
 
     result