From 6547e2f3c642bf9610606b6a00d2ce7135dd1a4b Mon Sep 17 00:00:00 2001 From: Marcel Schnelle Date: Sun, 7 Apr 2024 22:21:42 +0900 Subject: [PATCH] Re-implement the ParallelRunNotifier to allow for more predictable parallel tests We need to go through ridiculous lengths to accommodate the strictly single-thread android instrumentation while still allowing JUnit 5 to run its parallel tests, but I finally found a way to make it work. --- .../AndroidJUnitPlatformRunnerListener.kt | 17 +- .../notification/ParallelRunNotifier.kt | 290 ++++++++++++++---- .../sample/TestRunningOnJUnit5.kt | 49 +-- 3 files changed, 267 insertions(+), 89 deletions(-) diff --git a/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/AndroidJUnitPlatformRunnerListener.kt b/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/AndroidJUnitPlatformRunnerListener.kt index 00ed722c..e422b855 100644 --- a/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/AndroidJUnitPlatformRunnerListener.kt +++ b/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/AndroidJUnitPlatformRunnerListener.kt @@ -21,16 +21,12 @@ internal class AndroidJUnitPlatformRunnerListener( private val notifier: RunNotifier ) : TestExecutionListener { - override fun testPlanExecutionStarted(testPlan: TestPlan) { - // No-op, but must be declared to avoid AbstractMethodError - } - - override fun testPlanExecutionFinished(testPlan: TestPlan) { + override fun reportingEntryPublished(testIdentifier: TestIdentifier?, entry: ReportEntry?) { // No-op, but must be declared to avoid AbstractMethodError } - override fun reportingEntryPublished(testIdentifier: TestIdentifier?, entry: ReportEntry?) { - // No-op, but must be declared to avoid AbstractMethodError + override fun testPlanExecutionStarted(testPlan: TestPlan) { + notifier.fireTestSuiteStarted(testTree.suiteDescription) } override fun executionStarted(testIdentifier: TestIdentifier) { @@ -71,12 +67,15 @@ internal class AndroidJUnitPlatformRunnerListener( notifier.fireTestAssumptionFailed(toFailure(testExecutionResult, description)) } else if (status == TestExecutionResult.Status.FAILED) { notifier.fireTestFailure(toFailure(testExecutionResult, description)) - } - if (description.isTest) { + } else if (description.isTest) { notifier.fireTestFinished(description) } } + override fun testPlanExecutionFinished(testPlan: TestPlan) { + notifier.fireTestSuiteFinished(testTree.suiteDescription) + } + private fun toFailure( testExecutionResult: TestExecutionResult, description: Description diff --git a/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/notification/ParallelRunNotifier.kt b/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/notification/ParallelRunNotifier.kt index 480db7b8..8dbd71f5 100644 --- a/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/notification/ParallelRunNotifier.kt +++ b/instrumentation/runner/src/main/kotlin/de/mannodermaus/junit5/internal/runners/notification/ParallelRunNotifier.kt @@ -1,23 +1,36 @@ +@file:SuppressLint("RestrictedApi") + package de.mannodermaus.junit5.internal.runners.notification +import android.annotation.SuppressLint +import android.os.Bundle import android.util.Log import androidx.test.internal.runner.listener.InstrumentationResultPrinter import de.mannodermaus.junit5.internal.LOG_TAG +import de.mannodermaus.junit5.internal.runners.notification.ParallelRunNotifier.EventThread.Event import org.junit.runner.Description -import org.junit.runner.Result import org.junit.runner.notification.Failure import org.junit.runner.notification.RunListener import org.junit.runner.notification.RunNotifier +import java.util.concurrent.Executors +import java.util.concurrent.LinkedBlockingDeque +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit /** * Wrapping implementation of JUnit 4's run notifier for parallel test execution * (i.e. when "junit.jupiter.execution.parallel.enabled" is active during the run). - * It unpacks the singular 'instrumentation result printer' assigned by Android - * into using one instance per test, preventing its mutable internals from being - * modified by concurrent threads at the same time. + * It unpacks the singular 'instrumentation result printer' assigned by AndroidX + * and reroutes its notification mechanism a way that makes parallel + * tests still execute in parallel, but allowing their results to be reported back + * in the strictly sequential order required by the instrumentation. */ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotifier() { companion object { + private val doneLock = Object() + private val nopPrinter = InstrumentationResultPrinter() + private val nopTestState = TestState("", Bundle(), 0) + // Reflective access is available via companion object // to allow for shared storage of data across notifiers private val reflection by lazy { @@ -30,68 +43,130 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif } } - private val states = mutableMapOf() + private data class TestState( + val testClass: String, + val testResult: Bundle, + val testResultCode: Int, + ) + + private val states = mutableMapOf() + + // Even though parallelism is the name of the game under the hood for this RunNotifier, + // the nature of the Android Instrumentation is very much bound to synchronous execution internally. + // Therefore, a single-threaded executor must be used to project the multithreaded notifications + // from JUnit 5 onto this legacy thread model, resulting in some funky test reporting + // but allowing the awesome performance benefits of parallel test execution! + private lateinit var eventThread: EventThread + private val executor = Executors.newSingleThreadExecutor() // Original printer registered via Android instrumentation - private val printer = reflection?.initialize(delegate) + private val printer = reflection?.initialize(delegate) ?: nopPrinter override fun fireTestSuiteStarted(description: Description) { delegate.fireTestSuiteStarted(description) - } - override fun fireTestRunStarted(description: Description) { - delegate.fireTestRunStarted(description) + // Start asynchronous processing pipeline + eventThread = EventThread( + onProcessEvent = ::onProcessEvent, + onDone = ::onDone, + ).also(EventThread::start) } override fun fireTestStarted(description: Description) { - synchronized(this) { - delegate.fireTestStarted(description) - - // Notify original printer immediately, - // then freeze its state for the current method for later - printer?.testStarted(description) - states[description] = reflection?.copy(printer) - } + eventThread.enqueue(Event.Started(description)) } override fun fireTestIgnored(description: Description) { - synchronized(this) { - delegate.fireTestIgnored(description) - - printer?.testIgnored(description) - } + eventThread.enqueue(Event.Ignored(description)) } override fun fireTestFailure(failure: Failure) { - delegate.fireTestFailure(failure) - - states[failure.description]?.testFailure(failure) + eventThread.enqueue(Event.Finished(failure.description, testFailure = failure)) } override fun fireTestAssumptionFailed(failure: Failure) { - delegate.fireTestAssumptionFailed(failure) - - states[failure.description]?.testAssumptionFailure(failure) + eventThread.enqueue(Event.Finished(failure.description, assumptionFailure = failure)) } override fun fireTestFinished(description: Description) { - synchronized(this) { - delegate.fireTestFinished(description) + eventThread.enqueue(Event.Finished(description)) + } - states[description]?.testFinished(description) - states.remove(description) + override fun fireTestSuiteFinished(description: Description) { + synchronized(doneLock) { + // Request stopping of the asynchronous processing pipeline + eventThread.interruptPolitely(description) + doneLock.wait() } } - override fun fireTestRunFinished(result: Result) { - delegate.fireTestRunFinished(result) - } + /* Private */ - override fun fireTestSuiteFinished(description: Description) { - delegate.fireTestSuiteFinished(description) + private fun onProcessEvent(event: Event) = executor.submit { + val description = event.description + + when (event) { + is Event.Started -> { + delegate.fireTestStarted(description) + printer.testStarted(description) + + // Persist the current printer state for this test + // (for later, when this test's finish event comes in) + states[description] = printer.captureTestState() + } + + is Event.Ignored -> { + delegate.fireTestIgnored(description) + printer.testIgnored(description) + } + + is Event.Finished -> { + // Restore the printer state to the current test case, + // then fire the relevant lifecycle methods of the delegate notifier + printer.restoreTestState(description) + + // For failed test cases, always invoke the failure methods first, + // but invoke the 'finished' method pair for all cases + when { + event.testFailure != null -> { + delegate.fireTestFailure(event.testFailure) + printer.testFailure(event.testFailure) + delegate.fireTestFinished(description) + printer.testFinished(description) + } + + event.assumptionFailure != null -> { + delegate.fireTestAssumptionFailed(event.assumptionFailure) + printer.testAssumptionFailure(event.assumptionFailure) + delegate.fireTestFinished(description) + printer.testFinished(description) + } + + else -> { + delegate.fireTestFinished(description) + printer.testFinished(description) + } + } + } + } } - /* Private */ + private fun onDone(description: Description?) { + synchronized(doneLock) { + // Consume any pending asynchronous tasks + executor.shutdown() + executor.awaitTermination(15, TimeUnit.SECONDS) + + if (description != null) { + delegate.fireTestSuiteFinished(description) + printer.testSuiteFinished(description) + } + + // Unlocks the blockage from fireTestSuiteFinished(), + // allowing the test engine to properly finish this class + doneLock.notifyAll() + } + } private operator fun Map.get(key: Description): T? { return get(key.displayName) @@ -105,14 +180,115 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif remove(key.displayName) } + private fun InstrumentationResultPrinter.captureTestState(): TestState { + return reflection?.captureTestState(this) ?: nopTestState + } + + private fun InstrumentationResultPrinter.restoreTestState(description: Description) { + val state = requireNotNull(states[description]) + reflection?.restoreTestState(this, state) + states.remove(description) + } + + private class EventThread( + private val onProcessEvent: (Event) -> Unit, + private val onDone: (Description?) -> Unit, + ) : Thread("ParallelRunNotifier.EventThread") { + sealed interface Event { + val description: Description + + data class Started(override val description: Description) : Event + data class Finished( + override val description: Description, + val testFailure: Failure? = null, + val assumptionFailure: Failure? = null, + ) : Event + + data class Ignored(override val description: Description) : Event + } + + private val startQueue = LinkedBlockingQueue() + private val ignoreQueue = mutableListOf() + private val finishQueue = LinkedBlockingDeque() + private var interruptionDescription: Description? = null + + fun enqueue(event: Event) { + when (event) { + is Event.Started -> startQueue.offer(event) + is Event.Ignored -> ignoreQueue.add(event) + is Event.Finished -> finishQueue.offerFirst(event) + } + } + + fun interruptPolitely(description: Description) { + interruptionDescription = description + interrupt() + } + + private fun sendEvent(event: Event) { + onProcessEvent(event) + } + + private fun sendDone() { + onDone(interruptionDescription) + } + + override fun run() { + try { + while (true) { + // Accept the first incoming 'started' event + val startEvent = startQueue.take() + sendEvent(startEvent) + + // Now wait until a suitable 'finished' event comes in + var finishEvent = finishQueue.take() + while (finishEvent.description != startEvent.description) { + finishQueue.offer(finishEvent) + finishEvent = finishQueue.take() + } + + // If this point is reached, both event references point to the same test case. + // Allow the finish event to be processed, too + sendEvent(finishEvent) + + // Take care of any new ignore events at this point + ignoreQueue.forEach(::sendEvent) + ignoreQueue.clear() + } + } catch (ignored: InterruptedException) { + // OK + while (startQueue.isNotEmpty()) { + val startEvent = startQueue.take() + sendEvent(startEvent) + + if (finishQueue.isNotEmpty()) { + finishQueue + .firstOrNull { it.description == startEvent.description } + ?.let { finishEvent -> + finishQueue.remove(finishEvent) + sendEvent(finishEvent) + } + } + + ignoreQueue.forEach(::sendEvent) + ignoreQueue.clear() + } + + sendDone() + } + } + } + @Suppress("UNCHECKED_CAST") private class Reflection { - private val synchronizedRunListenerClass = - Class.forName("org.junit.runner.notification.SynchronizedRunListener") - private val synchronizedListenerDelegateField = synchronizedRunListenerClass - .getDeclaredField("listener").also { it.isAccessible = true } - private val runNotifierListenersField = RunNotifier::class.java - .getDeclaredField("listeners").also { it.isAccessible = true } + private fun Class.field(name: String) = this.getDeclaredField(name).also { it.isAccessible = true } + + private val synchronizedRunListenerClass = Class.forName("org.junit.runner.notification.SynchronizedRunListener") + private val synchronizedListenerDelegateField = synchronizedRunListenerClass.field("listener") + private val runNotifierListenersField = RunNotifier::class.java.field("listeners") + private val resultPrinterTestResultField = InstrumentationResultPrinter::class.java.field("testResult") + private val resultPrinterTestResultCodeField = InstrumentationResultPrinter::class.java.field("testResultCode") + private val resultPrinterTestClassField = InstrumentationResultPrinter::class.java.field("testClass") private var cached: InstrumentationResultPrinter? = null @@ -160,22 +336,18 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif } } - fun copy(original: InstrumentationResultPrinter?): InstrumentationResultPrinter? = try { - if (original != null) { - InstrumentationResultPrinter().also { copy -> - copy.instrumentation = original.instrumentation + fun captureTestState(printer: InstrumentationResultPrinter): TestState { + return TestState( + testClass = resultPrinterTestClassField.get(printer) as String, + testResult = resultPrinterTestResultField.get(printer) as Bundle, + testResultCode = resultPrinterTestResultCodeField.get(printer) as Int, + ) + } - InstrumentationResultPrinter::class.java.declaredFields.forEach { field -> - field.isAccessible = true - field.set(copy, field.get(original)) - } - } - } else { - null - } - } catch (e: Throwable) { - e.printStackTrace() - null + fun restoreTestState(printer: InstrumentationResultPrinter, state: TestState) { + resultPrinterTestClassField.set(printer, state.testClass) + resultPrinterTestResultField.set(printer, state.testResult) + resultPrinterTestResultCodeField.set(printer, state.testResultCode) } } } diff --git a/instrumentation/sample/src/androidTest/kotlin/de/mannodermaus/sample/TestRunningOnJUnit5.kt b/instrumentation/sample/src/androidTest/kotlin/de/mannodermaus/sample/TestRunningOnJUnit5.kt index 94c1068c..5b602887 100644 --- a/instrumentation/sample/src/androidTest/kotlin/de/mannodermaus/sample/TestRunningOnJUnit5.kt +++ b/instrumentation/sample/src/androidTest/kotlin/de/mannodermaus/sample/TestRunningOnJUnit5.kt @@ -1,6 +1,7 @@ package de.mannodermaus.sample import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Assumptions import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.junit.jupiter.api.parallel.Execution @@ -10,28 +11,34 @@ import org.junit.jupiter.params.provider.ValueSource @Execution(ExecutionMode.CONCURRENT) class TestRunningOnJUnit5 { - @Test - fun junit5_1() { - Thread.sleep(1000) - Assertions.assertEquals(4, 2 + 2) - } + @Test + fun junit5_1() { + Thread.sleep(1000) + Assertions.assertEquals(4, 2 + 2) + } - @Disabled - @Test - fun junit5_2() { - Thread.sleep(2000) - Assertions.assertEquals(4, 2 + 2) - } + @Disabled + @Test + fun junit5_2() { + Thread.sleep(2000) + Assertions.assertEquals(4, 2 + 2) + } - @Test - fun junit5_3() { - Thread.sleep(3000) - Assertions.assertEquals(4, 2 + 2) - } + @Test + fun junit5_3() { + Thread.sleep(3000) + Assertions.assertEquals(4, 2 + 2) + } - @ValueSource(ints = [1, 2, 3]) - @ParameterizedTest - fun junit5_parameterized(value: Int) { - Thread.sleep(value * 1000L) - } + @Test + fun junit5_4() { + Assumptions.assumeTrue(false, "Failed assumption on purpose") + Assertions.assertEquals(4, 2 + 2) + } + + @ValueSource(ints = [1, 2, 3]) + @ParameterizedTest + fun junit5_parameterized(value: Int) { + Thread.sleep(value * 1000L) + } }