Skip to content

Commit

Permalink
chore: run virtual threads with other scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed May 27, 2024
1 parent dc26b6f commit 17dbf69
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
43 changes: 38 additions & 5 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,27 @@ import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{ExecutorService, ThreadFactory}
import java.util.concurrent.{Executor, ExecutorService, ThreadFactory}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[Executor] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
Expand All @@ -24,26 +37,46 @@ object Util {
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
throw new UnsupportedOperationException("Failed to create virtual thread executor", e)
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory, returns null when failed.
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String): ThreadFactory =
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case _: Throwable => null
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
Expand Down
11 changes: 9 additions & 2 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cask.main

import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.{DispatchTrie, Util, ThreadBlockingHandler}
import cask.internal.{DispatchTrie, ThreadBlockingHandler, Util}
import Response.Raw
import cask.router.{Decorator, EndpointMetadata, EntryPoint, Result}
import cask.util.Logger
Expand Down Expand Up @@ -62,9 +62,16 @@ abstract class Main {
null
}

private def screenExecutor(executor: Executor): Executor = {
if (executor eq null) executor
else if (System.getProperty("cask.virtualThread.enabled", "true").toBoolean) {
Util.createVirtualThreadExecutor(executor).getOrElse(executor)
} else executor
}

def defaultHandler: HttpHandler = {
val mainHandler = new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
val executor = handlerExecutor()
val executor = screenExecutor(handlerExecutor())
if (handlerExecutor ne null) {
new ThreadBlockingHandler(executor, mainHandler)
} else new BlockingHandler(mainHandler)
Expand Down
9 changes: 1 addition & 8 deletions example/compress/app/src/Compress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@ package app

import cask.internal.{ThreadBlockingHandler, Util}

import java.util.concurrent.Executor
import java.util.concurrent.{Executor, Executors}

object Compress extends cask.MainRoutes {

protected override val handlerExecutor: Executor = {
if (System.getProperty("cask.virtualThread.enabled", "false").toBoolean) {
Util.createNewThreadPerTaskExecutor(
Util.createVirtualThreadFactory("cask-handler-executor"))
} else null
}

@cask.decorators.compress
@cask.get("/")
def hello(): String = {
Expand Down

0 comments on commit 17dbf69

Please sign in to comment.