Skip to content

Commit

Permalink
Add compressor, implement compression in HttpClient
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Dec 27, 2024
1 parent b2a981e commit ccd17c0
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 27 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ val pekkoStreams = "org.apache.pekko" %% "pekko-stream" % pekkoStreamVersion
val scalaTest = libraryDependencies ++= Seq("freespec", "funsuite", "flatspec", "wordspec", "shouldmatchers").map(m =>
"org.scalatest" %%% s"scalatest-$m" % "3.2.19" % Test
)
val scalaTestPlusScalaCheck = libraryDependencies += "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test

val zio1Version = "1.0.18"
val zio2Version = "2.1.14"
Expand Down Expand Up @@ -318,7 +319,8 @@ lazy val core = (projectMatrix in file("core"))
"com.softwaremill.sttp.shared" %%% "core" % sttpSharedVersion,
"com.softwaremill.sttp.shared" %%% "ws" % sttpSharedVersion
),
scalaTest
scalaTest,
scalaTestPlusScalaCheck
)
.settings(testServerSettings)
.jvmPlatform(
Expand Down
131 changes: 131 additions & 0 deletions core/src/main/scala/sttp/client4/internal/compression/Compressor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package sttp.client4.internal.compression

import sttp.client4._
import sttp.model.Encodings

import Compressor._
import java.io.FileInputStream
import java.nio.ByteBuffer
import java.util.zip.DeflaterInputStream
import java.util.zip.Deflater
import java.io.ByteArrayOutputStream

private[client4] trait Compressor {
def encoding: String
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R]
}

