diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala index a473ed8a65..e56d5f9dcd 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/WebSocketAppHandler.scala @@ -47,13 +47,14 @@ private[zio] final class WebSocketAppHandler( event: ChannelEvent[JWebSocketFrame], close: Boolean = false, ): Unit = { - zExec.runUninterruptible(ctx, NettyRuntime.noopEnsuring)( - queue.offer(event.map(frameFromNetty)) *> - (onComplete match { - case Some(promise) if close => promise.succeed(ChannelState.Invalid) - case _ => ZIO.unit - }), - ) + // IMPORTANT: Offering to the queue must be run synchronously to avoid messages being added in the wrong order + // Since the queue is unbounded, this will not block the event loop + // TODO: We need to come up with a design that doesn't involve running an effect to offer to the queue + zExec.unsafeRunSync(queue.offer(event.map(frameFromNetty))) + onComplete match { + case Some(promise) if close => promise.unsafe.done(Exit.succeed(ChannelState.Invalid)) + case _ => () + } } override def channelRead0(ctx: ChannelHandlerContext, msg: JWebSocketFrame): Unit = @@ -68,9 +69,8 @@ private[zio] final class WebSocketAppHandler( override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { dispatch(ctx, ChannelEvent.exceptionCaught(cause)) onComplete match { - case Some(promise) => - promise.fail(cause) - case None => + case Some(promise) => promise.unsafe.done(Exit.fail(cause)) + case None => () } } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 4d3e8723a0..b0fa975c5c 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -168,29 +168,30 @@ private[zio] final case class ServerInboundHandler( runtime: NettyRuntime, response: Response, jRequest: HttpRequest, - ): Option[Task[Unit]] = { + ): Task[Option[Task[Unit]]] = { response.body match { case WebsocketBody(socketApp) if response.status == Status.SwitchingProtocols => - upgradeToWebSocket(ctx, jRequest, socketApp, runtime) - None + upgradeToWebSocket(ctx, jRequest, socketApp, runtime).as(None) case _ => - val jResponse = NettyResponseEncoder.encode(response) + ZIO.attempt { + val jResponse = NettyResponseEncoder.encode(response) - if (!jResponse.isInstanceOf[FullHttpResponse]) { + if (!jResponse.isInstanceOf[FullHttpResponse]) { - // We MUST get the content length from the headers BEFORE we call writeAndFlush otherwise netty will mutate - // the headers and remove `content-length` since there is no content - val contentLength = - jResponse.headers().get(HttpHeaderNames.CONTENT_LENGTH) match { - case null => None - case value => Some(value.toLong) - } + // We MUST get the content length from the headers BEFORE we call writeAndFlush otherwise netty will mutate + // the headers and remove `content-length` since there is no content + val contentLength = + jResponse.headers().get(HttpHeaderNames.CONTENT_LENGTH) match { + case null => None + case value => Some(value.toLong) + } - ctx.writeAndFlush(jResponse) - NettyBodyWriter.writeAndFlush(response.body, contentLength, ctx, isResponseCompressible(jRequest)) - } else { - ctx.writeAndFlush(jResponse) - None + ctx.writeAndFlush(jResponse) + NettyBodyWriter.writeAndFlush(response.body, contentLength, ctx, isResponseCompressible(jRequest)) + } else { + ctx.writeAndFlush(jResponse) + None + } } } } @@ -266,51 +267,51 @@ private[zio] final case class ServerInboundHandler( } - // TODO: reimplement it on server settings level -// private def setServerTime(time: ServerTime, response: Response, jResponse: HttpResponse): Unit = { -// val _ = -// if (response.addServerTime) -// jResponse.headers().set(HttpHeaderNames.DATE, time.refreshAndGet()) -// } - /* * Checks if the response requires to switch protocol to websocket. Returns * true if it can, otherwise returns false */ - @tailrec private def upgradeToWebSocket( ctx: ChannelHandlerContext, jReq: HttpRequest, webSocketApp: WebSocketApp[Any], runtime: NettyRuntime, - ): Unit = { + ): Task[Unit] = { jReq match { case jReq: FullHttpRequest => - val queue = - runtime.unsafeRunSync { - Queue.unbounded[WebSocketChannelEvent].tap { queue => + Queue + .unbounded[WebSocketChannelEvent] + .tap { queue => + ZIO.suspend { val nettyChannel = NettyChannel.make[JWebSocketFrame](ctx.channel()) val webSocketChannel = WebSocketChannel.make(nettyChannel, queue) webSocketApp.handler.runZIO(webSocketChannel).ignoreLogged.forkDaemon } } - ctx - .channel() - .pipeline() - .addLast( - new WebSocketServerProtocolHandler( - NettySocketProtocol.serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig)).build(), - ), - ) - .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None)) - - val retained = jReq.retainedDuplicate() - val _ = ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(retained) } - - case jReq: HttpRequest => - val fullRequest = new DefaultFullHttpRequest(jReq.protocolVersion(), jReq.method(), jReq.uri()) - fullRequest.headers().setAll(jReq.headers()) - upgradeToWebSocket(ctx: ChannelHandlerContext, fullRequest, webSocketApp, runtime) + .flatMap { queue => + ZIO.attempt { + ctx + .channel() + .pipeline() + .addLast( + new WebSocketServerProtocolHandler( + NettySocketProtocol + .serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig)) + .build(), + ), + ) + .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None)) + + val retained = jReq.retainedDuplicate() + val _ = ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(retained) } + } + } + case jReq: HttpRequest => + ZIO.suspend { + val fullRequest = new DefaultFullHttpRequest(jReq.protocolVersion(), jReq.method(), jReq.uri()) + fullRequest.headers().setAll(jReq.headers()) + upgradeToWebSocket(ctx: ChannelHandlerContext, fullRequest, webSocketApp, runtime) + } } } @@ -338,18 +339,18 @@ private[zio] final case class ServerInboundHandler( }, ) }.flatMap { response => - ZIO.attempt { + ZIO.suspend { if (response ne null) { val done = attemptFastWrite(ctx, response) if (!done) attemptFullWrite(ctx, runtime, response, jReq) else - None + ZIO.none } else { if (ctx.channel().isOpen) { writeNotFound(ctx, jReq) } - None + ZIO.none } }.foldCauseZIO( cause => ZIO.attempt(attemptFastWrite(ctx, withDefaultErrorResponse(cause.squash))),