From b09199eafc0073a30a4f0e3d484a9b2583fbc0bb Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Sun, 1 Dec 2024 19:26:20 -0800 Subject: [PATCH] [core] STM effect (#878) See scaladocs for more info --- build.sbt | 12 + .../shared/src/test/scala/kyo/FiberTest.scala | 4 +- kyo-core/shared/src/test/scala/kyo/Test.scala | 6 +- .../shared/src/main/scala/kyo/Kyo.scala | 39 + .../shared/src/main/scala/kyo/Var.scala | 17 +- .../shared/src/test/scala/kyo/KyoTest.scala | 57 ++ .../shared/src/test/scala/kyo/VarTest.scala | 7 + .../shared/src/main/scala/kyo/RefLog.scala | 40 + kyo-stm/shared/src/main/scala/kyo/STM.scala | 148 ++++ kyo-stm/shared/src/main/scala/kyo/TID.scala | 26 + kyo-stm/shared/src/main/scala/kyo/TMap.scala | 297 ++++++++ kyo-stm/shared/src/main/scala/kyo/TRef.scala | 189 +++++ .../src/test/scala/kyo/RefLogTest.scala | 54 ++ .../shared/src/test/scala/kyo/STMTest.scala | 717 ++++++++++++++++++ .../shared/src/test/scala/kyo/TMapTest.scala | 500 ++++++++++++ .../shared/src/test/scala/kyo/TRefTest.scala | 38 + kyo-stm/shared/src/test/scala/kyo/Test.scala | 32 + 17 files changed, 2179 insertions(+), 4 deletions(-) create mode 100644 kyo-stm/shared/src/main/scala/kyo/RefLog.scala create mode 100644 kyo-stm/shared/src/main/scala/kyo/STM.scala create mode 100644 kyo-stm/shared/src/main/scala/kyo/TID.scala create mode 100644 kyo-stm/shared/src/main/scala/kyo/TMap.scala create mode 100644 kyo-stm/shared/src/main/scala/kyo/TRef.scala create mode 100644 kyo-stm/shared/src/test/scala/kyo/RefLogTest.scala create mode 100644 kyo-stm/shared/src/test/scala/kyo/STMTest.scala create mode 100644 kyo-stm/shared/src/test/scala/kyo/TMapTest.scala create mode 100644 kyo-stm/shared/src/test/scala/kyo/TRefTest.scala create mode 100644 kyo-stm/shared/src/test/scala/kyo/Test.scala diff --git a/build.sbt b/build.sbt index 9ecf31f92..7301e885d 100644 --- a/build.sbt +++ b/build.sbt @@ -92,6 +92,7 @@ lazy val kyoJVM = project `kyo-prelude`.jvm, `kyo-core`.jvm, `kyo-direct`.jvm, + `kyo-stm`.jvm, `kyo-stats-registry`.jvm, `kyo-stats-otel`.jvm, `kyo-cache`.jvm, @@ -120,6 +121,7 @@ lazy val kyoJS = project `kyo-prelude`.js, `kyo-core`.js, `kyo-direct`.js, + `kyo-stm`.js, `kyo-stats-registry`.js, `kyo-sttp`.js, `kyo-test`.js, @@ -250,6 +252,16 @@ lazy val `kyo-direct` = .nativeSettings(`native-settings`) .jsSettings(`js-settings`) +lazy val `kyo-stm` = + crossProject(JSPlatform, JVMPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-stm")) + .dependsOn(`kyo-core`) + .settings(`kyo-settings`) + .jvmSettings(mimaCheck(false)) + .jsSettings(`js-settings`) + lazy val `kyo-stats-registry` = crossProject(JSPlatform, JVMPlatform, NativePlatform) .withoutSuffixFor(JVMPlatform) diff --git a/kyo-core/shared/src/test/scala/kyo/FiberTest.scala b/kyo-core/shared/src/test/scala/kyo/FiberTest.scala index 1aefeb25d..4bcbda499 100644 --- a/kyo-core/shared/src/test/scala/kyo/FiberTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/FiberTest.scala @@ -157,8 +157,8 @@ class FiberTest extends Test: adder <- LongAdder.init result <- Fiber.race(Seq( - Async.delay(15.millis)(adder.increment.andThen(24)), - Async.delay(5.millis)((adder.increment.andThen(42))) + Async.delay(1.second)(adder.increment.andThen(24)), + Async.delay(1.millis)((adder.increment.andThen(42))) )).map(_.get) _ <- Async.sleep(50.millis) executed <- adder.get diff --git a/kyo-core/shared/src/test/scala/kyo/Test.scala b/kyo-core/shared/src/test/scala/kyo/Test.scala index 935812c00..9f6d6b998 100644 --- a/kyo-core/shared/src/test/scala/kyo/Test.scala +++ b/kyo-core/shared/src/test/scala/kyo/Test.scala @@ -16,7 +16,11 @@ abstract class Test extends AsyncFreeSpec with BaseKyoTest[Abort[Any] & Async & def run(v: Future[Assertion] < (Abort[Any] & Async & Resource)): Future[Assertion] = import AllowUnsafe.embrace.danger - val a = Async.run(Abort.run(Resource.run(v)).map(_.fold(e => throw new IllegalStateException(s"Test aborted with $e"))(identity))) + val a = Async.run(Abort.run(Resource.run(v)).map(_.fold { + _.getFailure match + case ex: Throwable => throw ex + case e => throw new IllegalStateException(s"Test aborted with $e") + }(identity))) val b = a.map(_.toFuture).map(_.flatten) IO.Unsafe.run(b).eval end run diff --git a/kyo-prelude/shared/src/main/scala/kyo/Kyo.scala b/kyo-prelude/shared/src/main/scala/kyo/Kyo.scala index e44d2e1d8..c9be08e72 100644 --- a/kyo-prelude/shared/src/main/scala/kyo/Kyo.scala +++ b/kyo-prelude/shared/src/main/scala/kyo/Kyo.scala @@ -289,6 +289,45 @@ object Kyo: end match end collectDiscard + /** Finds the first element in a sequence that satisfies a predicate. + * + * @param seq + * The input sequence + * @param f + * The effect-producing predicate function + * @return + * A new effect that produces Maybe of the first matching element + */ + def findFirst[A, B, S](seq: Seq[A])(f: Safepoint ?=> A => Maybe[B] < S)(using Frame, Safepoint): Maybe[B] < S = + seq.knownSize match + case 0 => Maybe.empty + case 1 => f(seq(0)) + case _ => + seq match + case seq: List[A] => + Loop(seq) { seq => + seq match + case Nil => Loop.done(Maybe.empty) + case head :: tail => + f(head).map { + case Absent => Loop.continue(tail) + case Present(v) => Loop.done(Maybe(v)) + } + } + case seq => + val indexed = toIndexed(seq) + val size = indexed.size + Loop.indexed { idx => + if idx == size then Loop.done(Maybe.empty) + else + f(indexed(idx)).map { + case Absent => Loop.continue + case Present(v) => Loop.done(Maybe(v)) + } + } + end match + end findFirst + /** Takes elements from a sequence while a predicate holds true. * * @param seq diff --git a/kyo-prelude/shared/src/main/scala/kyo/Var.scala b/kyo-prelude/shared/src/main/scala/kyo/Var.scala index 5c92785a6..740c40975 100644 --- a/kyo-prelude/shared/src/main/scala/kyo/Var.scala +++ b/kyo-prelude/shared/src/main/scala/kyo/Var.scala @@ -60,6 +60,21 @@ object Var: inline def set[V](inline value: V)(using inline tag: Tag[Var[V]], inline frame: Frame): V < Var[V] = ArrowEffect.suspend[Unit](tag, value: Op[V]) + /** Sets a new value and then executes another computation. + * + * @param value + * The new value to set in the Var + * @param f + * The computation to execute after setting the value + * @return + * The result of the computation after setting the new value + */ + private[kyo] inline def setAndThen[V, A, S](inline value: V)(inline f: => A < S)(using + inline tag: Tag[Var[V]], + inline frame: Frame + ): A < (Var[V] & S) = + ArrowEffect.suspendAndMap[Unit](tag, value: Op[V])(_ => f) + /** Sets a new value and returns `Unit`. * * @param value @@ -98,7 +113,7 @@ object Var: inline def updateDiscard[V](inline f: V => V)(using inline tag: Tag[Var[V]], inline frame: Frame): Unit < Var[V] = ArrowEffect.suspendAndMap[Unit](tag, (v => f(v)): Update[V])(_ => ()) - private inline def runWith[V, A: Flat, S, B, S2](state: V)(v: A < (Var[V] & S))( + private[kyo] inline def runWith[V, A: Flat, S, B, S2](state: V)(v: A < (Var[V] & S))( inline f: (V, A) => B < S2 )(using inline tag: Tag[Var[V]], inline frame: Frame): B < (S & S2) = ArrowEffect.handleState(tag, state, v)( diff --git a/kyo-prelude/shared/src/test/scala/kyo/KyoTest.scala b/kyo-prelude/shared/src/test/scala/kyo/KyoTest.scala index 8633fb2bc..5fcc9b8d7 100644 --- a/kyo-prelude/shared/src/test/scala/kyo/KyoTest.scala +++ b/kyo-prelude/shared/src/test/scala/kyo/KyoTest.scala @@ -286,4 +286,61 @@ class KyoTest extends Test: assert(result.eval == 10) } } + + "findFirst" - { + "empty sequence" in { + assert(Kyo.findFirst(Seq.empty[Int])(v => Maybe(v)).eval == Maybe.empty) + } + + "single element - found" in { + assert(Kyo.findFirst(Seq(1))(v => Maybe(v)).eval == Maybe(1)) + } + + "single element - not found" in { + assert(Kyo.findFirst(Seq(1))(v => Maybe.empty).eval == Maybe.empty) + } + + "multiple elements - first match" in { + assert(Kyo.findFirst(Seq(1, 2, 3))(v => if v > 0 then Maybe(v) else Maybe.empty).eval == Maybe(1)) + } + + "multiple elements - middle match" in { + assert(Kyo.findFirst(Seq(1, 2, 3))(v => if v == 2 then Maybe(v) else Maybe.empty).eval == Maybe(2)) + } + + "multiple elements - no match" in { + assert(Kyo.findFirst(Seq(1, 2, 3))(v => if v > 5 then Maybe(v) else Maybe.empty).eval == Maybe.empty) + } + + "works with effects" in { + var count = 0 + val result = Env.run(42)( + Kyo.findFirst(Seq(1, 2, 3)) { v => + Env.use[Int] { env => + count += 1 + if v == env then Maybe(v) else Maybe.empty + } + } + ).eval + assert(result == Maybe.empty) + assert(count == 3) + } + + "short circuits" in { + var count = 0 + val result = Kyo.findFirst(Seq(1, 2, 3, 4, 5)) { v => + count += 1 + if v == 2 then Maybe(v) else Maybe.empty + }.eval + assert(result == Maybe(2)) + assert(count == 2) + } + + "works with different sequence types" in { + val pred = (v: Int) => if v == 2 then Maybe(v) else Maybe.empty + assert(Kyo.findFirst(List(1, 2, 3))(pred).eval == Maybe(2)) + assert(Kyo.findFirst(Vector(1, 2, 3))(pred).eval == Maybe(2)) + assert(Kyo.findFirst(Chunk(1, 2, 3))(pred).eval == Maybe(2)) + } + } end KyoTest diff --git a/kyo-prelude/shared/src/test/scala/kyo/VarTest.scala b/kyo-prelude/shared/src/test/scala/kyo/VarTest.scala index a380e592c..7a8eceb38 100644 --- a/kyo-prelude/shared/src/test/scala/kyo/VarTest.scala +++ b/kyo-prelude/shared/src/test/scala/kyo/VarTest.scala @@ -51,6 +51,13 @@ class VarTest extends Test: assert(r == (2, 3)) } + "setAndThen" in { + val result = Var.run(1) { + Var.setAndThen(2)(Var.use[Int](_ * 2)) + }.eval + assert(result == 4) + } + "scope" - { "should not affect the outer state" in { val result = Var.run(42)( diff --git a/kyo-stm/shared/src/main/scala/kyo/RefLog.scala b/kyo-stm/shared/src/main/scala/kyo/RefLog.scala new file mode 100644 index 000000000..ef1c40752 --- /dev/null +++ b/kyo-stm/shared/src/main/scala/kyo/RefLog.scala @@ -0,0 +1,40 @@ +package kyo + +/** A log of transactional operations performed on TRefs within an STM transaction. + * + * RefLog maintains a mapping from transactional references to their pending read/write operations within a transaction. It tracks both + * read entries (which record the version of data read) and write entries (which contain the new values to be committed). + * + * This type is used internally by the STM implementation and should not be accessed directly by application code. + * + * @note + * This is a private implementation detail of the STM system + */ +opaque type RefLog = Map[TRef[Any], RefLog.Entry[Any]] + +private[kyo] object RefLog: + + given Tag[RefLog] = Tag[Map[TRef[Any], Entry[Any]]] + + val empty: RefLog = Map.empty + + extension (self: RefLog) + + def put[A](ref: TRef[A], entry: Entry[A]): RefLog = + self.updated(ref.asInstanceOf[TRef[Any]], entry.asInstanceOf[Entry[Any]]) + + def get[A](ref: TRef[A]): Maybe[Entry[A]] = + val refAny = ref.asInstanceOf[TRef[Any]] + Maybe.when(self.contains(refAny))(self(refAny).asInstanceOf[Entry[A]]) + + def toSeq: Seq[(TRef[Any], Entry[Any])] = + self.toSeq + end extension + + sealed abstract class Entry[A]: + def tid: Long + def value: A + + case class Read[A](tid: Long, value: A) extends Entry[A] + case class Write[A](tid: Long, value: A) extends Entry[A] +end RefLog diff --git a/kyo-stm/shared/src/main/scala/kyo/STM.scala b/kyo-stm/shared/src/main/scala/kyo/STM.scala new file mode 100644 index 000000000..9c65b94b2 --- /dev/null +++ b/kyo-stm/shared/src/main/scala/kyo/STM.scala @@ -0,0 +1,148 @@ +package kyo + +import scala.annotation.tailrec + +/** A FailedTransaction exception that is thrown when a transaction fails to commit. Contains the frame where the failure occurred. + */ +case class FailedTransaction(frame: Frame) extends Exception(frame.position.show) + +/** Software Transactional Memory (STM) provides concurrent access to shared state using optimistic locking. Rather than acquiring locks + * upfront, transactions execute speculatively and automatically retry if conflicts are detected during commit. While this enables better + * composability than manual locking, applications must be designed to handle potentially frequent transaction retries. + * + * > IMPORTANT: Transactions are atomic, isolated, and composable but may retry multiple times before success. Side effects (like I/O) + * inside transactions must be used with caution as they will be re-executed on retry. Pure operations that only modify transactional + * references are safe and encouraged, while external side effects should be performed after the transaction commits. + * + * The core operations are: + * - TRef.init and TRef.initNow create transactional references that can be shared between threads + * - TRef.get and TRef.set read and modify references within transactions + * - STM.run executes transactions that either fully commit or rollback + * - STM.retry and STM.retryIf provide manual control over transaction retry behavior + * - Configurable retry schedules via STM.run's retrySchedule parameter + * + * The implementation uses optimistic execution with lock-based validation during commit: + * - Transactions execute without acquiring locks, tracking reads and writes in a local log + * - During commit, read-write locks are acquired on affected TRefs to ensure consistency: + * - Multiple readers can hold shared locks on a TRef during commit + * - Writers require an exclusive lock during commit + * - No global locks are used - operations on different refs can commit independently + * - Lock acquisition is ordered by TRef identity to prevent deadlocks + * - Early conflict detection aborts transactions that would fail validation + * + * STM is most effective for operations that rarely conflict and complete quickly. Long-running transactions or high contention scenarios + * may face performance challenges from repeated retries. The approach particularly excels at read-heavy workloads due to its support for + * concurrent readers, while write-heavy workloads may experience more contention due to the need for exclusive write access. The + * fine-grained locking strategy means that transactions only conflict if they actually touch the same references, allowing for high + * concurrency when different transactions operate on different refs. + */ +opaque type STM <: (Var[RefLog] & Abort[FailedTransaction] & Async) = + Var[RefLog] & Abort[FailedTransaction] & Async + +object STM: + + /** The default retry schedule for failed transactions */ + val defaultRetrySchedule = Schedule.fixed(1.millis * 0.5).take(20) + + /** Forces a transaction retry by aborting the current transaction and rolling back all changes. This is useful when a transaction + * detects that it cannot proceed due to invalid state. + * + * @return + * Nothing, as this operation always aborts the transaction + */ + def retry(using frame: Frame): Nothing < STM = Abort.fail(FailedTransaction(frame)) + + /** Conditionally retries a transaction based on a boolean condition. If the condition is true, the transaction will be retried. + * Otherwise, execution continues normally. + * + * @param cond + * The condition that determines whether to retry + */ + def retryIf(cond: Boolean)(using frame: Frame): Unit < STM = Abort.when(cond)(FailedTransaction(frame)) + + /** Executes a transactional computation with explicit state isolation. This version of run supports additional effects beyond Abort and + * Async through the provided isolate, which ensures proper state management during transaction retries and rollbacks. + * + * @param isolate + * The isolation scope for the transaction + * @param retrySchedule + * The schedule for retrying failed transactions + * @param v + * The transactional computation to run + * @return + * The result of the computation if successful + */ + def run[E, A: Flat, S](isolate: Isolate[S], retrySchedule: Schedule = defaultRetrySchedule)(v: A < (STM & Abort[E] & Async & S))( + using frame: Frame + ): A < (S & Async & Abort[E | FailedTransaction]) = + isolate.use { st => + run(retrySchedule)(isolate.resume(st, v)).map(isolate.restore(_, _)) + } + + /** Executes a transactional computation with default retry behavior. This version only supports Abort and Async effects within the + * transaction, but provides a simpler interface when additional effect isolation is not needed. + * + * @param v + * The transactional computation to run + * @return + * The result of the computation if successful + */ + def run[E, A: Flat](v: A < (STM & Abort[E] & Async))(using frame: Frame): A < (Async & Abort[E | FailedTransaction]) = + run(defaultRetrySchedule)(v) + + /** Executes a transactional computation with custom retry behavior. Like the version above, this only supports Abort and Async effects + * but allows configuring how transaction conflicts are retried. + * + * @param retrySchedule + * The schedule for retrying failed transactions + * @param v + * The transactional computation to run + * @return + * The result of the computation if successful + */ + def run[E, A: Flat](retrySchedule: Schedule)(v: A < (STM & Abort[E] & Async))( + using frame: Frame + ): A < (Async & Abort[E | FailedTransaction]) = + TID.use { + case -1L => + // New transaction without a parent, use regular commit flow + Retry[FailedTransaction](retrySchedule) { + TID.useNew { tid => + Var.runWith(RefLog.empty)(v) { (log, result) => + IO.Unsafe { + // Attempt to acquire locks and commit the transaction + val (locked, unlocked) = + // Sort references by identity to prevent deadlocks + log.toSeq.sortBy((ref, _) => ref.hashCode) + .span((ref, entry) => ref.lock(entry)) + + if unlocked.nonEmpty then + // Failed to acquire some locks - rollback and retry + locked.foreach((ref, entry) => ref.unlock(entry)) + Abort.fail(FailedTransaction(frame)) + else + // Successfully locked all references - commit changes + locked.foreach((ref, entry) => ref.commit(tid, entry)) + // Release all locks + locked.foreach((ref, entry) => ref.unlock(entry)) + result + end if + } + } + } + } + case parent => + // Nested transaction inherits parent's transaction context but isolates RefLog. + // On success: changes propagate to parent. On failure: changes are rolled back + // without affecting parent's state. + val result = Var.isolate.update[RefLog].run(v) + + // Can't return `result` directly since it has a pending STM effect + // but it's safe to cast because, if there's a parent transaction, + // then there's a frame upper in the stack that will handle the + // STM effect in the parent transaction's `run`. + result.asInstanceOf[A < (Async & Abort[E | FailedTransaction])] + } + + end run +end STM diff --git a/kyo-stm/shared/src/main/scala/kyo/TID.scala b/kyo-stm/shared/src/main/scala/kyo/TID.scala new file mode 100644 index 000000000..e9795c034 --- /dev/null +++ b/kyo-stm/shared/src/main/scala/kyo/TID.scala @@ -0,0 +1,26 @@ +package kyo + +private[kyo] object TID: + + // Unique transaction ID generation + private val nextTid = AtomicLong.Unsafe.init(0)(using AllowUnsafe.embrace.danger) + + private val tidLocal = Local.initIsolated(-1L) + + def next(using AllowUnsafe): Long = nextTid.incrementAndGet() + + def useNew[A, S](f: Long => A < S)(using Frame): A < (S & IO) = + IO.Unsafe { + val tid = nextTid.incrementAndGet() + tidLocal.let(tid)(f(tid)) + } + + def use[A, S](f: Long => A < S)(using Frame): A < S = + tidLocal.use(f) + + def useRequired[A, S](f: Long => A < S)(using Frame): A < S = + tidLocal.use { + case -1L => bug("STM operation attempted outside of STM.run - this should be impossible due to effect typing") + case tid => f(tid) + } +end TID diff --git a/kyo-stm/shared/src/main/scala/kyo/TMap.scala b/kyo-stm/shared/src/main/scala/kyo/TMap.scala new file mode 100644 index 000000000..b37f4958e --- /dev/null +++ b/kyo-stm/shared/src/main/scala/kyo/TMap.scala @@ -0,0 +1,297 @@ +package kyo + +/** A transactional map implementation that provides atomic operations on key-value pairs within STM transactions. Internally represented as + * `TRef[Map[K, TRef[V]]]`, where each value is wrapped in its own transactional reference. + * + * TMap is designed to minimize contention in concurrent scenarios through this nested TRef structure. Since each value has its own `TRef`, + * operations on different keys can commit independently. Only structural changes like `put` with new keys or `remove` need to lock the + * main map during commit, while value updates through `updateWith` or `put` to existing keys only lock that specific value's TRef. + * + * This architecture is particularly effective at reducing retries in concurrent scenarios by limiting the scope of conflicts between + * transactions. Updates to different keys can proceed in parallel, while reads can commit concurrently with writes to different keys. The + * implementation maintains strong consistency guarantees while allowing maximum concurrency for non-conflicting operations. + * + * Operations that modify existing values have low contention characteristics, while structural modifications experience higher contention. + * This makes TMap particularly well-suited for scenarios with high concurrent access to different keys where most operations are reads or + * updates to existing values. + */ +opaque type TMap[K, V] = TRef[Map[K, TRef[V]]] + +object TMap: + + given [K, V]: Flat[TMap[K, V]] = Flat.derive[TRef[Map[K, TRef[V]]]] + + /** Creates a new transactional map within an STM transaction. + * + * @param entries + * The initial key-value pairs to populate the map + * @return + * A new transactional map containing the entries, within the STM effect + */ + def init[K, V](entries: (K, V)*)(using Frame): TMap[K, V] < STM = + Kyo.foreach(entries)((k, v) => TRef.init(v).map((k, _))).map(r => TRef.init(r.toMap)) + + /** Creates a new transactional map outside of any transaction. + * + * WARNING: This operation: + * - Cannot be rolled back + * - Is not part of any transaction + * - Will cause any containing transaction to retry if used within one, since it creates references with newer transaction IDs + * + * Use this only for static initialization or when you specifically need non-transactional creation. For most cases, prefer `init`. + * + * @param entries + * The initial key-value pairs to populate the map + * @return + * A new transactional map containing the entries, within the IO effect + */ + def initNow[K, V](entries: (K, V)*)(using Frame): TMap[K, V] < IO = + Kyo.foreach(entries)((k, v) => TRef.initNow(v).map((k, _))).map(r => TRef.initNow(r.toMap)) + + /** Initializes a new transactional map from an existing Map. + * + * @param map + * the initial map to copy entries from + * @return + * a new transactional map containing the provided map's entries + */ + def init[K, V](map: Map[K, V])(using Frame): TMap[K, V] < STM = + init(map.toSeq*) + + extension [K, V](self: TMap[K, V]) + + /** Applies a function to the value associated with a key if it exists. + * + * @param key + * the key to look up + * @param f + * the function to apply to the value if found + * @return + * the result of applying the function + */ + def use[A, S](key: K)(f: Maybe[V] => A < S)(using Frame): A < (STM & S) = + self.use { map => + if map.contains(key) then + map(key).use(v => f(Maybe(v))) + else + f(Maybe.empty) + } + + /** Returns the current size of the map. + * + * @return + * the number of key-value pairs in the map + */ + def size(using Frame): Int < STM = self.use(_.size) + + /** Checks if the map is empty. + * + * @return + * true if the map contains no entries, false otherwise + */ + def isEmpty(using Frame): Boolean < STM = self.use(_.isEmpty) + + /** Checks if the map is non-empty. + * + * @return + * true if the map contains at least one entry, false otherwise + */ + def nonEmpty(using Frame): Boolean < STM = self.use(_.nonEmpty) + + /** Removes all entries from the map. + */ + def clear(using Frame): Unit < STM = self.set(Map.empty) + + /** Retrieves the value associated with a key. + * + * @param key + * the key to look up + * @return + * the value if found, or Maybe.empty if not present + */ + def get(key: K)(using Frame): Maybe[V] < STM = + use(key)(identity) + + /** Retrieves the value for a key, or evaluates a default if not found. + * + * @param key + * the key to look up + * @param orElse + * the default value to compute if key is not found + * @return + * the value associated with the key or the computed default + */ + inline def getOrElse[A, S](key: K, inline orElse: => V < S)(using inline frame: Frame): V < (STM & S) = + self.use(key) { + case Absent => orElse + case Present(v) => v + } + + /** Adds a new key-value pair to the map. + * + * @param key + * the key to add + * @param value + * the value to associate with the key + */ + def put(key: K, value: V)(using Frame): Unit < STM = + self.use { map => + if map.contains(key) then + map(key).set(value) + else + TRef.init(value).map { ref => + self.update(_.updated(key, ref)) + } + } + + /** Checks if a key exists in the map. + * + * @param key + * the key to check + * @return + * true if the key exists, false otherwise + */ + def contains(key: K)(using Frame): Boolean < STM = + self.use(!_.isEmpty) + + /** Updates the value associated with a key based on its current value. + * + * @param key + * the key to update + * @param f + * the function to transform the current value + */ + def updateWith[S](key: K)(f: Maybe[V] => Maybe[V] < S)(using Frame): Unit < (STM & S) = + use(key) { currentValue => + f(currentValue).map { + case Absent => self.update(_ - key) + case Present(v) => put(key, v) + } + } + + /** Removes a key and returns its associated value if it existed. + * + * @param key + * the key to remove + * @return + * the value that was associated with the key, if any + */ + def remove(key: K)(using Frame): Maybe[V] < STM = + use(key) { + case Absent => Absent + case Present(value) => + self.update(_ - key).andThen(Maybe(value)) + } + + /** Removes a key without returning its value. + * + * @param key + * the key to remove + */ + def removeDiscard(key: K)(using Frame): Unit < STM = + use(key) { + case Absent => () + case Present(value) => + self.update(_ - key) + } + + /** Removes multiple keys from the map. + * + * @param keys + * the sequence of keys to remove + */ + def removeAll(keys: Seq[K])(using Frame): Unit < STM = + self.use { map => + self.set(map.removedAll(keys)) + } + + /** Returns an iterable of all keys in the map. + * + * @return + * iterable containing all keys + */ + def keys(using Frame): Iterable[K] < STM = + self.use(_.keys) + + /** Returns an iterable of all values in the map. + * + * @return + * iterable containing all values + */ + def values(using Frame): Iterable[V] < STM = + self.use { map => + Kyo.collect(map.values.toSeq.map(_.get)) + } + + /** Returns an iterable of all key-value pairs in the map. + * + * @return + * iterable containing all entries + */ + def entries(using Frame): Iterable[(K, V)] < STM = + self.use { map => + Kyo.collect( + map.toSeq.map { case (k, ref) => + ref.get.map((k, _)) + } + ) + } + + /** Removes entries that don't satisfy the given predicate. + * + * @param p + * the predicate function to test entries against + */ + def filter[S](p: (K, V) => Boolean < S)(using Frame): Unit < (STM & S) = + self.use { map => + Kyo.foreachDiscard(map.toSeq) { (key, ref) => + ref.use { value => + p(key, value).map { + case true => () + case false => removeDiscard(key) + } + } + } + } + + /** Folds over the entries in the map to produce a result. + * + * @param acc + * the initial accumulator value + * @param f + * the function to combine the accumulator with each entry + * @return + * the final accumulated result + */ + def fold[A, B, S](acc: A)(f: (A, K, V) => A < S)(using Frame): A < (STM & S) = + self.use { map => + Kyo.foldLeft(map.toSeq)(acc) { + case (acc, (key, ref)) => + ref.use(v => f(acc, key, v)) + } + } + + /** Finds the first entry that satisfies the given predicate. + * + * @param f + * the function to test entries + * @return + * the first result that matches, if any + */ + def findFirst[A, S](f: (K, V) => Maybe[A] < S)(using Frame): Maybe[A] < (STM & S) = + self.use { map => + Kyo.findFirst(map.toSeq) { (key, ref) => + ref.use(f(key, _)) + } + } + + /** Creates an immutable snapshot of the current map state. + * + * @return + * a Map containing the current entries + */ + def snapshot(using Frame): Map[K, V] < STM = + entries.map(_.toMap) + + end extension +end TMap diff --git a/kyo-stm/shared/src/main/scala/kyo/TRef.scala b/kyo-stm/shared/src/main/scala/kyo/TRef.scala new file mode 100644 index 000000000..0a522e310 --- /dev/null +++ b/kyo-stm/shared/src/main/scala/kyo/TRef.scala @@ -0,0 +1,189 @@ +package kyo + +import java.util.concurrent.atomic.AtomicInteger +import kyo.RefLog.* +import scala.annotation.tailrec + +/** A transactional reference that can be modified within STM transactions. Provides atomic read and write operations with strong + * consistency guarantees. + * + * @param id + * Unique identifier for this reference + * @param state + * The current state of the reference + */ +sealed trait TRef[A]: + + /** Applies a function to the current value of the reference within a transaction. + * + * @param f + * A function that transforms the current value of type A into a result of type B, with effects S + * @return + * The result of type B with combined STM and S effects + */ + def use[B, S](f: A => B < S)(using Frame): B < (STM & S) + + /** Sets a new value for the reference within a transaction. + * + * @param v + * The new value to set + */ + def set(v: A)(using Frame): Unit < STM + + /** Gets the current value of the reference within a transaction. + * + * @return + * The current value + */ + final def get(using Frame): A < STM = use(identity) + + /** Updates the reference's value by applying a function to the current value within a transaction. + * + * @param f + * The function to transform the current value into the new value + * @return + * Unit, as this is a modification operation + */ + final def update[S](f: A => A < S)(using Frame): Unit < (STM & S) = use(f(_).map(set)) + + private[kyo] def state(using AllowUnsafe): Write[A] + private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean + private[kyo] def commit(tid: Long, entry: Entry[A])(using AllowUnsafe): Unit + private[kyo] def unlock(entry: Entry[A])(using AllowUnsafe): Unit + +end TRef + +/** Implementation of a transactional reference. Extends AtomicInteger to avoid an extra allocation for lock state management. + * + * @param initialState + * The initial value and transaction ID for this reference + */ +final private class TRefImpl[A] private[kyo] (initialState: Write[A]) + extends AtomicInteger(0) // Atomic super class to keep the lock state + with TRef[A]: + + @volatile private var currentState = initialState + + private[kyo] def state(using AllowUnsafe): Write[A] = currentState + + def use[B, S](f: A => B < S)(using Frame): B < (STM & S) = + Var.use[RefLog] { log => + log.get(this) match + case Present(entry) => + f(entry.value) + case Absent => + TID.useRequired { tid => + IO { + val state = currentState + if state.tid > tid then + // Early retry if the TRef is concurrently modified + STM.retry + else + // Append Read to the log and return value + val entry = Read(state.tid, state.value) + Var.setAndThen(log.put(this, entry))(f(state.value)) + end if + } + } + end match + } + + def set(v: A)(using Frame): Unit < STM = + Var.use[RefLog] { log => + log.get(this) match + case Present(prev) => + val entry = Write(prev.tid, v) + Var.setDiscard(log.put(this, entry)) + case Absent => + TID.useRequired { tid => + IO { + val state = currentState + if state.tid > tid then + // Early retry if the TRef is concurrently modified + STM.retry + else + // Append Write to the log + val entry = Write(state.tid, v) + Var.setDiscard(log.put(this, entry)) + end if + } + } + } + + @tailrec private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean = + currentState.tid == entry.tid && { + val lockState = super.get() + entry match + case Read(tid, value) => + // Read locks can stack if no write lock + lockState != Int.MaxValue && (super.compareAndSet(lockState, lockState + 1) || lock(entry)) + case Write(tid, value) => + // Write lock requires no existing locks + lockState == 0 && (super.compareAndSet(lockState, Int.MaxValue) || lock(entry)) + end match + } + + private[kyo] def commit(tid: Long, entry: Entry[A])(using AllowUnsafe): Unit = + entry match + case Write(_, value) => + // Only need to commit Write entries + currentState = Write(tid, value) + case _ => + + private[kyo] def unlock(entry: Entry[A])(using AllowUnsafe): Unit = + entry match + case Read(tid, value) => + // Release read lock + discard(super.decrementAndGet()) + case Write(tid, value) => + // Release write lock + super.set(0) + end match + end unlock +end TRefImpl + +object TRef: + + /** Creates a new transactional reference within an STM transaction. + * + * @param value + * The initial value for the reference + * @return + * A new transactional reference containing the value, within the STM effect + */ + def init[A](value: A)(using Frame): TRef[A] < STM = + TID.useRequired { tid => + Var.use[RefLog] { log => + IO.Unsafe { + val ref = TRef.Unsafe.init(tid, value) + Var.setAndThen(log.put(ref, ref.state))(ref) + } + } + } + + /** Creates a new transactional reference outside of any transaction. + * + * WARNING: This operation: + * - Cannot be rolled back + * - Is not part of any transaction + * - Will cause any containing transaction to retry if used within one, since it creates a reference with a newer transaction ID + * + * Use this only for static initialization or when you specifically need non-transactional creation. For most cases, prefer `init`. + * + * @param value + * The initial value for the reference + * @return + * A new transactional reference containing the value, within the IO effect + */ + def initNow[A](value: A)(using Frame): TRef[A] < IO = + IO.Unsafe(TRef.Unsafe.initNow(value)) + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + object Unsafe: + def initNow[A](value: A)(using AllowUnsafe): TRef[A] = + init(TID.next, value) + + private[kyo] def init[A](tid: Long, value: A)(using AllowUnsafe): TRef[A] = + new TRefImpl(Write(tid, value)) + end Unsafe +end TRef diff --git a/kyo-stm/shared/src/test/scala/kyo/RefLogTest.scala b/kyo-stm/shared/src/test/scala/kyo/RefLogTest.scala new file mode 100644 index 000000000..595784729 --- /dev/null +++ b/kyo-stm/shared/src/test/scala/kyo/RefLogTest.scala @@ -0,0 +1,54 @@ +package kyo + +import kyo.RefLog.* + +class RefLogTest extends Test: + + given [A, B]: CanEqual[A, B] = CanEqual.derived + + "RefLog" - { + "empty" in run { + val log = RefLog.empty + assert(log.toSeq.isEmpty) + } + + "put" in run { + IO { + val ref = new TRefImpl[Int](Write(0, 0)) + val entry = Write(1, 42) + val log = RefLog.empty.put(ref, entry) + assert(log.toSeq.size == 1) + assert(log.toSeq.head._1 == ref) + assert(log.toSeq.head._2 == entry) + } + } + + "get" in run { + IO { + val ref = new TRefImpl[Int](Write(0, 0)) + val entry = Write(1, 42) + val log = RefLog.empty.put(ref, entry) + assert(log.get(ref) == Maybe(entry)) + assert(log.get(new TRefImpl[Int](Write(0, 0))).isEmpty) + } + } + + "toSeq" in run { + IO { + val ref1 = new TRefImpl[Int](Write(0, 0)) + val ref2 = new TRefImpl[Int](Write(0, 0)) + val entry1 = Write(1, 42) + val entry2 = Read(1, 24) + + val log = RefLog.empty + .put(ref1, entry1) + .put(ref2, entry2) + + val seq = log.toSeq + assert(seq.size == 2) + assert(seq.contains((ref1, entry1))) + assert(seq.contains((ref2, entry2))) + } + } + } +end RefLogTest diff --git a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala new file mode 100644 index 000000000..b3332ffcd --- /dev/null +++ b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala @@ -0,0 +1,717 @@ +package kyo + +class STMTest extends Test: + + "Transaction isolation" - { + "concurrent modifications" in run { + for + ref <- TRef.initNow(0) + fibers <- Async.parallelUnbounded(List.fill(100)(STM.run(ref.update(_ + 1)))) + value <- STM.run(ref.get) + yield assert(value == 100) + } + + "no dirty reads" in run { + for + ref <- TRef.initNow(0) + start <- Latch.init(1) + continue <- Latch.init(1) + fiber <- Async.run { + STM.run { + for + _ <- ref.set(42) + _ <- start.release + _ <- continue.await + yield () + } + } + _ <- start.await + before <- STM.run(ref.get) + _ <- continue.release + _ <- fiber.get + after <- STM.run(ref.get) + yield assert(before == 0 && after == 42) + } + + "independent transactions don't interfere" in run { + for + ref1 <- TRef.initNow(10) + ref2 <- TRef.initNow(20) + _ <- STM.run(ref1.set(30)) + result <- STM.run { + for + v <- ref2.get + _ <- ref2.set(v + 5) + yield v + } + final1 <- STM.run(ref1.get) + final2 <- STM.run(ref2.get) + yield assert(result == 20 && final1 == 30 && final2 == 25) + } + + } + + "Retry behavior" - { + "explicit retry" in run { + for + ref <- TRef.initNow(0) + result <- Abort.run { + STM.run { + for + v <- ref.get + _ <- STM.retryIf(v == 0) + yield v + } + } + yield assert(result.isFail) + } + + "retry with schedule" in run { + for + ref <- TRef.initNow(0) + counter <- AtomicInt.init(0) + result <- Abort.run { + STM.run(Schedule.repeat(3)) { + for + _ <- counter.incrementAndGet + v <- ref.get + _ <- STM.retryIf(v == 0) + yield v + } + } + count <- counter.get + yield assert(result.isFail && count == 4) + } + } + + "with isolates" - { + + "with Var effect" in run { + Var.run(42) { + for + ref <- TRef.initNow(0) + result <- STM.run(Var.isolate.update) { + for + _ <- ref.set(1) + _ <- Var.set(100) + v1 <- ref.get + v2 <- Var.get[Int] + yield (v1, v2) + } + finalRef <- STM.run(ref.get) + finalVar <- Var.get[Int] + yield assert(result == (1, 100) && finalRef == 1 && finalVar == 100) + } + } + + "with Emit effect" in run { + for + ref <- TRef.initNow(0) + result <- Emit.run { + STM.run(Emit.isolate.merge[Int]) { + for + _ <- ref.set(1) + _ <- Emit(42) + v <- ref.get + _ <- Emit(v) + yield v + } + } + finalValue <- STM.run(ref.get) + yield assert(result == (Chunk(42, 1), 1) && finalValue == 1) + } + + "rollback on failure preserves effect isolation" in run { + val ex = new Exception("Test failure") + for + ref <- TRef.initNow(0) + result <- + Emit.run { + Abort.run { + STM.run(Emit.isolate.merge[Int]) { + for + _ <- ref.set(42) + _ <- Emit(1) + _ <- Abort.fail(ex) + _ <- Emit(2) + yield "unreachable" + } + } + } + finalValue <- STM.run(ref.get) + yield assert(result == (Chunk.empty, Result.fail(ex)) && finalValue == 0) + end for + } + + "with nested Var isolations" in run { + Var.run(0) { + for + ref <- TRef.initNow(1) + result <- STM.run(Var.isolate.update) { + for + _ <- ref.set(2) + _ <- Var.set(1) + innerResult <- Var.isolate.update.run { + for + _ <- ref.set(3) + _ <- Var.set(2) + v1 <- ref.get + v2 <- Var.get[Int] + yield (v1, v2) + } + outerVar <- Var.get[Int] + finalRef <- ref.get + yield (innerResult, outerVar, finalRef) + } + finalVar <- Var.get[Int] + yield assert(result == ((3, 2), 2, 3) && finalVar == 2) + } + } + + "with Memo effect" in run { + var count = 0 + val f = Memo[Int, Int, Any] { x => + count += 1 + x * 2 + } + + Memo.run { + for + ref <- TRef.initNow(1) + result <- STM.run(Memo.isolate.merge) { + for + _ <- ref.set(2) + v1 <- f(2) + _ <- ref.set(3) + v2 <- f(2) + refVal <- ref.get + yield (v1, v2, refVal) + } + v3 <- f(2) + finalValue <- STM.run(ref.get) + yield assert(result == (4, 4, 3) && count == 1 && v3 == 4 && finalValue == 3) + } + } + + "rollback preserves all effect isolations" in run { + val ex = new Exception("Test failure") + Var.run(0) { + for + ref <- TRef.initNow(0) + result <- Emit.run { + Abort.run { + STM.run(Emit.isolate.merge[Int].andThen(Var.isolate.update)) { + for + _ <- ref.set(1) + _ <- Emit(1) + _ <- Var.set(1) + _ <- Abort.fail(ex) + _ <- ref.set(2) + _ <- Emit(2) + _ <- Var.set(2) + yield "unreachable" + } + } + } + finalRef <- STM.run(ref.get) + finalVar <- Var.get[Int] + yield assert(result == (Chunk.empty, Result.fail(ex)) && finalRef == 0 && finalVar == 0) + } + } + + } + + "Nested transactions" - { + + "nested transactions share the same transaction context" in run { + for + ref <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref.set(1) + innerResult <- STM.run { + for + v1 <- ref.get + _ <- ref.set(v1 + 1) + v2 <- ref.get + yield v2 + } + finalValue <- ref.get + yield (innerResult, finalValue) + } + outsideValue <- STM.run(ref.get) + yield assert(result == (2, 2) && outsideValue == 2) + } + + "nested transaction rollbacks affect outer transaction" in run { + for + ref <- TRef.initNow(0) + _ <- + STM.run { + for + _ <- ref.set(1) + result <- + Abort.run { + STM.run { + for + _ <- ref.set(2) + _ <- STM.retry + yield () + } + } + yield assert(result.isFail) + } + finalValue <- STM.run(ref.get) + yield assert(finalValue == 1) + } + + "multiple levels of nesting maintain consistency" in run { + for + ref <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref.set(1) + v1 <- STM.run { + for + _ <- ref.set(2) + v2 <- STM.run(ref.get) + _ <- ref.set(3) + yield v2 + } + v3 <- ref.get + yield (v1, v3) + } + finalValue <- STM.run(ref.get) + yield assert(result == (2, 3) && finalValue == 3) + } + + "nested transactions see parent modifications" in run { + for + ref <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref.set(1) + v1 <- ref.get + nestedResult <- STM.run { + for + v2 <- ref.get + _ <- ref.set(2) + v3 <- ref.get + yield (v2, v3) + } + v4 <- ref.get + yield (v1, nestedResult, v4) + } + yield assert(result == (1, (1, 2), 2)) + } + + "nested transaction modifications are visible to parent" in run { + for + ref1 <- TRef.initNow(0) + ref2 <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref1.set(1) + nestedResult <- STM.run { + for + v1 <- ref1.get + _ <- ref2.set(2) + yield v1 + } + v2 <- ref2.get + _ <- ref1.set(3) + yield (nestedResult, v2) + } + finalValues <- STM.run { + for + v1 <- ref1.get + v2 <- ref2.get + yield (v1, v2) + } + yield assert(result == (1, 2) && finalValues == (3, 2)) + } + + "sequential nested transactions see previous changes" in run { + for + ref <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref.set(1) + r1 <- STM.run { + for + v1 <- ref.get + _ <- ref.set(2) + yield v1 + } + r2 <- STM.run { + for + v2 <- ref.get + _ <- ref.set(3) + yield v2 + } + r3 <- ref.get + yield (r1, r2, r3) + } + yield assert(result == (1, 2, 3)) + } + + "nested transaction rollback preserves parent changes" in run { + for + ref <- TRef.initNow(0) + result <- STM.run { + for + _ <- ref.set(1) + r1 <- ref.get + r2 <- Abort.run { + STM.run { + for + v <- ref.get + _ <- ref.set(2) + _ <- STM.retry + yield v + } + } + r3 <- ref.get + _ <- ref.set(3) + yield (r1, r2.isFail, r3) + } + finalValue <- STM.run(ref.get) + yield assert(result == (1, true, 1) && finalValue == 3) + } + } + + "Error handling" - { + + "transaction rollback on failure" in run { + for + ref <- TRef.initNow(42) + result <- Abort.run { + STM.run { + for + _ <- ref.set(100) + _ <- Abort.fail(new Exception("Test failure")) + yield () + } + } + value <- STM.run(ref.get) + yield assert(result.isFail && value == 42) + } + + "multiple refs rollback on failure" in run { + for + ref1 <- TRef.initNow(10) + ref2 <- TRef.initNow(20) + result <- Abort.run { + STM.run { + for + _ <- ref1.set(30) + _ <- ref2.set(40) + _ <- Abort.fail(new Exception("Multi-ref failure")) + yield () + } + } + value1 <- STM.run(ref1.get) + value2 <- STM.run(ref2.get) + yield assert(result.isFail && value1 == 10 && value2 == 20) + } + + "nested transaction rollback on inner failure" in run { + for + ref <- TRef.initNow(1) + result <- Abort.run { + STM.run { + for + _ <- ref.set(2) + _ <- STM.run { + for + _ <- ref.set(3) + _ <- Abort.fail(new Exception("Inner failure")) + yield () + } + yield () + } + } + value <- STM.run(ref.get) + yield assert(result.isFail && value == 1) + } + + "partial updates within transaction are atomic" in run { + for + ref1 <- TRef.initNow("initial1") + ref2 <- TRef.initNow("initial2") + result <- Abort.run { + STM.run { + for + _ <- ref1.set("updated1") + v1 <- ref1.get + _ <- STM.retryIf(v1 == "updated1") + _ <- ref2.set("updated2") + yield () + } + } + value1 <- STM.run(ref1.get) + value2 <- STM.run(ref2.get) + yield assert( + result.isFail && + value1 == "initial1" && + value2 == "initial2" + ) + } + + "exception in update function rolls back" in run { + for + ref <- TRef.initNow(0) + result <- Abort.run { + STM.run { + ref.update { x => + if x == 0 then throw new Exception("Update failure") + else x + 1 + } + } + } + value <- STM.run(ref.get) + yield assert(result.isPanic && value == 0) + } + } + + "Concurrency" - { + + val repeats = 100 + + "concurrent updates" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + ref <- TRef.initNow(0) + _ <- Async.parallelUnbounded((1 to size).map(_ => STM.run(ref.update(_ + 1)))) + value <- STM.run(ref.get) + yield assert(value == size)) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent reads and writes" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + ref <- TRef.initNow(0) + latch <- Latch.init(1) + writeFiber <- Async.run( + latch.await.andThen(Async.parallelUnbounded((1 to size).map(_ => STM.run(ref.update(_ + 1))))) + ) + readFiber <- Async.run( + latch.await.andThen(Async.parallelUnbounded((1 to size).map(_ => STM.run(ref.get)))) + ) + _ <- latch.release + _ <- writeFiber.get + reads <- readFiber.get + value <- STM.run(ref.get) + yield assert(value == size && reads.forall(_ <= size))) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent nested transactions" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + ref <- TRef.initNow(0) + _ <- Async.parallelUnbounded((1 to size).map { _ => + STM.run { + for + _ <- ref.update(_ + 1) + _ <- STM.run { + for + v <- ref.get + _ <- ref.set(v + 1) + yield () + } + yield () + } + }) + value <- STM.run(ref.get) + yield assert(value == size * 2)) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "dining philosophers" in run { + val philosophers = 5 + (for + forks <- Kyo.fill(philosophers)(TRef.initNow(true)) + _ <- Async.parallelUnbounded( + (0 until philosophers).map { i => + val leftFork = forks(i) + val rightFork = forks((i + 1) % philosophers) + Async.parallelUnbounded((1 to 10).map { _ => + STM.run { + for + leftAvailable <- leftFork.get + _ <- STM.retryIf(!leftAvailable) + _ <- leftFork.set(false) + + rightAvailable <- rightFork.get + _ <- STM.retryIf(!rightAvailable) + _ <- rightFork.set(false) + + _ <- leftFork.set(true) + _ <- rightFork.set(true) + yield () + } + }) + } + ) + finalStates <- Kyo.collect(forks.map(fork => STM.run(fork.get))) + yield assert(finalStates.forall(identity))) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "bank account transfers" in run { + (for + account1 <- TRef.initNow(500) + account2 <- TRef.initNow(300) + account3 <- TRef.initNow(200) + + transfers = List( + STM.run { + for + balance <- account1.get + amount = 250 + _ <- STM.retryIf(balance < amount) + _ <- account1.update(_ - amount) + _ <- account2.update(_ + amount) + yield () + }, + STM.run { + for + balance <- account2.get + amount = 200 + _ <- STM.retryIf(balance < amount) + _ <- account2.update(_ - amount) + _ <- account3.update(_ + amount) + yield () + }, + STM.run { + for + balance <- account3.get + amount = 150 + _ <- STM.retryIf(balance < amount) + _ <- account3.update(_ - amount) + _ <- account1.update(_ + amount) + yield () + } + ) + + _ <- Async.parallelUnbounded(transfers) + final1 <- STM.run(account1.get) + final2 <- STM.run(account2.get) + final3 <- STM.run(account3.get) + yield + assert(final1 + final2 + final3 == 1000) + assert(final1 >= 0) + assert(final2 >= 0) + assert(final3 >= 0) + ) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "circular account transfers" in run { + (for + account1 <- TRef.initNow(300) + account2 <- TRef.initNow(200) + account3 <- TRef.initNow(100) + + circularTransfers = (1 to 5).flatMap(_ => + List( + STM.run { + for + balance <- account1.get + amount = 80 + _ <- STM.retryIf(balance < amount) + _ <- account1.update(_ - amount) + _ <- account2.update(_ + amount) + yield () + }, + STM.run { + for + balance <- account2.get + amount = 60 + _ <- STM.retryIf(balance < amount) + _ <- account2.update(_ - amount) + _ <- account3.update(_ + amount) + yield () + }, + STM.run { + for + balance <- account3.get + amount = 40 + _ <- STM.retryIf(balance < amount) + _ <- account3.update(_ - amount) + _ <- account1.update(_ + amount) + yield () + } + ) + ) + + _ <- Async.parallelUnbounded(circularTransfers) + final1 <- STM.run(account1.get) + final2 <- STM.run(account2.get) + final3 <- STM.run(account3.get) + yield + assert(final1 + final2 + final3 == 600) + assert(final1 >= 0) + assert(final2 >= 0) + assert(final3 >= 0) + ) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + } + + "async transaction nesting" - { + "nested transactions with async boundary should fail gracefully" in run { + for + ref <- TRef.initNow(0) + result <- Abort.run { + STM.run { + for + _ <- ref.set(1) + fiber <- Async.run { + STM.run { + for + _ <- ref.set(2) + v <- ref.get + yield v + } + } + _ <- fiber.get + yield + // The transaction will keep failing until it reaches the + // retry limit because the ref is changed by the nested + // fiber concurrently. The transactions in the nested + // fibers executed on each try succeed, updating the ref + // to 2. + () + } + } + value <- STM.run(ref.get) + yield assert(result.isFail && value == 2) + } + + "transaction ID should not leak across async boundaries" in run { + for + ref <- TRef.initNow(0) + (parentTid, childTid) <- + STM.run { + TID.use { parentTid => + Async.run { + STM.run(TID.use(identity)) + }.map(_.get).map { childTid => + (parentTid, childTid) + } + } + } + yield assert(parentTid != childTid) + } + } + +end STMTest diff --git a/kyo-stm/shared/src/test/scala/kyo/TMapTest.scala b/kyo-stm/shared/src/test/scala/kyo/TMapTest.scala new file mode 100644 index 000000000..bac180f31 --- /dev/null +++ b/kyo-stm/shared/src/test/scala/kyo/TMapTest.scala @@ -0,0 +1,500 @@ +package kyo + +import kyo.debug.Debug + +class TMapTest extends Test: + + "Basic operations" - { + "init and get" in run { + STM.run { + for + map <- TMap.init("key" -> 42) + value <- map.get("key") + yield value + }.map { value => + assert(value == Maybe(42)) + } + } + + "init from Map" in run { + val initial = Map("a" -> 1, "b" -> 2, "c" -> 3) + STM.run { + for + map <- TMap.init(initial) + snapshot <- map.snapshot + yield snapshot + }.map { snapshot => + assert(snapshot == initial) + } + } + + "add and contains" in run { + STM.run { + for + map <- TMap.init[String, Int]() + _ <- map.put("key", 42) + exists <- map.contains("key") + missing <- map.contains("nonexistent") + value <- map.get("key") + yield (exists, missing, value) + }.map { case (exists, missing, value) => + assert(exists && missing && value == Maybe(42)) + } + } + + "size and empty checks" in run { + STM.run { + for + map <- TMap.init[String, Int]() + empty1 <- map.isEmpty + _ <- map.put("key", 42) + size <- map.size + empty2 <- map.isEmpty + nonEmpty <- map.nonEmpty + yield (empty1, size, empty2, nonEmpty) + }.map { case (empty1, size, empty2, nonEmpty) => + assert(empty1 && size == 1 && !empty2 && nonEmpty) + } + } + + "remove operations" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2) + value <- map.remove("a") + missing <- map.remove("nonexistent") + _ <- map.removeDiscard("b") + size <- map.size + yield (value, missing, size) + }.map { case (value, missing, size) => + assert(value == Maybe(1) && missing.isEmpty && size == 0) + } + } + + "update operations" in run { + STM.run { + for + map <- TMap.init("key" -> 10) + _ <- map.updateWith("key")(v => Maybe(v.getOrElse(0) + 1)) + value1 <- map.get("key") + _ <- map.updateWith("key")(_ => Maybe.empty) + value2 <- map.get("key") + yield (value1, value2) + }.map { case (value1, value2) => + assert(value1 == Maybe(11) && value2.isEmpty) + } + } + + "clear" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + _ <- map.clear + size <- map.size + empty <- map.isEmpty + yield (size, empty) + }.map { case (size, empty) => + assert(size == 0 && empty) + } + } + } + + "Collection operations" - { + "keys" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + keys <- map.keys + yield keys + }.map { keys => + assert(keys.toSet == Set("a", "b", "c")) + } + } + + "values" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + values <- map.values + yield values + }.map { values => + assert(values.toSet == Set(1, 2, 3)) + } + } + + "entries" in run { + val initial = Map("a" -> 1, "b" -> 2, "c" -> 3) + STM.run { + for + map <- TMap.init(initial) + snapshot <- map.snapshot + yield snapshot + }.map { snapshot => + assert(snapshot == initial) + } + } + + "filter" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + _ <- map.filter((_, v) => v % 2 == 1) + snapshot <- map.snapshot + yield snapshot + }.map { snapshot => + assert(snapshot == Map("a" -> 1, "c" -> 3)) + } + } + + "fold" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + sum <- map.fold(0)((acc, _, v) => acc + v) + concat <- map.fold("")((acc, k, v) => acc + k + v) + yield (sum, concat) + }.map { case (sum, concat) => + assert(sum == 6 && concat == "a1b2c3") + } + } + + "findFirst" in run { + STM.run { + for + map <- TMap.init("a" -> 1, "b" -> 2, "c" -> 3) + found <- map.findFirst((k, v) => if v % 2 == 0 then Maybe(k) else Maybe.empty) + missing <- map.findFirst((_, v) => if v > 10 then Maybe(v) else Maybe.empty) + yield (found, missing) + }.map { case (found, missing) => + assert(found == Maybe("b") && missing.isEmpty) + } + } + + "snapshot" in run { + val initial = Map("a" -> 1, "b" -> 2, "c" -> 3) + STM.run { + for + map <- TMap.init(initial) + snapshot <- map.snapshot + yield snapshot + }.map { result => + assert(result == initial) + } + } + } + + "Error handling" - { + "rollback on direct failure" in run { + for + map <- STM.run(TMap.init[String, Int]("initial" -> 42)) + result <- Abort.run { + STM.run { + for + _ <- map.put("key1", 100) + _ <- map.put("key2", 200) + _ <- Abort.fail(new Exception("Test failure")) + yield () + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isFail && + snapshot == Map("initial" -> 42) + ) + } + + "rollback on nested transaction failure" in run { + for + map <- STM.run(TMap.init[String, Int]()) + result <- Abort.run { + STM.run { + for + _ <- map.put("outer", 1) + _ <- STM.run { + for + _ <- map.put("inner", 2) + _ <- Abort.fail(new Exception("Nested failure")) + yield () + } + yield () + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isFail && + snapshot.isEmpty + ) + } + + "partial updates remain atomic" in run { + for + map <- STM.run(TMap.init[String, Int]()) + result <- Abort.run { + STM.run { + for + _ <- map.put("key1", 100) + _ <- map.updateWith("key1") { _ => Maybe(200) } + _ <- STM.retry + _ <- map.put("key2", 300) + yield () + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isFail && + snapshot.isEmpty + ) + } + + "exception in update function rolls back" in run { + for + map <- STM.run(TMap.init[String, Int]("test" -> 42)) + result <- Abort.run { + STM.run { + map.updateWith("test") { _ => + throw new Exception("Update failure") + Maybe(100) + } + } + } + value <- STM.run(map.get("test")) + yield assert( + result.isPanic && + value == Maybe(42) + ) + } + + "filter operation rollback" in run { + for + map <- STM.run(TMap.init[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)) + result <- Abort.run { + STM.run { + for + _ <- map.filter { (k, v) => + if k == "b" then throw new Exception("Filter failure") + true + } + yield () + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isPanic && + snapshot == Map("a" -> 1, "b" -> 2, "c" -> 3) + ) + } + + "fold operation rollback" in run { + for + map <- STM.run(TMap.init[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)) + result <- Abort.run { + STM.run { + map.fold(0) { (acc, k, v) => + if acc > 2 then throw new Exception("Fold failure") + acc + v + } + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isPanic && + snapshot == Map("a" -> 1, "b" -> 2, "c" -> 3) + ) + } + + "findFirst operation rollback" in run { + for + map <- STM.run(TMap.init[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)) + result <- Abort.run { + STM.run { + map.findFirst { (k, v) => + if k == "b" then throw new Exception("Find failure") + Maybe.empty + } + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isPanic && + snapshot == Map("a" -> 1, "b" -> 2, "c" -> 3) + ) + } + + "multiple operations rollback on failure" in run { + for + map <- STM.run(TMap.init[String, Int]()) + result <- Abort.run { + STM.run { + for + _ <- map.put("key1", 100) + _ <- map.removeDiscard("key1") + _ <- map.put("key2", 200) + _ <- Abort.fail(new Exception("Multi-op failure")) + _ <- map.put("key3", 300) + yield () + } + } + snapshot <- STM.run(map.snapshot) + yield assert( + result.isFail && + snapshot.isEmpty + ) + } + + "nested effects with rollback" in run { + Var.run(0) { + for + map <- STM.run(TMap.init[String, Int]("start" -> 0)) + result <- Abort.run { + STM.run(Var.isolate.update) { + for + _ <- map.put("key1", 100) + _ <- Var.set(1) + _ <- Abort.fail(new Exception("Nested effect failure")) + _ <- map.put("key2", 200) + yield () + } + } + snapshot <- STM.run(map.snapshot) + varValue <- Var.get[Int] + yield assert( + result.isFail && + snapshot == Map("start" -> 0) && + varValue == 0 + ) + } + } + } + + "Concurrency" - { + val repeats = 100 + + "concurrent modifications" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + map <- STM.run(TMap.init[Int, Int]()) + _ <- Async.parallelUnbounded((1 to size).map(i => STM.run(map.put(i, i)))) + snapshot <- STM.run(map.snapshot) + yield assert( + snapshot.size == size && + snapshot.forall((k, v) => k == v && k >= 1 && k <= size) + )) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent reads and writes" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + map <- STM.run(TMap.init[Int, Int]()) + latch <- Latch.init(1) + + writeFiber <- Async.run( + latch.await.andThen( + Async.parallelUnbounded( + (1 to size).map(i => + STM.run(map.put(i, i * 2)) + ) + ) + ) + ) + + readFiber <- Async.run( + latch.await.andThen( + Async.parallelUnbounded( + (1 to size).map(i => + STM.run(map.get(i)) + ) + ) + ) + ) + + _ <- latch.release + _ <- writeFiber.get + reads <- readFiber.get + snapshot <- STM.run(map.snapshot) + yield + assert(snapshot.size == size) + assert(snapshot.forall((k, v) => v == k * 2)) + assert(reads.forall(maybeVal => maybeVal.isEmpty || maybeVal.exists(_ % 2 == 0))) + ) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent updates" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + map <- STM.run(TMap.init[Int, Int]()) + _ <- STM.run { + Kyo.foreachDiscard((1 to size))(i => map.put(i, 1)) + } + _ <- Async.parallelUnbounded( + (1 to 10).map(_ => + STM.run { + Kyo.foreachDiscard((1 to size)) { i => + map.updateWith(i)(v => Maybe(v.getOrElse(0) + 1)) + } + } + ) + ) + snapshot <- STM.run(map.snapshot) + yield assert( + snapshot.size == size && + snapshot.forall((_, v) => v == 11) // Initial 1 + 10 increments + )) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent removals" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + map <- STM.run(TMap.init[Int, Int]()) + _ <- STM.run { + Kyo.foreachDiscard((1 to size))(i => map.put(i, i)) + } + _ <- Async.parallelUnbounded( + (1 to size).map(i => + STM.run(map.removeDiscard(i)) + ) + ) + snapshot <- STM.run(map.snapshot) + yield assert(snapshot.isEmpty)) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + + "concurrent bulk operations" in run { + (for + size <- Choice.get(Seq(1, 10, 100)) + map <- STM.run(TMap.init[Int, Int]()) + _ <- STM.run { + Kyo.foreachDiscard((1 to size))(i => map.put(i, i)) + } + + filterOps = Async.parallelUnbounded( + (1 to 5).map(_ => + STM.run(map.filter((k, v) => v % 2 == 0)) + ) + ) + + foldOps = Async.parallelUnbounded( + (1 to 5).map(_ => + STM.run(map.fold(0)((acc, _, v) => acc + v)) + ) + ) + + _ <- filterOps + sums <- foldOps + snapshot <- STM.run(map.snapshot) + yield + assert(snapshot.forall((_, v) => v % 2 == 0)) + assert(sums.forall(_ == snapshot.values.sum)) + ) + .pipe(Choice.run, _.unit, Loop.repeat(repeats)) + .andThen(succeed) + } + } + +end TMapTest diff --git a/kyo-stm/shared/src/test/scala/kyo/TRefTest.scala b/kyo-stm/shared/src/test/scala/kyo/TRefTest.scala new file mode 100644 index 000000000..8a9261e58 --- /dev/null +++ b/kyo-stm/shared/src/test/scala/kyo/TRefTest.scala @@ -0,0 +1,38 @@ +package kyo + +import kyo.debug.Debug + +class TRefTest extends Test: + + "init and get" in run { + for + ref <- TRef.initNow(42) + value <- STM.run(ref.get) + yield assert(value == 42) + } + + "set and get" in run { + for + ref <- TRef.initNow(42) + _ <- STM.run(ref.set(100)) + value <- STM.run(ref.get) + yield assert(value == 100) + } + + "multiple operations in transaction" in run { + for + ref1 <- TRef.initNow(10) + ref2 <- TRef.initNow(20) + result <- STM.run { + for + v1 <- ref1.get + v2 <- ref2.get + _ <- ref1.set(v2) + _ <- ref2.set(v1) + r1 <- ref1.get + r2 <- ref2.get + yield (r1, r2) + } + yield assert(result == (20, 10)) + } +end TRefTest diff --git a/kyo-stm/shared/src/test/scala/kyo/Test.scala b/kyo-stm/shared/src/test/scala/kyo/Test.scala new file mode 100644 index 000000000..9f6d6b998 --- /dev/null +++ b/kyo-stm/shared/src/test/scala/kyo/Test.scala @@ -0,0 +1,32 @@ +package kyo + +import kyo.internal.BaseKyoTest +import kyo.kernel.Platform +import org.scalatest.NonImplicitAssertions +import org.scalatest.Tag +import org.scalatest.freespec.AsyncFreeSpec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future + +abstract class Test extends AsyncFreeSpec with BaseKyoTest[Abort[Any] & Async & Resource] with NonImplicitAssertions: + + private def runWhen(cond: => Boolean) = if cond then "" else "org.scalatest.Ignore" + object jvmOnly extends Tag(runWhen(kyo.kernel.Platform.isJVM)) + object jsOnly extends Tag(runWhen(kyo.kernel.Platform.isJS)) + + def run(v: Future[Assertion] < (Abort[Any] & Async & Resource)): Future[Assertion] = + import AllowUnsafe.embrace.danger + val a = Async.run(Abort.run(Resource.run(v)).map(_.fold { + _.getFailure match + case ex: Throwable => throw ex + case e => throw new IllegalStateException(s"Test aborted with $e") + }(identity))) + val b = a.map(_.toFuture).map(_.flatten) + IO.Unsafe.run(b).eval + end run + + type Assertion = org.scalatest.compatible.Assertion + def success = succeed + + override given executionContext: ExecutionContext = Platform.executionContext +end Test