private[client4] object GZipDefaultCompressor extends Compressor {
val encoding: String = Encodings.Gzip

def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(GZIPCompressingInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(GZIPCompressingInputStream(new FileInputStream(f.toFile)), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val bos = new java.io.ByteArrayOutputStream()
val gzip = new java.util.zip.GZIPOutputStream(bos)
gzip.write(bytes)
gzip.close()
bos.toByteArray()
}
}

private[client4] object DeflateDefaultCompressor extends Compressor {
val encoding: String = Encodings.Deflate

def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(DeflaterInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(DeflaterInputStream(new FileInputStream(f.toFile)), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val deflater = new Deflater()
try {
deflater.setInput(bytes)
deflater.finish()
val byteArrayOutputStream = new ByteArrayOutputStream()
val readBuffer = new Array[Byte](1024)

while (!deflater.finished()) {
val readCount = deflater.deflate(readBuffer)
if (readCount > 0) {
byteArrayOutputStream.write(readBuffer, 0, readCount)
}
}

byteArrayOutputStream.toByteArray
} finally deflater.end()
}
}

private[client4] object Compressor {
def compressIfNeeded[T, R](
request: GenericRequest[T, R],
compressors: List[Compressor]
): (GenericRequestBody[R], Option[Long]) =
request.options.compressRequestBody match {
case Some(encoding) =>
val compressedBody = compressors.find(_.encoding.equalsIgnoreCase(encoding)) match {
case Some(compressor) => compressor(request.body, encoding)
case None => throw new IllegalArgumentException(s"Unsupported encoding: $encoding")
}

val contentLength = calculateContentLength(compressedBody)
(compressedBody, contentLength)

case None => (request.body, request.contentLength)
}

private def calculateContentLength[R](body: GenericRequestBody[R]): Option[Long] = body match {
case NoBody => None
case StringBody(b, e, _) => Some(b.getBytes(e).length.toLong)
case ByteArrayBody(b, _) => Some(b.length.toLong)
case ByteBufferBody(b, _) => None
case InputStreamBody(b, _) => None
case FileBody(f, _) => Some(f.toFile.length())
case StreamBody(_) => None
case MultipartStreamBody(parts) => None
case BasicMultipartBody(parts) => None
}

private[compression] def compressingMultipartBodiesNotSupported: Nothing =
throw new IllegalArgumentException("Multipart bodies cannot be compressed")

private[compression] def streamsNotSupported: Nothing =
throw new IllegalArgumentException("Streams are not supported")

private[compression] def byteBufferToArray(inputBuffer: ByteBuffer): Array[Byte] =
if (inputBuffer.hasArray()) {
inputBuffer.array()
} else {
val inputBytes = new Array[Byte](inputBuffer.remaining())
inputBuffer.get(inputBytes)
inputBytes
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package sttp.client4.internal.compression

import java.io.{ByteArrayInputStream, IOException, InputStream}
import java.util.zip.{CRC32, Deflater}

// based on:
// https://github.com/http4k/http4k/blob/master/core/core/src/main/kotlin/org/http4k/filter/Gzip.kt#L124
// https://stackoverflow.com/questions/11036280/compress-an-inputstream-with-gzip
private[client4] class GZIPCompressingInputStream(
source: InputStream,
compressionLevel: Int = java.util.zip.Deflater.DEFAULT_COMPRESSION
) extends InputStream {

private object State extends Enumeration {
type State = Value
val HEADER, DATA, FINALISE, TRAILER, DONE = Value
}

import State._

private val GZIP_MAGIC = 0x8b1f
private val HEADER_DATA: Array[Byte] = Array(
GZIP_MAGIC.toByte,
(GZIP_MAGIC >> 8).toByte,
Deflater.DEFLATED.toByte,
0,
0,
0,
0,
0,
0,
0
)
private val INITIAL_BUFFER_SIZE = 8192

private val deflater = new Deflater(Deflater.DEFLATED, true)
deflater.setLevel(compressionLevel)

private val crc = new CRC32()
private var trailer: ByteArrayInputStream = _
private val header = new ByteArrayInputStream(HEADER_DATA)

private var deflationBuffer: Array[Byte] = new Array[Byte](INITIAL_BUFFER_SIZE)
private var stage: State = HEADER

override def read(): Int = {
val readBytes = new Array[Byte](1)
var bytesRead = 0
while (bytesRead == 0)
bytesRead = read(readBytes, 0, 1)
if (bytesRead != -1) readBytes(0) & 0xff else -1
}

@throws[IOException]
override def read(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = stage match {
case HEADER =>
val bytesRead = header.read(readBuffer, readOffset, readLength)
if (header.available() == 0) stage = DATA
bytesRead

case DATA =>
if (!deflater.needsInput) {
deflatePendingInput(readBuffer, readOffset, readLength)
} else {
if (deflationBuffer.length < readLength) {
deflationBuffer = new Array[Byte](readLength)
}

val bytesRead = source.read(deflationBuffer, 0, readLength)
if (bytesRead <= 0) {
stage = FINALISE
deflater.finish()
0
} else {
crc.update(deflationBuffer, 0, bytesRead)
deflater.setInput(deflationBuffer, 0, bytesRead)
deflatePendingInput(readBuffer, readOffset, readLength)
}
}

case FINALISE =>
if (deflater.finished()) {
stage = TRAILER
val crcValue = crc.getValue.toInt
val totalIn = deflater.getTotalIn
trailer = createTrailer(crcValue, totalIn)
0
} else {
deflater.deflate(readBuffer, readOffset, readLength, Deflater.FULL_FLUSH)
}

case TRAILER =>
val bytesRead = trailer.read(readBuffer, readOffset, readLength)
if (trailer.available() == 0) stage = DONE
bytesRead

case DONE => -1
}

private def deflatePendingInput(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = {
var bytesCompressed = 0
while (!deflater.needsInput && readLength - bytesCompressed > 0)
bytesCompressed += deflater.deflate(
readBuffer,
readOffset + bytesCompressed,
readLength - bytesCompressed,
Deflater.FULL_FLUSH
)
bytesCompressed
}

private def createTrailer(crcValue: Int, totalIn: Int): ByteArrayInputStream =
new ByteArrayInputStream(
Array(
(crcValue >> 0).toByte,
(crcValue >> 8).toByte,
(crcValue >> 16).toByte,
(crcValue >> 24).toByte,
(totalIn >> 0).toByte,
(totalIn >> 8).toByte,
(totalIn >> 16).toByte,
(totalIn >> 24).toByte
)
)

override def available(): Int = if (stage == DONE) 0 else 1

@throws[IOException]
override def close(): Unit = {
source.close()
deflater.end()
if (trailer != null) trailer.close()
header.close()
}

crc.reset()
}
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/client4/request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import sttp.attributes.AttributeMap
* ability to send and receive streaming bodies) or [[sttp.capabilities.WebSockets]] (the ability to handle websocket
* requests).
*/
trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata {
sealed trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata {
def body: GenericRequestBody[R]
def response: ResponseAsDelegate[T, R]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
package sttp.client4.httpclient

import sttp.capabilities.{Effect, Streams}
import sttp.capabilities.Effect
import sttp.capabilities.Streams
import sttp.client4.Backend
import sttp.client4.BackendOptions
import sttp.client4.BackendOptions.Proxy
import sttp.client4.GenericBackend
import sttp.client4.GenericRequest
import sttp.client4.MultipartBody
import sttp.client4.Response
import sttp.client4.SttpClientException
import sttp.client4.httpclient.HttpClientBackend.EncodingHandler
import sttp.client4.internal.SttpToJavaConverters.toJavaFunction
import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer}
import sttp.client4.internal.ws.SimpleQueue
import sttp.client4.{
Backend,
BackendOptions,
GenericBackend,
GenericRequest,
MultipartBody,
Response,
SttpClientException
}
import sttp.model.HttpVersion.{HTTP_1_1, HTTP_2}
import sttp.client4.internal.httpclient.BodyFromHttpClient
import sttp.client4.internal.httpclient.BodyToHttpClient
import sttp.model._
import sttp.model.HttpVersion.HTTP_1_1
import sttp.model.HttpVersion.HTTP_2
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.ws.WebSocket

import java.net.Authenticator
import java.net.Authenticator.RequestorType
import java.net.http.{HttpClient, HttpRequest, HttpResponse, WebSocket => JWebSocket}
import java.net.{Authenticator, PasswordAuthentication}
import java.net.PasswordAuthentication
import java.net.http.HttpClient
import java.net.http.HttpRequest
import java.net.http.HttpResponse
import java.net.http.{WebSocket => JWebSocket}
import java.time.{Duration => JDuration}
import java.util.concurrent.{Executor, ThreadPoolExecutor}
import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import java.util.function
import scala.collection.JavaConverters._

Expand Down Expand Up @@ -117,7 +122,7 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B](
resBody.left
.map { is =>
encoding
.filterNot(e => code.equals(StatusCode.NoContent) || request.autoDecompressionDisabled || e.isEmpty)
.filterNot(e => code.equals(StatusCode.NoContent) || !request.autoDecompressionEnabled || e.isEmpty)
.map(e => customEncodingHandler.applyOrElse((is, e), standardEncoding.tupled))
.getOrElse(is)
}
Expand Down Expand Up @@ -166,16 +171,16 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B](
}
override def close(): F[Unit] =
if (closeClient) {
monad.eval(
client
monad.eval {
val _ = client
.executor()
.map[Unit](new function.Function[Executor, Unit] {
override def apply(t: Executor): Unit = t match {
case tpe: ThreadPoolExecutor => tpe.shutdown()
case _ => ()
}
})
)
}
} else {
monad.unit(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ class HttpClientSyncBackend private (
val isOpen: AtomicBoolean = new AtomicBoolean(false)
val responseCell = new ArrayBlockingQueue[Either[Throwable, () => Response[T]]](1)

def fillCellError(t: Throwable): Unit = responseCell.add(Left(t)): Unit
def fillCell(wr: () => Response[T]): Unit = responseCell.add(Right(wr)): Unit
def fillCellError(t: Throwable): Unit = { val _ = responseCell.add(Left(t)) }
def fillCell(wr: () => Response[T]): Unit = { val _ = responseCell.add(Right(wr)) }

val listener = new DelegatingWebSocketListener(
new AddToQueueListener(queue, isOpen),
ws => {
val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, _.get(): Unit)
val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, cf => { val _ = cf.get() })
val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata)
val body = () => bodyFromHttpClient(Right(webSocket), request.response, baseResponse)
fillCell(() => baseResponse.copy(body = body()))
Expand Down
Loading

0 comments on commit ccd17c0

Please sign in to comment.