From 393910c36da20b23332a9c928c0b268fdfcaddea Mon Sep 17 00:00:00 2001 From: Erlend Hamnaberg Date: Sat, 7 Dec 2024 12:37:23 +0100 Subject: [PATCH] ensure we drain the publisher if sending empty body --- .../http4s/jdkhttpclient/JdkHttpClient.scala | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/http4s/jdkhttpclient/JdkHttpClient.scala b/core/src/main/scala/org/http4s/jdkhttpclient/JdkHttpClient.scala index 6a53f46e..61aa71d1 100644 --- a/core/src/main/scala/org/http4s/jdkhttpclient/JdkHttpClient.scala +++ b/core/src/main/scala/org/http4s/jdkhttpclient/JdkHttpClient.scala @@ -24,6 +24,7 @@ import fs2.Chunk import fs2.Stream import fs2.concurrent.SignallingRef import fs2.interop.flow +import org.http4s.EmptyBody import org.http4s.Header import org.http4s.Headers import org.http4s.HttpVersion @@ -63,21 +64,24 @@ object JdkHttpClient { def convertRequest(req: Request[F]): Resource[F, HttpRequest] = flow.toPublisher(req.body.chunks.map(_.toByteBuffer)).evalMap { publisher => convertHttpVersionFromHttp4s[F](req.httpVersion).map { version => - def isStreaming = version match { + def consumeFully = (req.body ne EmptyBody) && (version match { case HttpClient.Version.HTTP_1_1 => req.isChunked case HttpClient.Version.HTTP_2 => req.contentLength.isEmpty - } + }) val rb = HttpRequest.newBuilder .method( req.method.name, - if (isStreaming) + if (consumeFully) BodyPublishers.fromPublisher(publisher) else req.contentLength match { case Some(length) if length > 0L => BodyPublishers.fromPublisher(publisher, length) - case _ => BodyPublishers.noBody + case _ => + // If we dont do this, we might block finalization + publisher.subscribe(DrainingSubscriber) + BodyPublishers.noBody } ) .uri(URI.create(req.uri.renderString)) @@ -305,4 +309,12 @@ object JdkHttpClient { "via", "warning" ).map(CIString(_)) + + private object DrainingSubscriber extends Flow.Subscriber[ByteBuffer] { + override def onSubscribe(subscription: Flow.Subscription): Unit = + subscription.request(Long.MaxValue) + override def onNext(item: ByteBuffer): Unit = () + override def onError(throwable: Throwable): Unit = () + override def onComplete(): Unit = () + } }