Skip to content

Commit

Permalink
Re-implement the ParallelRunNotifier to allow for more predictable pa…
Browse files Browse the repository at this point in the history
…rallel tests

We need to go through ridiculous lengths to accommodate the strictly
single-thread android instrumentation while still allowing JUnit 5 to run
its parallel tests, but I finally found a way to make it work.
  • Loading branch information
mannodermaus committed Apr 9, 2024
1 parent ed63e4e commit 6547e2f
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@ internal class AndroidJUnitPlatformRunnerListener(
private val notifier: RunNotifier
) : TestExecutionListener {

override fun testPlanExecutionStarted(testPlan: TestPlan) {
// No-op, but must be declared to avoid AbstractMethodError
}

override fun testPlanExecutionFinished(testPlan: TestPlan) {
override fun reportingEntryPublished(testIdentifier: TestIdentifier?, entry: ReportEntry?) {
// No-op, but must be declared to avoid AbstractMethodError
}

override fun reportingEntryPublished(testIdentifier: TestIdentifier?, entry: ReportEntry?) {
// No-op, but must be declared to avoid AbstractMethodError
override fun testPlanExecutionStarted(testPlan: TestPlan) {
notifier.fireTestSuiteStarted(testTree.suiteDescription)
}

override fun executionStarted(testIdentifier: TestIdentifier) {
Expand Down Expand Up @@ -71,12 +67,15 @@ internal class AndroidJUnitPlatformRunnerListener(
notifier.fireTestAssumptionFailed(toFailure(testExecutionResult, description))
} else if (status == TestExecutionResult.Status.FAILED) {
notifier.fireTestFailure(toFailure(testExecutionResult, description))
}
if (description.isTest) {
} else if (description.isTest) {
notifier.fireTestFinished(description)
}
}

override fun testPlanExecutionFinished(testPlan: TestPlan) {
notifier.fireTestSuiteFinished(testTree.suiteDescription)
}

private fun toFailure(
testExecutionResult: TestExecutionResult,
description: Description
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
@file:SuppressLint("RestrictedApi")

package de.mannodermaus.junit5.internal.runners.notification

import android.annotation.SuppressLint
import android.os.Bundle
import android.util.Log
import androidx.test.internal.runner.listener.InstrumentationResultPrinter
import de.mannodermaus.junit5.internal.LOG_TAG
import de.mannodermaus.junit5.internal.runners.notification.ParallelRunNotifier.EventThread.Event
import org.junit.runner.Description
import org.junit.runner.Result
import org.junit.runner.notification.Failure
import org.junit.runner.notification.RunListener
import org.junit.runner.notification.RunNotifier
import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit

/**
* Wrapping implementation of JUnit 4's run notifier for parallel test execution
* (i.e. when "junit.jupiter.execution.parallel.enabled" is active during the run).
* It unpacks the singular 'instrumentation result printer' assigned by Android
* into using one instance per test, preventing its mutable internals from being
* modified by concurrent threads at the same time.
* It unpacks the singular 'instrumentation result printer' assigned by AndroidX
* and reroutes its notification mechanism a way that makes parallel
* tests still execute in parallel, but allowing their results to be reported back
* in the strictly sequential order required by the instrumentation.
*/
internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotifier() {
companion object {
private val doneLock = Object()
private val nopPrinter = InstrumentationResultPrinter()
private val nopTestState = TestState("", Bundle(), 0)

// Reflective access is available via companion object
// to allow for shared storage of data across notifiers
private val reflection by lazy {
Expand All @@ -30,68 +43,130 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif
}
}

private val states = mutableMapOf<String, InstrumentationResultPrinter?>()
private data class TestState(
val testClass: String,
val testResult: Bundle,
val testResultCode: Int,
)

private val states = mutableMapOf<String, TestState>()

// Even though parallelism is the name of the game under the hood for this RunNotifier,
// the nature of the Android Instrumentation is very much bound to synchronous execution internally.
// Therefore, a single-threaded executor must be used to project the multithreaded notifications
// from JUnit 5 onto this legacy thread model, resulting in some funky test reporting
// but allowing the awesome performance benefits of parallel test execution!
private lateinit var eventThread: EventThread
private val executor = Executors.newSingleThreadExecutor()

// Original printer registered via Android instrumentation
private val printer = reflection?.initialize(delegate)
private val printer = reflection?.initialize(delegate) ?: nopPrinter

override fun fireTestSuiteStarted(description: Description) {
delegate.fireTestSuiteStarted(description)
}

override fun fireTestRunStarted(description: Description) {
delegate.fireTestRunStarted(description)
// Start asynchronous processing pipeline
eventThread = EventThread(
onProcessEvent = ::onProcessEvent,
onDone = ::onDone,
).also(EventThread::start)
}

override fun fireTestStarted(description: Description) {
synchronized(this) {
delegate.fireTestStarted(description)

// Notify original printer immediately,
// then freeze its state for the current method for later
printer?.testStarted(description)
states[description] = reflection?.copy(printer)
}
eventThread.enqueue(Event.Started(description))
}

override fun fireTestIgnored(description: Description) {
synchronized(this) {
delegate.fireTestIgnored(description)

printer?.testIgnored(description)
}
eventThread.enqueue(Event.Ignored(description))
}

override fun fireTestFailure(failure: Failure) {
delegate.fireTestFailure(failure)

states[failure.description]?.testFailure(failure)
eventThread.enqueue(Event.Finished(failure.description, testFailure = failure))
}

override fun fireTestAssumptionFailed(failure: Failure) {
delegate.fireTestAssumptionFailed(failure)

states[failure.description]?.testAssumptionFailure(failure)
eventThread.enqueue(Event.Finished(failure.description, assumptionFailure = failure))
}

override fun fireTestFinished(description: Description) {
synchronized(this) {
delegate.fireTestFinished(description)
eventThread.enqueue(Event.Finished(description))
}

states[description]?.testFinished(description)
states.remove(description)
override fun fireTestSuiteFinished(description: Description) {
synchronized(doneLock) {
// Request stopping of the asynchronous processing pipeline
eventThread.interruptPolitely(description)
doneLock.wait()
}
}

override fun fireTestRunFinished(result: Result) {
delegate.fireTestRunFinished(result)
}
/* Private */

override fun fireTestSuiteFinished(description: Description) {
delegate.fireTestSuiteFinished(description)
private fun onProcessEvent(event: Event) = executor.submit {
val description = event.description

when (event) {
is Event.Started -> {
delegate.fireTestStarted(description)
printer.testStarted(description)

// Persist the current printer state for this test
// (for later, when this test's finish event comes in)
states[description] = printer.captureTestState()
}

is Event.Ignored -> {
delegate.fireTestIgnored(description)
printer.testIgnored(description)
}

is Event.Finished -> {
// Restore the printer state to the current test case,
// then fire the relevant lifecycle methods of the delegate notifier
printer.restoreTestState(description)

// For failed test cases, always invoke the failure methods first,
// but invoke the 'finished' method pair for all cases
when {
event.testFailure != null -> {
delegate.fireTestFailure(event.testFailure)
printer.testFailure(event.testFailure)
delegate.fireTestFinished(description)
printer.testFinished(description)
}

event.assumptionFailure != null -> {
delegate.fireTestAssumptionFailed(event.assumptionFailure)
printer.testAssumptionFailure(event.assumptionFailure)
delegate.fireTestFinished(description)
printer.testFinished(description)
}

else -> {
delegate.fireTestFinished(description)
printer.testFinished(description)
}
}
}
}
}

/* Private */
private fun onDone(description: Description?) {
synchronized(doneLock) {
// Consume any pending asynchronous tasks
executor.shutdown()
executor.awaitTermination(15, TimeUnit.SECONDS)

if (description != null) {
delegate.fireTestSuiteFinished(description)
printer.testSuiteFinished(description)
}

// Unlocks the blockage from fireTestSuiteFinished(),
// allowing the test engine to properly finish this class
doneLock.notifyAll()
}
}

private operator fun <T> Map<String, T>.get(key: Description): T? {
return get(key.displayName)
Expand All @@ -105,14 +180,115 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif
remove(key.displayName)
}

private fun InstrumentationResultPrinter.captureTestState(): TestState {
return reflection?.captureTestState(this) ?: nopTestState
}

private fun InstrumentationResultPrinter.restoreTestState(description: Description) {
val state = requireNotNull(states[description])
reflection?.restoreTestState(this, state)
states.remove(description)
}

private class EventThread(
private val onProcessEvent: (Event) -> Unit,
private val onDone: (Description?) -> Unit,
) : Thread("ParallelRunNotifier.EventThread") {
sealed interface Event {
val description: Description

data class Started(override val description: Description) : Event
data class Finished(
override val description: Description,
val testFailure: Failure? = null,
val assumptionFailure: Failure? = null,
) : Event

data class Ignored(override val description: Description) : Event
}

private val startQueue = LinkedBlockingQueue<Event>()
private val ignoreQueue = mutableListOf<Event>()
private val finishQueue = LinkedBlockingDeque<Event>()
private var interruptionDescription: Description? = null

fun enqueue(event: Event) {
when (event) {
is Event.Started -> startQueue.offer(event)
is Event.Ignored -> ignoreQueue.add(event)
is Event.Finished -> finishQueue.offerFirst(event)
}
}

fun interruptPolitely(description: Description) {
interruptionDescription = description
interrupt()
}

private fun sendEvent(event: Event) {
onProcessEvent(event)
}

private fun sendDone() {
onDone(interruptionDescription)
}

override fun run() {
try {
while (true) {
// Accept the first incoming 'started' event
val startEvent = startQueue.take()
sendEvent(startEvent)

// Now wait until a suitable 'finished' event comes in
var finishEvent = finishQueue.take()
while (finishEvent.description != startEvent.description) {
finishQueue.offer(finishEvent)
finishEvent = finishQueue.take()
}

// If this point is reached, both event references point to the same test case.
// Allow the finish event to be processed, too
sendEvent(finishEvent)

// Take care of any new ignore events at this point
ignoreQueue.forEach(::sendEvent)
ignoreQueue.clear()
}
} catch (ignored: InterruptedException) {
// OK
while (startQueue.isNotEmpty()) {
val startEvent = startQueue.take()
sendEvent(startEvent)

if (finishQueue.isNotEmpty()) {
finishQueue
.firstOrNull { it.description == startEvent.description }
?.let { finishEvent ->
finishQueue.remove(finishEvent)
sendEvent(finishEvent)
}
}

ignoreQueue.forEach(::sendEvent)
ignoreQueue.clear()
}

sendDone()
}
}
}

@Suppress("UNCHECKED_CAST")
private class Reflection {
private val synchronizedRunListenerClass =
Class.forName("org.junit.runner.notification.SynchronizedRunListener")
private val synchronizedListenerDelegateField = synchronizedRunListenerClass
.getDeclaredField("listener").also { it.isAccessible = true }
private val runNotifierListenersField = RunNotifier::class.java
.getDeclaredField("listeners").also { it.isAccessible = true }
private fun <T : Any> Class<T>.field(name: String) = this.getDeclaredField(name).also { it.isAccessible = true }

private val synchronizedRunListenerClass = Class.forName("org.junit.runner.notification.SynchronizedRunListener")
private val synchronizedListenerDelegateField = synchronizedRunListenerClass.field("listener")
private val runNotifierListenersField = RunNotifier::class.java.field("listeners")
private val resultPrinterTestResultField = InstrumentationResultPrinter::class.java.field("testResult")
private val resultPrinterTestResultCodeField = InstrumentationResultPrinter::class.java.field("testResultCode")
private val resultPrinterTestClassField = InstrumentationResultPrinter::class.java.field("testClass")

private var cached: InstrumentationResultPrinter? = null

Expand Down Expand Up @@ -160,22 +336,18 @@ internal class ParallelRunNotifier(private val delegate: RunNotifier) : RunNotif
}
}

fun copy(original: InstrumentationResultPrinter?): InstrumentationResultPrinter? = try {
if (original != null) {
InstrumentationResultPrinter().also { copy ->
copy.instrumentation = original.instrumentation
fun captureTestState(printer: InstrumentationResultPrinter): TestState {
return TestState(
testClass = resultPrinterTestClassField.get(printer) as String,
testResult = resultPrinterTestResultField.get(printer) as Bundle,
testResultCode = resultPrinterTestResultCodeField.get(printer) as Int,
)
}

InstrumentationResultPrinter::class.java.declaredFields.forEach { field ->
field.isAccessible = true
field.set(copy, field.get(original))
}
}
} else {
null
}
} catch (e: Throwable) {
e.printStackTrace()
null
fun restoreTestState(printer: InstrumentationResultPrinter, state: TestState) {
resultPrinterTestClassField.set(printer, state.testClass)
resultPrinterTestResultField.set(printer, state.testResult)
resultPrinterTestResultCodeField.set(printer, state.testResultCode)
}
}
}
Expand Down
Loading

0 comments on commit 6547e2f

Please sign in to comment.