Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
djspiewak committed Feb 23, 2023
1 parent e848c2b commit 32d881f
Show file tree
Hide file tree
Showing 28 changed files with 651 additions and 673 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.concurrent.duration.FiniteDuration
// Can you imagine a thread pool on JS? Have fun trying to extend or instantiate
// this class. Unfortunately, due to the explicit branching, this type leaks
// into the shared source code of IOFiber.scala.
private[effect] sealed abstract class WorkStealingThreadPool private ()
private[effect] sealed abstract class WorkStealingThreadPool[+P] private ()
extends ExecutionContext {
def execute(runnable: Runnable): Unit
def reportFailure(cause: Throwable): Unit
Expand All @@ -38,12 +38,12 @@ private[effect] sealed abstract class WorkStealingThreadPool private ()
private[effect] def canExecuteBlockingCode(): Boolean
private[unsafe] def liveTraces(): (
Map[Runnable, Trace],
Map[WorkerThread, (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])],
Map[WorkerThread[Poller], (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])],
Map[Runnable, Trace])
}

private[unsafe] sealed abstract class WorkerThread private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean
private[unsafe] sealed abstract class WorkerThread[+P] private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[Poller]): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
private[effect] sealed class FiberMonitor(
// A reference to the compute pool of the `IORuntime` in which this suspended fiber bag
// operates. `null` if the compute pool of the `IORuntime` is not a `WorkStealingThreadPool`.
private[this] val compute: WorkStealingThreadPool
private[this] val compute: WorkStealingThreadPool[Poller]
) extends FiberMonitorShared {

private[this] final val Bags = FiberMonitor.Bags
Expand All @@ -64,8 +64,8 @@ private[effect] sealed class FiberMonitor(
*/
def monitorSuspended(fiber: IOFiber[_]): WeakBag.Handle = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread]) {
val worker = thread.asInstanceOf[WorkerThread]
if (thread.isInstanceOf[WorkerThread[_]]) {
val worker = thread.asInstanceOf[WorkerThread[Poller]]
// Guard against tracking errors when multiple work stealing thread pools exist.
if (worker.isOwnedBy(compute)) {
worker.monitor(fiber)
Expand Down Expand Up @@ -111,7 +111,7 @@ private[effect] sealed class FiberMonitor(
val externalFibers = external.collect(justFibers)
val suspendedFibers = suspended.collect(justFibers)
val workersMapping: Map[
WorkerThread,
WorkerThread[_],
(Thread.State, Option[(IOFiber[_], Trace)], Map[IOFiber[_], Trace])] =
workers.map {
case (thread, (state, opt, set)) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package cats.effect.unsafe

trait Poller {
def poll(limitNanos: Long): Boolean
def interrupt(targetThread: Thread): Unit
def close(): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package cats.effect.unsafe

abstract class PollingRuntime[+P <: Poller] {
def buildPoller(reportFailure: Throwable => Unit): P
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,6 @@
package cats.effect
package unsafe

abstract class PollingSystem {

/**
* The user-facing interface.
*/
type GlobalPollingState <: AnyRef

/**
* The thread-local data structure used for polling.
*/
type Poller <: AnyRef

def makeGlobalPollingState(register: (Poller => Unit) => Unit): GlobalPollingState

def makePoller(): Poller

def closePoller(poller: Poller): Unit

/**
* @param nanos
* the maximum duration for which to block, where `nanos == -1` indicates to block
* indefinitely. ''However'', if `nanos == -1` and there are no remaining events to poll
* for, this method should return `false` immediately. This is unfortunate but necessary so
* that the `EventLoop` can yield to the Scala Native global `ExecutionContext` which is
* currently hard-coded into every test framework, including MUnit, specs2, and Weaver.
*
* @return
* whether poll should be called again (i.e., there are more events to be polled)
*/
def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean

def interrupt(targetThread: Thread, targetPoller: Poller): Unit

abstract class PollingSystem[+P <: Poller] {
def buildRuntime(): PollingRuntime[P]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package cats.effect.unsafe

import scala.concurrent.ExecutionContext

trait RuntimeContext[+P] extends ExecutionContext {
def register(cb: P => Unit): Unit
}
2 changes: 1 addition & 1 deletion core/jvm/src/main/scala/cats/effect/IOApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ trait IOApp {
*/
protected def runtimeConfig: unsafe.IORuntimeConfig = unsafe.IORuntimeConfig()

protected def pollingSystem: unsafe.PollingSystem = unsafe.SelectorSystem()
protected def pollingSystem: unsafe.PollingSystem[unsafe.Poller] = unsafe.SelectorSystem()

/**
* Controls the number of worker threads which will be allocated to the compute pool in the
Expand Down
11 changes: 0 additions & 11 deletions core/jvm/src/main/scala/cats/effect/IOCompanionPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ package cats.effect

import cats.effect.std.Console
import cats.effect.tracing.Tracing
import cats.effect.unsafe.WorkStealingThreadPool

import scala.reflect.ClassTag

import java.time.Instant
import java.util.concurrent.{CompletableFuture, CompletionStage}
Expand Down Expand Up @@ -144,12 +141,4 @@ private[effect] abstract class IOCompanionPlatform { this: IO.type =>
*/
def readLine: IO[String] =
Console[IO].readLine

def poller[Poller](implicit ct: ClassTag[Poller]): IO[Option[Poller]] =
IO.executionContext.map {
case wstp: WorkStealingThreadPool
if ct.runtimeClass.isInstance(wstp.globalPollingState) =>
Some(wstp.globalPollingState.asInstanceOf[Poller])
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import scala.concurrent.ExecutionContext

private[unsafe] trait FiberMonitorCompanionPlatform {
def apply(compute: ExecutionContext): FiberMonitor = {
if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool]) {
val wstp = compute.asInstanceOf[WorkStealingThreadPool]
if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool[_]]) {
val wstp = compute.asInstanceOf[WorkStealingThreadPool[Poller]]
new FiberMonitor(wstp)
} else {
new FiberMonitor(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type

private[this] final val DefaultBlockerPrefix = "io-compute-blocker"

def createWorkStealingComputeThreadPool(
def createWorkStealingComputeThreadPool[P <: Poller](
threads: Int = Math.max(2, Runtime.getRuntime().availableProcessors()),
threadPrefix: String = "io-compute",
blockerThreadPrefix: String = DefaultBlockerPrefix,
runtimeBlockingExpiration: Duration = 60.seconds,
reportFailure: Throwable => Unit = _.printStackTrace(),
pollingSystem: PollingSystem = SelectorSystem()): (WorkStealingThreadPool, () => Unit) = {
pollingSystem: PollingSystem[P] = SelectorSystem()): (WorkStealingThreadPool[P], () => Unit) = {
val threadPool =
new WorkStealingThreadPool(
threads,
Expand Down Expand Up @@ -123,7 +123,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
threadPrefix: String,
blockerThreadPrefix: String,
runtimeBlockingExpiration: Duration,
reportFailure: Throwable => Unit): (WorkStealingThreadPool, () => Unit) =
reportFailure: Throwable => Unit): (WorkStealingThreadPool[Poller], () => Unit) =
createWorkStealingComputeThreadPool(
threads,
threadPrefix,
Expand All @@ -141,14 +141,14 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
threads: Int = Math.max(2, Runtime.getRuntime().availableProcessors()),
threadPrefix: String = "io-compute",
blockerThreadPrefix: String = DefaultBlockerPrefix)
: (WorkStealingThreadPool, () => Unit) =
: (WorkStealingThreadPool[Poller], () => Unit) =
createWorkStealingComputeThreadPool(threads, threadPrefix, blockerThreadPrefix)

@deprecated("bincompat shim for previous default method overload", "3.3.13")
def createDefaultComputeThreadPool(
self: () => IORuntime,
threads: Int,
threadPrefix: String): (WorkStealingThreadPool, () => Unit) =
threadPrefix: String): (WorkStealingThreadPool[Poller], () => Unit) =
createDefaultComputeThreadPool(self(), threads, threadPrefix)

def createDefaultBlockingExecutionContext(
Expand Down
8 changes: 4 additions & 4 deletions core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ import java.util.concurrent.ThreadLocalRandom
* `Unsafe`. And `Unsafe` is only really needed on JVM 8. JVM 9+ introduce much richer and
* better APIs and tools for building high-performance concurrent systems (e.g. `VarHandle`).
*/
private final class LocalQueue extends LocalQueuePadding {
private final class LocalQueue[P <: Poller] extends LocalQueuePadding {

import LocalQueueConstants._
import TracingConstants._
Expand Down Expand Up @@ -337,7 +337,7 @@ private final class LocalQueue extends LocalQueuePadding {
* @return
* a fiber to be executed directly
*/
def enqueueBatch(batch: Array[Runnable], worker: WorkerThread): Runnable = {
def enqueueBatch(batch: Array[Runnable], worker: WorkerThread[P]): Runnable = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

Expand Down Expand Up @@ -410,7 +410,7 @@ private final class LocalQueue extends LocalQueuePadding {
* the fiber at the head of the queue, or `null` if the queue is empty (in order to avoid
* unnecessary allocations)
*/
def dequeue(worker: WorkerThread): Runnable = {
def dequeue(worker: WorkerThread[P]): Runnable = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

Expand Down Expand Up @@ -487,7 +487,7 @@ private final class LocalQueue extends LocalQueuePadding {
* a reference to the first fiber to be executed by the stealing [[WorkerThread]], or `null`
* if the stealing was unsuccessful
*/
def stealInto(dst: LocalQueue, dstWorker: WorkerThread): Runnable = {
def stealInto(dst: LocalQueue[P], dstWorker: WorkerThread[P]): Runnable = {
// A plain, unsynchronized load of the tail of the destination queue, owned
// by the executing thread.
val dstTl = dst.tail
Expand Down
95 changes: 95 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/SelectorPoller.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package cats.effect
package unsafe

import java.nio.channels.SelectableChannel
import java.nio.channels.spi.AbstractSelector

final class SelectorPoller(
selector: AbstractSelector)(
reportFailure: Throwable => Unit)
extends Poller {

import SelectorPoller.CallbackNode

def close(): Unit =
selector.close()

def interrupt(targetThread: Thread): Unit = {
selector.wakeup()
()
}

def poll(nanos: Long): Boolean = {
val millis = if (nanos >= 0) nanos / 1000000 else -1

if (millis == 0) selector.selectNow()
else if (millis > 0) selector.select(millis)
else selector.select()

if (selector.isOpen()) { // closing selector interrupts select
val ready = selector.selectedKeys().iterator()
while (ready.hasNext()) {
val key = ready.next()
ready.remove()

val readyOps = key.readyOps()

var head: CallbackNode = null
var prev: CallbackNode = null
var node = key.attachment().asInstanceOf[CallbackNode]
while (node ne null) {
val next = node.next

if ((node.interest & readyOps) != 0) { // execute callback and drop this node
val cb = node.callback
if (cb != null) cb(readyOps)
if (prev ne null) prev.next = next
} else { // keep this node
prev = node
if (head eq null)
head = node
}

node = next
}

// reset interest in triggered ops
key.interestOps(key.interestOps() & ~readyOps)
key.attach(head)
}

!selector.keys().isEmpty()
} else false
}

def select(ch: SelectableChannel, ops: Int)(cb: Int => Unit): () => Unit = {
val key = ch.keyFor(selector)

val node = if (key eq null) { // not yet registered on this selector
val node = new CallbackNode(ops, cb, null)
ch.register(selector, ops, node)
node
} else { // existing key
// mixin the new interest
key.interestOps(key.interestOps() | ops)
val node =
new CallbackNode(ops, cb, key.attachment().asInstanceOf[CallbackNode])
key.attach(node)
node
}

{ () =>
// set all interest bits
node.interest = -1
// clear for gc
node.callback = null
}
}
}

private object SelectorPoller {
private[SelectorPoller] final class CallbackNode(
var interest: Int,
var callback: Int => Unit,
var next: CallbackNode)
}
Loading

0 comments on commit 32d881f

Please sign in to comment.