diff --git a/io/src/main/scala/fs2/io/tls/TLSEngine.scala b/io/src/main/scala/fs2/io/tls/TLSEngine.scala index 73a3d736c0..37ae753d93 100644 --- a/io/src/main/scala/fs2/io/tls/TLSEngine.scala +++ b/io/src/main/scala/fs2/io/tls/TLSEngine.scala @@ -128,12 +128,13 @@ private[tls] object TLSEngine { private def read0(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = // Check if the initial handshake has finished -- if so, read; otherwise, handshake and then read - dequeueUnwrap(maxBytes).flatMap { out => + unwrapThenTakeUnwrapped(maxBytes, timeout).flatMap { out => if (out.isEmpty) initialHandshakeDone.ifM( read1(maxBytes, timeout), - write(Chunk.empty, None) >> dequeueUnwrap(maxBytes).flatMap { out => - if (out.isEmpty) read1(maxBytes, timeout) else Applicative[F].pure(out) + write(Chunk.empty, None) >> unwrapThenTakeUnwrapped(maxBytes, timeout).flatMap { + out => + if (out.isEmpty) read1(maxBytes, timeout) else Applicative[F].pure(out) } ) else Applicative[F].pure(out) @@ -161,7 +162,7 @@ private[tls] object TLSEngine { case SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING => unwrapBuffer.inputRemains .map(_ > 0 && result.bytesConsumed > 0) - .ifM(unwrap(maxBytes, timeout), dequeueUnwrap(maxBytes)) + .ifM(unwrap(maxBytes, timeout), takeUnwrapped(maxBytes)) case SSLEngineResult.HandshakeStatus.FINISHED => unwrap(maxBytes, timeout) case _ => @@ -172,17 +173,23 @@ private[tls] object TLSEngine { ) } case SSLEngineResult.Status.BUFFER_UNDERFLOW => - dequeueUnwrap(maxBytes) + takeUnwrapped(maxBytes) case SSLEngineResult.Status.BUFFER_OVERFLOW => unwrapBuffer.expandOutput >> unwrap(maxBytes, timeout) case SSLEngineResult.Status.CLOSED => - stopWrap >> stopUnwrap >> dequeueUnwrap(maxBytes) + stopWrap >> stopUnwrap >> takeUnwrapped(maxBytes) } } - private def dequeueUnwrap(maxBytes: Int): F[Option[Chunk[Byte]]] = + private def takeUnwrapped(maxBytes: Int): F[Option[Chunk[Byte]]] = unwrapBuffer.output(maxBytes).map(out => if (out.isEmpty) None else Some(out)) + private def unwrapThenTakeUnwrapped( + maxBytes: Int, + timeout: Option[FiniteDuration] + ): F[Option[Chunk[Byte]]] = + unwrapBuffer.inputRemains.map(_ > 0).ifM(unwrap(maxBytes, timeout), takeUnwrapped(maxBytes)) + /** Determines what to do next given the result of a handshake operation. * Must be called with `handshakeSem`. */