diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala b/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala index eb129f166e..6425d501a6 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala @@ -77,6 +77,8 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent]( val _ = ctx.channel().config().setAutoRead(previousAutoRead) } + protected def onLastMessage(): Unit = () + override def channelRead0( ctx: ChannelHandlerContext, msg: HttpContent, @@ -87,6 +89,12 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent]( val isLast = msg.isInstanceOf[LastHttpContent] val content = ByteBufUtil.getBytes(msg.content()) + if (isLast) { + readingDone = true + ctx.channel().pipeline().remove(this) + onLastMessage() + } + state match { case State.Buffering => // `connect` method hasn't been called yet, add all incoming content to the buffer @@ -103,13 +111,7 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent]( callback(Chunk.fromArray(content), isLast) } - if (isLast) { - readingDone = true - ctx.channel().pipeline().remove(this) - } else { - ctx.read() - } - () + if (!isLast) ctx.read(): Unit } } @@ -137,6 +139,8 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent]( } object AsyncBodyReader { + private val FnUnit = () => () + sealed trait State object State { diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala index b898944d4b..28e704771e 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala @@ -35,18 +35,16 @@ final class ClientResponseStreamHandler( private implicit val unsafe: Unsafe = Unsafe.unsafe + override def onLastMessage(): Unit = + if (keepAlive) + onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) + else + onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid)) + override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = { val isLast = msg.isInstanceOf[LastHttpContent] super.channelRead0(ctx, msg) - - if (isLast) { - if (keepAlive) - onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) - else { - onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid)) - ctx.close(): Unit - } - } + if (isLast && !keepAlive) ctx.close(): Unit } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit =