From 0b58f8604ab82f2b3bc3da72682a417dd1bbff24 Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Fri, 13 Dec 2024 12:30:41 -0800 Subject: [PATCH 1/2] [prelude][core][stm] fix issue with preemption in nested async computations --- .../main/scala/kyo/scheduler/IOPromise.scala | 15 ++--- .../src/main/scala/kyo/scheduler/IOTask.scala | 3 + .../shared/src/test/scala/kyo/AsyncTest.scala | 55 +++++++++++++++++++ .../test/scala/kyo/kernel/SafepointTest.scala | 28 +++++++--- .../src/main/scala/kyo/kernel/Safepoint.scala | 12 +++- .../shared/src/test/scala/kyo/STMTest.scala | 24 ++++++++ 6 files changed, 120 insertions(+), 17 deletions(-) diff --git a/kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala b/kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala index 46cbb4601..53e428741 100644 --- a/kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala +++ b/kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala @@ -216,14 +216,15 @@ private[kyo] class IOPromise[+E, +A](init: State[E, A]) extends Safepoint.Interc blockLoop(this) end block + protected def stateString(): String = + state.get() match + case p: Pending[?, ?] => s"Pending(waiters = ${p.waiters})" + case l: Linked[?, ?] => s"Linked(promise = ${l.p})" + case r => s"Done(result = ${r.asInstanceOf[Result[Any, Any]].show})" + override def toString = - val stateString = - state.get() match - case p: Pending[?, ?] => s"Pending(waiters = ${p.waiters})" - case l: Linked[?, ?] => s"Linked(promise = ${l.p})" - case r => s"Done(result = ${r.asInstanceOf[Result[Any, Any]].show})" - s"IOPromise(state = ${stateString})" - end toString + s"IOPromise(state = ${stateString()})" + end IOPromise private[kyo] object IOPromise: diff --git a/kyo-core/shared/src/main/scala/kyo/scheduler/IOTask.scala b/kyo-core/shared/src/main/scala/kyo/scheduler/IOTask.scala index 2303ce215..36278f91e 100644 --- a/kyo-core/shared/src/main/scala/kyo/scheduler/IOTask.scala +++ b/kyo-core/shared/src/main/scala/kyo/scheduler/IOTask.scala @@ -87,6 +87,9 @@ sealed private[kyo] class IOTask[Ctx, E, A] private ( private inline def nullResult = null.asInstanceOf[A < Ctx & Async & Abort[E]] + override def toString = + s"IOTask(state = ${stateString()}, preempt = ${{ shouldPreempt() }}, finalizers = ${finalizers.size()}, curr = ${curr})" + end IOTask object IOTask: diff --git a/kyo-core/shared/src/test/scala/kyo/AsyncTest.scala b/kyo-core/shared/src/test/scala/kyo/AsyncTest.scala index 30b807a99..2452ab880 100644 --- a/kyo-core/shared/src/test/scala/kyo/AsyncTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/AsyncTest.scala @@ -1014,4 +1014,59 @@ class AsyncTest extends Test: } } + "preemption is properly handled in nested Async computations" - { + "simple" in run { + Async.run(Async.run(Async.delay(100.millis)(42))).map(_.get).map(_.get).map { result => + assert(result == 42) + } + } + "with nested eval" in run { + import AllowUnsafe.embrace.danger + val task = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(42))) + Async.run(task).map(_.get).map(_.get).map { result => + assert(result == 42) + } + } + "with multiple nested evals" in run { + import AllowUnsafe.embrace.danger + val innerTask = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(42))) + val middleTask = IO.Unsafe.evalOrThrow(Async.run(innerTask)) + val outerTask = IO.Unsafe.evalOrThrow(Async.run(middleTask)) + Async.run(outerTask).map(_.get).map(_.get).map(_.get).map(_.get).map { result => + assert(result == 42) + } + } + "with eval inside async computation" in run { + import AllowUnsafe.embrace.danger + Async.run { + Async.delay(100.millis) { + IO.Unsafe.evalOrThrow(Async.run(42)).get + } + }.map(_.get).map { result => + assert(result == 42) + } + } + "with interleaved evals and delays" in run { + import AllowUnsafe.embrace.danger + val task1 = IO.Unsafe.evalOrThrow(Async.run(Async.delay(100.millis)(1))) + val task2 = Async.delay(100.millis) { + IO.Unsafe.evalOrThrow(Async.run(task1)).get + } + val task3 = IO.Unsafe.evalOrThrow(Async.run(task2)) + Async.run(task3).map(_.get).map(_.get).map(_.get).map { result => + assert(result == 1) + } + } + "with race" in run { + Async.run { + Async.race( + Async.run(Async.delay(100.millis)(1)).map(_.get), + Async.run(Async.delay(200.millis)(2)).map(_.get) + ) + }.map(_.get).map { result => + assert(result == 1) + } + } + } + end AsyncTest diff --git a/kyo-prelude/jvm/src/test/scala/kyo/kernel/SafepointTest.scala b/kyo-prelude/jvm/src/test/scala/kyo/kernel/SafepointTest.scala index a85b2a225..95bd47b47 100644 --- a/kyo-prelude/jvm/src/test/scala/kyo/kernel/SafepointTest.scala +++ b/kyo-prelude/jvm/src/test/scala/kyo/kernel/SafepointTest.scala @@ -193,10 +193,22 @@ class SafepointTest extends Test: executed = true true - Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1).eval) + Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1)).eval assert(executed) } + "eval removes the interceptor" in { + var executed = false + val interceptor = new TestInterceptor: + def ensure(f: () => Unit): Unit = () + def enter(frame: Frame, value: Any): Boolean = + executed = true + true + + Safepoint.immediate(interceptor)((1: Int < Any).map(_ + 1).eval) + assert(!executed) + } + "restore previous interceptor" in { var count = 0 val interceptor1 = new TestInterceptor: @@ -212,7 +224,7 @@ class SafepointTest extends Test: true Safepoint.immediate(interceptor1) { - Safepoint.immediate(interceptor2)((1: Int < Any).map(_ + 1).eval) + Safepoint.immediate(interceptor2)((1: Int < Any).map(_ + 1)) }.eval assert(count == 11) @@ -508,8 +520,8 @@ class SafepointTest extends Test: assert(interceptor.ensuresAdded.size == 1) assert(interceptor.ensuresRemoved.isEmpty) 42 - }.eval - } + } + }.eval assert(interceptor.ensuresAdded.size == 1) assert(interceptor.ensuresRemoved.size == 1) @@ -525,8 +537,8 @@ class SafepointTest extends Test: Safepoint.immediate(interceptor) { testEnsure { 42 - }.eval - } + } + }.eval assert(interceptor.ensuresRemoved.size == 1) } @@ -544,8 +556,8 @@ class SafepointTest extends Test: 42 } } - }.eval - } + } + }.eval assert(interceptor.ensuresAdded.size == 3) assert(interceptor.ensuresRemoved.size == 3) diff --git a/kyo-prelude/shared/src/main/scala/kyo/kernel/Safepoint.scala b/kyo-prelude/shared/src/main/scala/kyo/kernel/Safepoint.scala index 759993e54..d84d32f4f 100644 --- a/kyo-prelude/shared/src/main/scala/kyo/kernel/Safepoint.scala +++ b/kyo-prelude/shared/src/main/scala/kyo/kernel/Safepoint.scala @@ -43,6 +43,10 @@ final class Safepoint private () extends Trace.Owner: interceptor = newInterceptor state = state.withInterceptor(newInterceptor != null) + override def toString(): String = + val currentState = state + s"Safepoint(depth=${currentState.depth}, threadId=${currentState.threadId}, interceptor=${interceptor})" + end Safepoint object Safepoint: @@ -166,8 +170,12 @@ object Safepoint: private[kernel] inline def eval[A]( inline f: Safepoint ?=> A )(using inline frame: Frame): A = - val self = Safepoint.get - self.withNewTrace(f(using self)) + val self = Safepoint.get + val prevInterceptor = self.interceptor + self.setInterceptor(null) + try self.withNewTrace(f(using self)) + finally + self.interceptor = prevInterceptor end eval private[kernel] inline def handle[V, A, S](value: V)( diff --git a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala index 2a8ff0409..3d871a5d4 100644 --- a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala +++ b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala @@ -1,5 +1,7 @@ package kyo +import scala.concurrent.Future + class STMTest extends Test: "Transaction isolation" - { @@ -715,4 +717,26 @@ class STMTest extends Test: } } + "bug #925" in run { + def unsafeToFuture[A: Flat](a: => A < (Async & Abort[Throwable])): Future[A] = + import kyo.AllowUnsafe.embrace.danger + IO.Unsafe.evalOrThrow( + Async.run(a).map(_.toFuture) + ) + end unsafeToFuture + + val ex = new Exception + + val faultyTransaction: Int < STM = TRef.init(42).map { r => + throw ex + r.get + } + + val task = Async.runAndBlock(Duration.Infinity)(Async.fromFuture(unsafeToFuture(STM.run(faultyTransaction)))) + + Abort.run(task).map { result => + assert(result == Result.fail(ex)) + } + } + end STMTest From 6e059db37c34d133651cb12341e7232f5444b5dd Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Fri, 13 Dec 2024 13:03:18 -0800 Subject: [PATCH 2/2] fix js build --- kyo-stm/shared/src/test/scala/kyo/STMTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala index 3d871a5d4..b4a29dbe5 100644 --- a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala +++ b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala @@ -717,7 +717,7 @@ class STMTest extends Test: } } - "bug #925" in run { + "bug #925" in runJVM { def unsafeToFuture[A: Flat](a: => A < (Async & Abort[Throwable])): Future[A] = import kyo.AllowUnsafe.embrace.danger IO.Unsafe.evalOrThrow(