Skip to content

Commit

Permalink
Refactor ServerRunner to use Resource, fix unsafe .allocated cases (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-klass authored Jan 31, 2025
1 parent ba21ead commit 32ccf15
Show file tree
Hide file tree
Showing 20 changed files with 152 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,22 @@ import org.http4s._
import org.slf4j.LoggerFactory
import org.typelevel.ci.CIString
import scodec.bits.ByteVector
import sttp.tapir.client.tests.HttpServer._

import scala.concurrent.ExecutionContext

object HttpServer {
object HttpServer extends ResourceApp.Forever {
type Port = Int

def main(args: Array[String]): Unit = {
def run(args: List[String]): Resource[IO, Unit] = {
val port = args.headOption.map(_.toInt).getOrElse(51823)
new HttpServer(port).start()
new HttpServer(port).build.void
}
}

class HttpServer(port: Port) {
class HttpServer(port: HttpServer.Port) {

private val logger = LoggerFactory.getLogger(getClass)

private var stopServer: IO[Unit] = _

//

private object numParam extends QueryParamDecoderMatcher[Int]("num")
Expand Down Expand Up @@ -212,23 +209,11 @@ class HttpServer(port: Port) {

//

def start(): Unit = {
val (_, _stopServer) = BlazeServerBuilder[IO]
def build: Resource[IO, server.Server] = BlazeServerBuilder[IO]
.withExecutionContext(ExecutionContext.global)
.bindHttp(port)
.withHttpWebSocketApp(app)
.resource
.map(_.address.getPort)
.allocated
.unsafeRunSync()

stopServer = _stopServer

logger.info(s"Server on port $port started")
}

def close(): Unit = {
stopServer.unsafeRunSync()
logger.info(s"Server on port $port stopped")
}
.evalTap(_ => IO(logger.info(s"Server on port $port started")))
.onFinalize(IO(logger.info(s"Server on port $port stopped")))
}
3 changes: 1 addition & 2 deletions doc/tutorials/07_cats_effect.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ object HelloWorldTapir extends IOApp:
.bindHttp(8080, "localhost")
.withHttpApp(Router("/" -> allRoutes).orNotFound)
.resource
.use(_ => IO.never)
.as(ExitCode.Success)
.useForever
```

Hence, we first generate endpoint descriptions, which correspond to exposing the Swagger UI (containing the generated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,4 @@ object ProxyHttp4sFs2Server extends IOApp:
.bindHttp(8080, "localhost")
.withHttpApp(Router("/" -> routes).orNotFound)
.resource
} yield ())
.use { _ => IO.never }
.as(ExitCode.Success)
} yield ()).useForever
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,4 @@ object StreamingHttp4sFs2ServerOrError extends IOApp:
.bindHttp(8080, "localhost")
.withHttpApp(Router("/" -> userDataRoutes).orNotFound)
.resource
.use { _ => IO.never }
.as(ExitCode.Success)
.useForever
3 changes: 1 addition & 2 deletions generated-doc/out/tutorials/07_cats_effect.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ object HelloWorldTapir extends IOApp:
.bindHttp(8080, "localhost")
.withHttpApp(Router("/" -> allRoutes).orNotFound)
.resource
.use(_ => IO.never)
.as(ExitCode.Success)
.useForever
```

Hence, we first generate endpoint descriptions, which correspond to exposing the Swagger UI (containing the generated
Expand Down
48 changes: 21 additions & 27 deletions perf-tests/src/main/scala/sttp/tapir/perf/apis/ServerRunner.scala
Original file line number Diff line number Diff line change
@@ -1,48 +1,42 @@
package sttp.tapir.perf.apis

import cats.effect.{ExitCode, IO, IOApp}
import cats.effect.{IO, Resource, ResourceApp}

import scala.reflect.runtime.universe

trait ServerRunner {
def start: IO[ServerRunner.KillSwitch]
def runServer: Resource[IO, Unit]
}

/** Can be used as a Main object to run a single server using its short name. Running perfTests/runMain
* [[sttp.tapir.perf.apis.ServerRunner]] will load special javaOptions configured in build.sbt, enabling recording JFR metrics. This is
* useful when you want to guarantee that the server runs in a different JVM than test runner, so that memory and CPU metrics are recorded
* only in the scope of the server JVM.
*/
object ServerRunner extends IOApp {
type KillSwitch = IO[Unit]
val NoopKillSwitch = IO.pure(IO.unit)
object ServerRunner extends ResourceApp.Forever {

private val runtimeMirror = universe.runtimeMirror(getClass.getClassLoader)
private val requireArg: Resource[IO, Unit] = Resource.raiseError(
new IllegalArgumentException(s"Unspecified server name. Use one of: ${TypeScanner.allServers}"): Throwable
)
private def notInstantiated(name: ServerName)(e: Throwable): IO[ServerRunner] = IO.raiseError(
new IllegalArgumentException(
s"ERROR! Could not find object ${name.fullName} or it doesn't extend ServerRunner", e
)
)

def run(args: List[String]): IO[ExitCode] = {
val shortServerName = args.headOption.getOrElse {
throw new IllegalArgumentException(s"Unspecified server name. Use one of: ${TypeScanner.allServers}")
}
for {
killSwitch <- startServerByTypeName(ServerName.fromShort(shortServerName))
_ <- IO.never.guarantee(killSwitch)
} yield ExitCode.Success
}
def run(args: List[String]): Resource[IO, Unit] =
args.headOption.map(ServerName.fromShort).map(startServerByTypeName).getOrElse(requireArg)

def startServerByTypeName(serverName: ServerName): IO[ServerRunner.KillSwitch] = {
def startServerByTypeName(serverName: ServerName): Resource[IO, Unit] =
serverName match {
case ExternalServerName => NoopKillSwitch
case _ =>
try {
case ExternalServerName => Resource.unit
case _ => Resource.eval(
IO({
val moduleSymbol = runtimeMirror.staticModule(serverName.fullName)
val moduleMirror = runtimeMirror.reflectModule(moduleSymbol)
val instance: ServerRunner = moduleMirror.instance.asInstanceOf[ServerRunner]
instance.start
} catch {
case e: Throwable =>
IO.raiseError(
new IllegalArgumentException(s"ERROR! Could not find object ${serverName.fullName} or it doesn't extend ServerRunner", e)
)
}
moduleMirror.instance.asInstanceOf[ServerRunner]
}).handleErrorWith(notInstantiated(serverName))
).flatMap(_.runServer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import io.github.classgraph.ClassGraph

import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

import sttp.tapir.perf.Common._

Expand Down
21 changes: 8 additions & 13 deletions perf-tests/src/main/scala/sttp/tapir/perf/http4s/Http4s.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,21 @@ object Tapir extends Endpoints {
object server {
val maxConnections = 65536
val connectorPoolSize: Int = Math.max(2, Runtime.getRuntime.availableProcessors() / 4)
def runServer(
router: WebSocketBuilder2[IO] => HttpRoutes[IO]
): IO[ServerRunner.KillSwitch] =
def runServer(router: WebSocketBuilder2[IO] => HttpRoutes[IO]): Resource[IO, Unit] =
BlazeServerBuilder[IO]
.bindHttp(Port, "localhost")
.withHttpWebSocketApp(wsb => router(wsb).orNotFound)
.withMaxConnections(maxConnections)
.withConnectorPoolSize(connectorPoolSize)
.resource
.allocated
.map(_._2)
.map(_.flatTap { _ =>
IO.println("Http4s server closed.")
})
.map(_ => ())
.onFinalize(IO.println("Http4s server closed."))
}

object TapirServer extends ServerRunner { override def start = server.runServer(Tapir.router(1)) }
object TapirMultiServer extends ServerRunner { override def start = server.runServer(Tapir.router(128)) }
object TapirServer extends ServerRunner { override def runServer = server.runServer(Tapir.router(1)) }
object TapirMultiServer extends ServerRunner { override def runServer = server.runServer(Tapir.router(128)) }
object TapirInterceptorMultiServer extends ServerRunner {
override def start = server.runServer(Tapir.router(128, withServerLog = true))
override def runServer = server.runServer(Tapir.router(128, withServerLog = true))
}
object VanillaServer extends ServerRunner { override def start = server.runServer(Vanilla.router(1)) }
object VanillaMultiServer extends ServerRunner { override def start = server.runServer(Vanilla.router(128)) }
object VanillaServer extends ServerRunner { override def runServer = server.runServer(Vanilla.router(1)) }
object VanillaMultiServer extends ServerRunner { override def runServer = server.runServer(Vanilla.router(128)) }
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import sttp.tapir.perf.apis._
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.cats.NettyCatsServer
import sttp.tapir.server.netty.cats.NettyCatsServerOptions
import sttp.ws.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams

import scala.concurrent.duration._
Expand All @@ -33,27 +32,26 @@ object NettyCats {
Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(()))
}
)
def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = {
def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): Resource[IO, Unit] = {
val declaredPort = Port
val declaredHost = "0.0.0.0"
(for {
for {
dispatcher <- Dispatcher.parallel[IO]
serverOptions = buildOptions(NettyCatsServerOptions.customiseInterceptors(dispatcher), withServerLog)
server <- NettyCatsServer.io()
_ <-
Resource.make(
server
.port(declaredPort)
.host(declaredHost)
.addEndpoints(wsServerEndpoint :: endpoints)
.start()
)(binding => binding.stop())
} yield ()).allocated.map(_._2)
server <- NettyCatsServer.io().map(_.options(serverOptions))
_ <- Resource.make(
server
.port(declaredPort)
.host(declaredHost)
.addEndpoints(wsServerEndpoint :: endpoints)
.start()
)(_.stop())
} yield ()
}
}

object TapirServer extends ServerRunner { override def start = NettyCats.runServer(Tapir.genEndpointsIO(1)) }
object TapirMultiServer extends ServerRunner { override def start = NettyCats.runServer(Tapir.genEndpointsIO(128)) }
object TapirServer extends ServerRunner { override def runServer = NettyCats.runServer(Tapir.genEndpointsIO(1)) }
object TapirMultiServer extends ServerRunner { override def runServer = NettyCats.runServer(Tapir.genEndpointsIO(128)) }
object TapirInterceptorMultiServer extends ServerRunner {
override def start = NettyCats.runServer(Tapir.genEndpointsIO(128), withServerLog = true)
override def runServer = NettyCats.runServer(Tapir.genEndpointsIO(128), withServerLog = true)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sttp.tapir.perf.netty.future

import cats.effect.IO
import cats.effect.{IO, Resource}
import sttp.tapir.perf.apis._
import sttp.tapir.perf.Common._
import sttp.tapir.server.netty.{NettyFutureServer, NettyFutureServerBinding, NettyFutureServerOptions}
Expand All @@ -14,7 +14,7 @@ object Tapir extends Endpoints

object NettyFuture {

def runServer(endpoints: List[ServerEndpoint[Any, Future]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = {
def runServer(endpoints: List[ServerEndpoint[Any, Future]], withServerLog: Boolean = false): Resource[IO, Unit] = {
val declaredPort = Port
val declaredHost = "0.0.0.0"
val serverOptions = buildOptions(NettyFutureServerOptions.customiseInterceptors, withServerLog)
Expand All @@ -29,13 +29,12 @@ object NettyFuture {
.start()
)
)

serverBinding.map(b => IO.fromFuture(IO(b.stop())))
Resource.make(serverBinding)(b => IO.fromFuture(IO(b.stop()))).map(_ => ())
}
}

object TapirServer extends ServerRunner { override def start = NettyFuture.runServer(Tapir.genEndpointsFuture(1)) }
object TapirMultiServer extends ServerRunner { override def start = NettyFuture.runServer(Tapir.genEndpointsFuture(128)) }
object TapirServer extends ServerRunner { override def runServer = NettyFuture.runServer(Tapir.genEndpointsFuture(1)) }
object TapirMultiServer extends ServerRunner { override def runServer = NettyFuture.runServer(Tapir.genEndpointsFuture(128)) }
object TapirInterceptorMultiServer extends ServerRunner {
override def start = NettyFuture.runServer(Tapir.genEndpointsFuture(128), withServerLog = true)
override def runServer = NettyFuture.runServer(Tapir.genEndpointsFuture(128), withServerLog = true)
}
32 changes: 17 additions & 15 deletions perf-tests/src/main/scala/sttp/tapir/perf/nima/Nima.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sttp.tapir.perf.nima

import cats.effect.IO
import cats.effect.{IO, Resource}
import io.helidon.webserver.WebServer
import sttp.shared.Identity
import sttp.tapir.perf.apis._
Expand All @@ -14,27 +14,29 @@ object Tapir extends Endpoints {

object Nima {

def runServer(endpoints: List[ServerEndpoint[Any, Identity]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = {
def runServer(endpoints: List[ServerEndpoint[Any, Identity]], withServerLog: Boolean = false): Resource[IO, Unit] = {
val declaredPort = Port
val serverOptions = buildOptions(NimaServerOptions.customiseInterceptors, withServerLog)
// Starting Nima server

val handler = NimaServerInterpreter(serverOptions).toHandler(endpoints)
val server = WebServer
.builder()
.routing { builder =>
builder.any(handler)
()
}
.port(declaredPort)
.build()
.start()
IO(IO { val _ = server.stop() })
val startServer = IO {
WebServer
.builder()
.routing { builder =>
builder.any(handler)
()
}
.port(declaredPort)
.build()
.start()
}
Resource.make(startServer)(server => IO(server.stop()).void).map(_ => ())
}
}

object TapirServer extends ServerRunner { override def start = Nima.runServer(Tapir.genEndpointsNId(1)) }
object TapirMultiServer extends ServerRunner { override def start = Nima.runServer(Tapir.genEndpointsNId(128)) }
object TapirServer extends ServerRunner { override def runServer = Nima.runServer(Tapir.genEndpointsNId(1)) }
object TapirMultiServer extends ServerRunner { override def runServer = Nima.runServer(Tapir.genEndpointsNId(128)) }
object TapirInterceptorMultiServer extends ServerRunner {
override def start = Nima.runServer(Tapir.genEndpointsNId(128), withServerLog = true)
override def runServer = Nima.runServer(Tapir.genEndpointsNId(128), withServerLog = true)
}
Loading

0 comments on commit 32ccf15

Please sign in to comment.