diff --git a/mirai-core-api/src/commonMain/kotlin/event/EventChannel.kt b/mirai-core-api/src/commonMain/kotlin/event/EventChannel.kt index 38df9260665..df087c1e785 100644 --- a/mirai-core-api/src/commonMain/kotlin/event/EventChannel.kt +++ b/mirai-core-api/src/commonMain/kotlin/event/EventChannel.kt @@ -479,9 +479,48 @@ public open class EventChannel @JvmOverloads internal con host: ListenerHost, coroutineContext: CoroutineContext = EmptyCoroutineContext, ) { + val jobOfListenerHost: Job? + val coroutineContext0 = if (host is SimpleListenerHost) { + val listenerCoroutineContext = host.coroutineContext + val listenerJob = listenerCoroutineContext[Job] + + val rsp = listenerCoroutineContext.minusKey(Job) + + coroutineContext + + (listenerCoroutineContext[CoroutineExceptionHandler] ?: EmptyCoroutineContext) + + val registerCancelHook = when { + listenerJob === null -> false + + // Registering cancellation hook is needless + // if [Job] of [EventChannel] is same as [Job] of [SimpleListenerHost] + (rsp[Job] ?: this.defaultCoroutineContext[Job]) === listenerJob -> false + + else -> true + } + + jobOfListenerHost = if (registerCancelHook) { + listenerCoroutineContext[Job] + } else { + null + } + rsp + } else { + jobOfListenerHost = null + coroutineContext + } for (method in host.javaClass.declaredMethods) { method.getAnnotation(EventHandler::class.java)?.let { - method.registerEventHandler(host, this, it, coroutineContext) + val listener = method.registerEventHandler(host, this, it, coroutineContext0) + // For [SimpleListenerHost.cancelAll] + jobOfListenerHost?.invokeOnCompletion { exception -> + listener.cancel( + when (exception) { + is CancellationException -> exception + is Throwable -> CancellationException(null, exception) + else -> null + } + ) + } } } } diff --git a/mirai-core-api/src/jvmTest/kotlin/event/JvmMethodEventsTest.kt b/mirai-core-api/src/jvmTest/kotlin/event/JvmMethodEventsTest.kt index a3194545760..29a893a22f4 100644 --- a/mirai-core-api/src/jvmTest/kotlin/event/JvmMethodEventsTest.kt +++ b/mirai-core-api/src/jvmTest/kotlin/event/JvmMethodEventsTest.kt @@ -12,13 +12,19 @@ package net.mamoe.mirai.event import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive import kotlinx.coroutines.runBlocking import org.jetbrains.annotations.NotNull import org.junit.jupiter.api.Test import java.util.concurrent.atomic.AtomicInteger +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.fail internal class JvmMethodEventsTest : AbstractEventTest() { @@ -107,10 +113,13 @@ internal class JvmMethodEventsTest : AbstractEventTest() { class MyException : RuntimeException() class TestClass : SimpleListenerHost() { + var exceptionHandled: Boolean = false + override fun handleException(context: CoroutineContext, exception: Throwable) { assert(exception is ExceptionInEventHandlerException) assert(exception.event is TestEvent) assert(exception.rootCause is MyException) + exceptionHandled = true } @Suppress("unused") @@ -127,6 +136,35 @@ internal class JvmMethodEventsTest : AbstractEventTest() { TestEvent().broadcast() } cancel() // reset listeners + if (!exceptionHandled) { + fail("SimpleListenerHost.handleException not invoked") + } + } + + TestClass().run { + GlobalEventChannel.registerListenerHost(this) + + runBlocking { + TestEvent().broadcast() + } + cancel() // reset listeners + if (!exceptionHandled) { + fail("SimpleListenerHost.handleException not invoked") + } + } + + TestClass().run { + val scope = CoroutineScope(EmptyCoroutineContext) + scope.globalEventChannel().registerListenerHost(this) + + runBlocking { + TestEvent().broadcast() + } + cancel() // reset listeners + scope.cancel() + if (!exceptionHandled) { + fail("SimpleListenerHost.handleException not invoked") + } } } @@ -163,4 +201,57 @@ internal class JvmMethodEventsTest : AbstractEventTest() { assertEquals(1, this.getCalled()) } } + + @Test + fun testCancellation() { + class TestingListener : SimpleListenerHost() { + var handled: Boolean = false + + @EventHandler + fun handle(event: TestEvent) { + handled = true + } + } + runBlocking { + + TestingListener().runTesting { + TestEvent().broadcast() + assertTrue { handled } + } + + // registered listeners cancelled when parent scope cancelled + CoroutineScope(EmptyCoroutineContext).let { scope -> + TestingListener().runTesting(scope.globalEventChannel()) { listener -> + scope.cancel() + TestEvent().broadcast() + assertFalse { handled } + assertTrue { listener.isActive } + } + } + + // registered listeners cancelled when ListenerHost cancelled + CoroutineScope(EmptyCoroutineContext).let { scope -> + val listener = TestingListener() + listener.runTesting(scope.globalEventChannel()) { } + assertFalse { listener.isActive } + assertTrue { scope.isActive } + TestEvent().broadcast() + assertFalse { listener.handled } + scope.cancel() + } + + } + } + + private inline fun T.runTesting( + channel: EventChannel<*> = GlobalEventChannel, + block: T.(T) -> Unit + ) where T : SimpleListenerHost { + contract { + callsInPlace(block, InvocationKind.EXACTLY_ONCE) + } + channel.registerListenerHost(this) + block(this, this) + cancel() + } }