From d53aa857b595f70cb87cd88219942b3f44ee2878 Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Fri, 13 Dec 2024 10:32:33 -0800 Subject: [PATCH] [stm] optimizations (#924) Benchmark [results](https://jmh.morethan.io/?sources=https://gist.githubusercontent.com/fwbrasil/28a2f0957f839310a85752634904a99d/raw/e414912bcd3b7d89d43663298a40a5e3dd5c560d/jmh-result-baseline.json,https://gist.githubusercontent.com/fwbrasil/28a2f0957f839310a85752634904a99d/raw/e414912bcd3b7d89d43663298a40a5e3dd5c560d/jmh-result-candidate.json): ![image](https://github.com/user-attachments/assets/a6b59751-5fc9-473f-ad1a-605ca251d473) --- .../main/scala/kyo/bench/TRefMultiBench.scala | 2 +- .../shared/src/main/scala/kyo/Retry.scala | 11 +- .../shared/src/main/scala/kyo/Local.scala | 2 +- .../shared/src/main/scala/kyo/Var.scala | 12 +- kyo-stm/shared/src/main/scala/kyo/STM.scala | 144 +++++++++++++++--- kyo-stm/shared/src/main/scala/kyo/TID.scala | 10 +- kyo-stm/shared/src/main/scala/kyo/TRef.scala | 24 +-- .../shared/src/main/scala/kyo/TRefLog.scala | 28 ++-- .../shared/src/test/scala/kyo/STMTest.scala | 2 +- .../src/test/scala/kyo/TRefLogTest.scala | 10 +- 10 files changed, 172 insertions(+), 73 deletions(-) diff --git a/kyo-bench/src/main/scala/kyo/bench/TRefMultiBench.scala b/kyo-bench/src/main/scala/kyo/bench/TRefMultiBench.scala index 9516e4b55..26e17b14c 100644 --- a/kyo-bench/src/main/scala/kyo/bench/TRefMultiBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/TRefMultiBench.scala @@ -13,7 +13,7 @@ class TRefMultiBench(parallelism: Int) extends Bench.ForkOnly(parallelism): STM.runtime[IO].flatMap { stm => for - refs <- stm.commit(Seq.fill(parallelism)(stm.TVar.of(0)).sequence) + refs <- Seq.fill(parallelism)(stm.commit(stm.TVar.of(0))).sequence _ <- refs.map(ref => stm.commit(ref.modify(_ + 1))).parSequence_ result <- stm.commit(refs.traverse(_.get).map(_.sum)) yield result diff --git a/kyo-core/shared/src/main/scala/kyo/Retry.scala b/kyo-core/shared/src/main/scala/kyo/Retry.scala index 33d804e14..b726d381d 100644 --- a/kyo-core/shared/src/main/scala/kyo/Retry.scala +++ b/kyo-core/shared/src/main/scala/kyo/Retry.scala @@ -41,18 +41,17 @@ object Retry: SafeClassTag[E], Frame ): A < (Async & Abort[E] & S) = - Loop(schedule) { schedule => - Abort.run[E](v).map(_.fold { r => + Abort.run[E](v).map { + case Result.Success(r) => r + case error: Result.Error[?] => Clock.now.map { now => schedule.next(now).map { (delay, nextSchedule) => - Async.delay(delay)(Loop.continue(nextSchedule)) + Async.delay(delay)(Retry[E](nextSchedule)(v)) }.getOrElse { - Abort.get(r) + Abort.get(error) } } - }(Loop.done(_))) } - end apply end RetryOps /** Creates a RetryOps instance for the specified error type. diff --git a/kyo-prelude/shared/src/main/scala/kyo/Local.scala b/kyo-prelude/shared/src/main/scala/kyo/Local.scala index 2c0270449..db3e2a0aa 100644 --- a/kyo-prelude/shared/src/main/scala/kyo/Local.scala +++ b/kyo-prelude/shared/src/main/scala/kyo/Local.scala @@ -108,7 +108,7 @@ object Local: ContextEffect.suspendAndMap(tag, Map.empty)(map => f(map.getOrElse(this, default).asInstanceOf[A])) def let[B, S](value: A)(v: B < S)(using Frame) = - ContextEffect.handle(tag, Map(this -> value), _.updated(this, value.asInstanceOf[AnyRef]))(v) + ContextEffect.handle(tag, Map.empty[Local[?], AnyRef].updated(this, value), _.updated(this, value.asInstanceOf[AnyRef]))(v) def update[B, S](f: A => A)(v: B < S)(using Frame) = ContextEffect.handle( diff --git a/kyo-prelude/shared/src/main/scala/kyo/Var.scala b/kyo-prelude/shared/src/main/scala/kyo/Var.scala index 740c40975..e6816145a 100644 --- a/kyo-prelude/shared/src/main/scala/kyo/Var.scala +++ b/kyo-prelude/shared/src/main/scala/kyo/Var.scala @@ -167,7 +167,7 @@ object Var: runWith(state)(v)((state, result) => (state, result)) object isolate: - abstract private[kyo] class Base[V: Tag] extends Isolate[Var[V]]: + abstract private[kyo] class Base[V](using Tag[Var[V]]) extends Isolate[Var[V]]: type State = V def use[A, S2](f: V => A < S2)(using Frame) = Var.use(f) def resume[A: Flat, S2](state: State, v: A < (Var[V] & S2))(using Frame) = @@ -183,10 +183,10 @@ object Var: * @return * An isolate that updates the Var with its isolated value */ - def update[V: Tag]: Isolate[Var[V]] = + def update[V](using Tag[Var[V]]): Isolate[Var[V]] = new Base[V]: def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) = - Var.set(state).andThen(v) + Var.setAndThen(state)(v) /** Creates an isolate that merges Var values using a combination function. * @@ -200,10 +200,10 @@ object Var: * @return * An isolate that merges Var values */ - def merge[V: Tag](f: (V, V) => V): Isolate[Var[V]] = + def merge[V](f: (V, V) => V)(using Tag[Var[V]]): Isolate[Var[V]] = new Base[V]: def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) = - Var.use[V](prev => Var.set(f(prev, state)).andThen(v)) + Var.use[V](prev => Var.setAndThen(f(prev, state))(v)) /** Creates an isolate that keeps Var modifications local. * @@ -215,7 +215,7 @@ object Var: * @return * An isolate that discards Var modifications */ - def discard[V: Tag]: Isolate[Var[V]] = + def discard[V](using Tag[Var[V]]): Isolate[Var[V]] = new Base[V]: def restore[A: Flat, S2](state: V, v: A < S2)(using Frame) = v diff --git a/kyo-stm/shared/src/main/scala/kyo/STM.scala b/kyo-stm/shared/src/main/scala/kyo/STM.scala index dcf95c91b..9b59c60ee 100644 --- a/kyo-stm/shared/src/main/scala/kyo/STM.scala +++ b/kyo-stm/shared/src/main/scala/kyo/STM.scala @@ -1,10 +1,13 @@ package kyo +import java.util.Arrays +import kyo.Result.Fail import scala.annotation.tailrec +import scala.util.control.NoStackTrace /** A FailedTransaction exception that is thrown when a transaction fails to commit. Contains the frame where the failure occurred. */ -case class FailedTransaction(frame: Frame) extends Exception(frame.position.show) +case class FailedTransaction(frame: Frame) extends Exception(frame.position.show) with NoStackTrace /** Software Transactional Memory (STM) provides concurrent access to shared state using optimistic locking. Rather than acquiring locks * upfront, transactions execute speculatively and automatically retry if conflicts are detected during commit. While this enables better @@ -108,26 +111,89 @@ object STM: // New transaction without a parent, use regular commit flow Retry[FailedTransaction](retrySchedule) { TID.useNew { tid => - TRefLog.runWith(v) { (log, result) => - IO.Unsafe { - // Attempt to acquire locks and commit the transaction - val (locked, unlocked) = - // Sort references by identity to prevent deadlocks - log.toSeq.sortBy((ref, _) => ref.hashCode) - .span((ref, entry) => ref.lock(entry)) - - if unlocked.nonEmpty then - // Failed to acquire some locks - rollback and retry - locked.foreach((ref, entry) => ref.unlock(entry)) - Abort.fail(FailedTransaction(frame)) - else - // Successfully locked all references - commit changes - locked.foreach((ref, entry) => ref.commit(tid, entry)) - // Release all locks - locked.foreach((ref, entry) => ref.unlock(entry)) + Var.runWith(TRefLog.empty)(v) { (log, result) => + val logMap = log.toMap + logMap.size match + case 0 => + // Nothing to commit result - end if - } + case 1 => + // Fast-path for a single ref + IO.Unsafe { + val (ref, entry) = logMap.head + // No need to pre-validate since `lock` validates and + // there's a single ref + if ref.lock(entry) then + ref.commit(tid, entry) + ref.unlock(entry) + result + else + Abort.fail(FailedTransaction(frame)) + end if + } + case size => + // Commit multiple refs + IO.Unsafe { + // Flattened representation of the log + val array = new Array[Any](size * 2) + + try + def fail = throw new FailedTransaction(frame) + + var i = 0 + // Pre-validate and dump the log to the flat array + logMap.foreachEntry { (ref, entry) => + // This code uses exception throwing because + // foreachEntry is the only way to traverse the + // map without allocating tuples, so throwing + // is the workaround to short circuit + if !ref.validate(entry) then fail + array(i) = ref + array(i + 1) = entry + i += 2 + } + + // Sort references by identity to prevent deadlocks + quickSort(array, size) + + // Convenience accessors to the flat log + inline def ref(idx: Int) = array(idx * 2).asInstanceOf[TRef[Any]] + inline def entry(idx: Int) = array(idx * 2 + 1).asInstanceOf[TRefLog.Entry[Any]] + + @tailrec def lock(idx: Int): Int = + if idx == size then size + else if !ref(idx).lock(entry(idx)) then idx + else lock(idx + 1) + + @tailrec def unlock(idx: Int, upTo: Int): Unit = + if idx < upTo then + ref(idx).unlock(entry(idx)) + unlock(idx + 1, upTo) + + @tailrec def commit(idx: Int): Unit = + if idx < size then + ref(idx).commit(tid, entry(idx)) + commit(idx + 1) + + val acquired = lock(0) + if acquired != size then + // Failed to acquire some locks - rollback and retry + unlock(0, acquired) + fail + end if + + // Successfully locked all references - commit changes + commit(0) + + // Release all locks + unlock(0, size) + result + catch + case ex: FailedTransaction => + Abort.fail(ex) + end try + } + end match } } } @@ -135,7 +201,7 @@ object STM: // Nested transaction inherits parent's transaction context but isolates RefLog. // On success: changes propagate to parent. On failure: changes are rolled back // without affecting parent's state. - val result = TRefLog.isolate(v) + val result = TRefLog.isolate.run(v) // Can't return `result` directly since it has a pending STM effect // but it's safe to cast because, if there's a parent transaction, @@ -145,4 +211,40 @@ object STM: } end run + + private def quickSort(array: Array[Any], size: Int): Unit = + def swap(i: Int, j: Int): Unit = + val temp = array(i) + array(i) = array(j) + array(j) = temp + val temp2 = array(i + 1) + array(i + 1) = array(j + 1) + array(j + 1) = temp2 + end swap + + def getHash(idx: Int): Int = + array(idx * 2).hashCode() + + @tailrec def partitionLoop(low: Int, hi: Int, pivot: Int, i: Int, j: Int): Int = + if j >= hi then + swap(i * 2, pivot * 2) + i + else if getHash(j) < getHash(pivot) then + swap(i * 2, j * 2) + partitionLoop(low, hi, pivot, i + 1, j + 1) + else + partitionLoop(low, hi, pivot, i, j + 1) + + def partition(low: Int, hi: Int): Int = + partitionLoop(low, hi, hi, low, low) + + def loop(low: Int, hi: Int): Unit = + if low < hi then + val p = partition(low, hi) + loop(low, p - 1) + loop(p + 1, hi) + + if size > 0 then + loop(0, size - 1) + end quickSort end STM diff --git a/kyo-stm/shared/src/main/scala/kyo/TID.scala b/kyo-stm/shared/src/main/scala/kyo/TID.scala index e9795c034..173fc6fd1 100644 --- a/kyo-stm/shared/src/main/scala/kyo/TID.scala +++ b/kyo-stm/shared/src/main/scala/kyo/TID.scala @@ -5,20 +5,20 @@ private[kyo] object TID: // Unique transaction ID generation private val nextTid = AtomicLong.Unsafe.init(0)(using AllowUnsafe.embrace.danger) - private val tidLocal = Local.initIsolated(-1L) + private val tidLocal = Local.initIsolated[java.lang.Long](-1L) def next(using AllowUnsafe): Long = nextTid.incrementAndGet() - def useNew[A, S](f: Long => A < S)(using Frame): A < (S & IO) = + inline def useNew[A, S](inline f: Long => A < S)(using inline frame: Frame): A < (S & IO) = IO.Unsafe { val tid = nextTid.incrementAndGet() tidLocal.let(tid)(f(tid)) } - def use[A, S](f: Long => A < S)(using Frame): A < S = - tidLocal.use(f) + inline def use[A, S](inline f: Long => A < S)(using inline frame: Frame): A < S = + tidLocal.use(f(_)) - def useRequired[A, S](f: Long => A < S)(using Frame): A < S = + inline def useRequired[A, S](inline f: Long => A < S)(using inline frame: Frame): A < S = tidLocal.use { case -1L => bug("STM operation attempted outside of STM.run - this should be impossible due to effect typing") case tid => f(tid) diff --git a/kyo-stm/shared/src/main/scala/kyo/TRef.scala b/kyo-stm/shared/src/main/scala/kyo/TRef.scala index 285bd1ec2..3d11ec9fc 100644 --- a/kyo-stm/shared/src/main/scala/kyo/TRef.scala +++ b/kyo-stm/shared/src/main/scala/kyo/TRef.scala @@ -42,6 +42,7 @@ sealed trait TRef[A]: final def update[S](f: A => A < S)(using Frame): Unit < (STM & S) = use(f(_).map(set)) private[kyo] def state(using AllowUnsafe): Write[A] + private[kyo] def validate(entry: Entry[A])(using AllowUnsafe): Boolean private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean private[kyo] def commit(tid: Long, entry: Entry[A])(using AllowUnsafe): Unit private[kyo] def unlock(entry: Entry[A])(using AllowUnsafe): Unit @@ -62,7 +63,7 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A]) private[kyo] def state(using AllowUnsafe): Write[A] = currentState def use[B, S](f: A => B < S)(using Frame): B < (STM & S) = - TRefLog.use { log => + Var.use[TRefLog] { log => log.get(this) match case Present(entry) => f(entry.value) @@ -76,7 +77,7 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A]) else // Append Read to the log and return value val entry = Read(state.tid, state.value) - TRefLog.setAndThen(log.put(this, entry))(f(state.value)) + Var.setAndThen(log.put(this, entry))(f(state.value)) end if } } @@ -84,11 +85,11 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A]) } def set(v: A)(using Frame): Unit < STM = - TRefLog.use { log => + Var.use[TRefLog] { log => log.get(this) match case Present(prev) => val entry = Write(prev.tid, v) - TRefLog.setDiscard(log.put(this, entry)) + Var.setDiscard(log.put(this, entry)) case Absent => TID.useRequired { tid => IO { @@ -99,15 +100,18 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A]) else // Append Write to the log val entry = Write(state.tid, v) - TRefLog.setDiscard(log.put(this, entry)) + Var.setDiscard(log.put(this, entry)) end if } } } + private[kyo] def validate(entry: Entry[A])(using AllowUnsafe): Boolean = + currentState.tid == entry.tid + private[kyo] def lock(entry: Entry[A])(using AllowUnsafe): Boolean = @tailrec def loop(): Boolean = - currentState.tid == entry.tid && { + validate(entry) && { val lockState = super.get() entry match case Read(tid, value) => @@ -119,9 +123,9 @@ final private class TRefImpl[A] private[kyo] (initialState: Write[A]) end match } val locked = loop() - if locked && currentState.tid != entry.tid then + if locked && !validate(entry) then // This branch handles the race condition where another fiber commits - // after the initial `currentState.tid == entry.tid` check but before the + // after the initial `validate(entry)` check but before the // lock is acquired. If that's the case, roll back the lock. unlock(entry) false @@ -160,10 +164,10 @@ object TRef: */ def init[A](value: A)(using Frame): TRef[A] < STM = TID.useRequired { tid => - TRefLog.use { log => + Var.use[TRefLog] { log => IO.Unsafe { val ref = TRef.Unsafe.init(tid, value) - TRefLog.setAndThen(log.put(ref, ref.state))(ref) + Var.setAndThen(log.put(ref, ref.state))(ref) } } } diff --git a/kyo-stm/shared/src/main/scala/kyo/TRefLog.scala b/kyo-stm/shared/src/main/scala/kyo/TRefLog.scala index 7f9a70626..daa94fb15 100644 --- a/kyo-stm/shared/src/main/scala/kyo/TRefLog.scala +++ b/kyo-stm/shared/src/main/scala/kyo/TRefLog.scala @@ -1,5 +1,8 @@ package kyo +import scala.collection.immutable.Map +import scala.collection.mutable.TreeMap + /** A log of transactional operations performed on TRefs within an STM transaction. * * TRefLog maintains a mapping from transactional references to their pending read/write operations within a transaction. It tracks both @@ -14,35 +17,26 @@ opaque type TRefLog = Map[TRef[Any], TRefLog.Entry[Any]] private[kyo] object TRefLog: + given tag: Tag[Var[TRefLog]] = Tag[Var[Map[TRef[Any], TRefLog.Entry[Any]]]] + val empty: TRefLog = Map.empty extension (self: TRefLog) def put[A](ref: TRef[A], entry: Entry[A]): TRefLog = - self.updated(ref.asInstanceOf[TRef[Any]], entry.asInstanceOf[Entry[Any]]) + val refAny = ref.asInstanceOf[TRef[Any]] + val entryAny = entry.asInstanceOf[TRefLog.Entry[Any]] + self.updated(refAny, entryAny) + end put def get[A](ref: TRef[A]): Maybe[Entry[A]] = val refAny = ref.asInstanceOf[TRef[Any]] Maybe.when(self.contains(refAny))(self(refAny).asInstanceOf[Entry[A]]) - def toSeq: Seq[(TRef[Any], Entry[Any])] = - self.toSeq + def toMap: Map[TRef[Any], TRefLog.Entry[Any]] = self end extension - def use[A, S](f: TRefLog => A < S)(using Frame): A < (S & Var[TRefLog]) = - Var.use(f) - - def isolate[A: Flat, S](v: A < (S & Var[TRefLog]))(using Frame): A < (S & Var[TRefLog]) = - Var.isolate.update[TRefLog].run(v) - - def runWith[A: Flat, B, S, S2](v: A < (S & Var[TRefLog]))(f: (TRefLog, A) => B < S2)(using Frame): B < (S & S2) = - Var.runWith(empty)(v)(f(_, _)) - - def setAndThen[A, S](log: TRefLog)(f: => A < S)(using Frame): A < (S & Var[TRefLog]) = - Var.setAndThen(log)(f) - - def setDiscard(log: TRefLog)(using Frame): Unit < Var[TRefLog] = - Var.setDiscard(log) + val isolate = Var.isolate.update[TRefLog](using tag) sealed abstract class Entry[A]: def tid: Long diff --git a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala index c28ec55d2..2a8ff0409 100644 --- a/kyo-stm/shared/src/test/scala/kyo/STMTest.scala +++ b/kyo-stm/shared/src/test/scala/kyo/STMTest.scala @@ -151,7 +151,7 @@ class STMTest extends Test: for _ <- ref.set(2) _ <- Var.set(1) - innerResult <- TRefLog.isolate { + innerResult <- TRefLog.isolate.run { for _ <- ref.set(3) _ <- Var.set(2) diff --git a/kyo-stm/shared/src/test/scala/kyo/TRefLogTest.scala b/kyo-stm/shared/src/test/scala/kyo/TRefLogTest.scala index 92698e877..90ddb34d5 100644 --- a/kyo-stm/shared/src/test/scala/kyo/TRefLogTest.scala +++ b/kyo-stm/shared/src/test/scala/kyo/TRefLogTest.scala @@ -9,7 +9,7 @@ class TTRefLogTest extends Test: "TTRefLog" - { "empty" in run { val log = TRefLog.empty - assert(log.toSeq.isEmpty) + assert(log.toMap.isEmpty) } "put" in run { @@ -17,9 +17,9 @@ class TTRefLogTest extends Test: val ref = new TRefImpl[Int](Write(0, 0)) val entry = Write(1, 42) val log = TRefLog.empty.put(ref, entry) - assert(log.toSeq.size == 1) - assert(log.toSeq.head._1 == ref) - assert(log.toSeq.head._2 == entry) + assert(log.toMap.size == 1) + assert(log.toMap.head._1 == ref) + assert(log.toMap.head._2 == entry) } } @@ -44,7 +44,7 @@ class TTRefLogTest extends Test: .put(ref1, entry1) .put(ref2, entry2) - val seq = log.toSeq + val seq = log.toMap.toSeq assert(seq.size == 2) assert(seq.contains((ref1, entry1))) assert(seq.contains((ref2, entry2)))