Skip to content

Commit

Permalink
Transfer tick state when replacing a blocked thread
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Jan 24, 2025
1 parent 62e956b commit 6e9a121
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
val thread =
new WorkerThread(
index,
0,
queue,
parkedSignal,
externalQueue,
Expand All @@ -174,6 +175,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
system,
poller,
metrics,
new WorkerThread.TransferState,
this)

workerThreads.set(i, thread)
Expand Down
47 changes: 29 additions & 18 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import java.lang.Long.MIN_VALUE
import java.util.concurrent.{LinkedTransferQueue, ThreadLocalRandom}
import java.util.concurrent.atomic.AtomicBoolean

import WorkerThread.Metrics
import WorkerThread.{Metrics, TransferState}

/**
* Implementation of the worker thread at the heart of the [[WorkStealingThreadPool]].
Expand All @@ -45,6 +45,7 @@ import WorkerThread.Metrics
*/
private[effect] final class WorkerThread[P <: AnyRef](
idx: Int,
private[this] var tick: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
// The state of the `WorkerThread` (parked/unparked).
Expand All @@ -58,6 +59,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
private[this] val system: PollingSystem.WithPoller[P],
private[this] var _poller: P,
private[this] var metrics: Metrics,
private[this] var transferState: TransferState,
// Reference to the `WorkStealingThreadPool` in which this thread operates.
pool: WorkStealingThreadPool[P])
extends Thread
Expand Down Expand Up @@ -109,7 +111,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
*/
private[this] var _active: Runnable = _

private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue()
private val stateTransfer: LinkedTransferQueue[TransferState] = new LinkedTransferQueue()
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[?] = _
Expand Down Expand Up @@ -313,8 +315,6 @@ private[effect] final class WorkerThread[P <: AnyRef](
random = ThreadLocalRandom.current()
val rnd = random

var iteration = 0

val done = pool.done

/*
Expand Down Expand Up @@ -673,15 +673,16 @@ private[effect] final class WorkerThread[P <: AnyRef](
_active = null
_poller = null.asInstanceOf[P]
metrics = null
transferState = null

// Add this thread to the cached threads data structure, to be picked up
// by another thread in the future.
pool.cachedThreads.add(this)
try {
val len = runtimeBlockingExpiration.length
val unit = runtimeBlockingExpiration.unit
var newIdx: Integer = indexTransfer.poll(len, unit)
if (newIdx eq null) {
var newState = stateTransfer.poll(len, unit)
if (newState eq null) {
// The timeout elapsed and no one woke up this thread. Try to remove
// the thread from the cached threads data structure.
if (pool.cachedThreads.remove(this)) {
Expand All @@ -692,12 +693,12 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Someone else concurrently stole this thread from the cached
// data structure and will transfer the data soon. Time to wait
// for it again.
newIdx = indexTransfer.take()
init(newIdx)
newState = stateTransfer.take()
init(newState)
}
} else {
// Some other thread woke up this thread. Time to take its place.
init(newIdx)
init(newState)
}
} catch {
case _: InterruptedException =>
Expand All @@ -706,13 +707,9 @@ private[effect] final class WorkerThread[P <: AnyRef](
// exit.
return
}

// Reset the state of the thread for resumption.
blocking = false
iteration = 1
}

((iteration & ExternalQueueTicksMask): @switch) match {
((tick & ExternalQueueTicksMask): @switch) match {
case 0 =>
if (pool.blockedThreadDetectionEnabled) {
// TODO prefetch pool.workerThread or Thread.State.BLOCKED ?
Expand Down Expand Up @@ -822,7 +819,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Continue executing fibers from the local queue.
}

iteration += 1
tick += 1
}
}

Expand Down Expand Up @@ -895,7 +892,9 @@ private[effect] final class WorkerThread[P <: AnyRef](
val idx = index
pool.replaceWorker(idx, cached)
// Transfer the data structures to the cached thread and wake it up.
cached.indexTransfer.transfer(idx)
transferState.index = idx
transferState.tick = tick + 1
cached.stateTransfer.transfer(transferState)
} else {
// Spawn a new `WorkerThread`, a literal clone of this one. It is safe to
// transfer ownership of the local queue and the parked signal to the new
Expand All @@ -911,6 +910,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
val clone =
new WorkerThread(
idx,
tick + 1,
queue,
parked,
external,
Expand All @@ -919,6 +919,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
system,
_poller,
metrics,
transferState,
pool)
// Make sure the clone gets our old name:
val clonePrefix = pool.threadPrefix
Expand All @@ -942,18 +943,23 @@ private[effect] final class WorkerThread[P <: AnyRef](
thunk
}

private[this] def init(newIdx: Int): Unit = {
private[this] def init(newState: TransferState): Unit = {
val newIdx = newState.index
_index = newIdx
tick = newState.tick
queue = pool.localQueues(newIdx)
sleepers = pool.sleepers(newIdx)
parked = pool.parkedSignals(newIdx)
fiberBag = pool.fiberBags(newIdx)
_poller = pool.pollers(newIdx)
metrics = pool.metrices(newIdx)
transferState = newState

// Reset the name of the thread to the regular prefix.
val prefix = pool.threadPrefix
setName(s"$prefix-$newIdx")
setName(s"$prefix-${_index}")

blocking = false
}

/**
Expand All @@ -973,6 +979,11 @@ private[effect] final class WorkerThread[P <: AnyRef](

private[effect] object WorkerThread {

private[unsafe] final class TransferState {
var index: Int = _
var tick: Int = _
}

final class Metrics {
private[this] var idleTime: Long = 0
def getIdleTime(): Long = idleTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala

val test = mkBlockingWork *>
IO.pollers.map(_.head.asInstanceOf[DummyPoller]).flatMap { poller =>
poller.poll.as(true)
poller.poll.replicateA_(100).as(true)
}

test.unsafeRunTimed(1.second) must beSome(beTrue)
Expand Down

0 comments on commit 6e9a121

Please sign in to comment.