-
Notifications
You must be signed in to change notification settings - Fork 311
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add compressor, implement compression in HttpClient
- Loading branch information
Showing
9 changed files
with
362 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
core/src/main/scala/sttp/client4/internal/compression/Compressor.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package sttp.client4.internal.compression | ||
|
||
import sttp.client4._ | ||
import sttp.model.Encodings | ||
|
||
import Compressor._ | ||
import java.io.FileInputStream | ||
import java.nio.ByteBuffer | ||
import java.util.zip.DeflaterInputStream | ||
import java.util.zip.Deflater | ||
import java.io.ByteArrayOutputStream | ||
|
||
private[client4] trait Compressor { | ||
def encoding: String | ||
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] | ||
} | ||
|
||
private[client4] object GZipDefaultCompressor extends Compressor { | ||
val encoding: String = Encodings.Gzip | ||
|
||
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = | ||
body match { | ||
case NoBody => NoBody | ||
case StringBody(s, encoding, defaultContentType) => | ||
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) | ||
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) | ||
case ByteBufferBody(b, defaultContentType) => | ||
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) | ||
case InputStreamBody(b, defaultContentType) => | ||
InputStreamBody(GZIPCompressingInputStream(b), defaultContentType) | ||
case StreamBody(b) => streamsNotSupported | ||
case FileBody(f, defaultContentType) => | ||
InputStreamBody(GZIPCompressingInputStream(new FileInputStream(f.toFile)), defaultContentType) | ||
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported | ||
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported | ||
} | ||
|
||
private def byteArray(bytes: Array[Byte]): Array[Byte] = { | ||
val bos = new java.io.ByteArrayOutputStream() | ||
val gzip = new java.util.zip.GZIPOutputStream(bos) | ||
gzip.write(bytes) | ||
gzip.close() | ||
bos.toByteArray() | ||
} | ||
} | ||
|
||
private[client4] object DeflateDefaultCompressor extends Compressor { | ||
val encoding: String = Encodings.Deflate | ||
|
||
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = | ||
body match { | ||
case NoBody => NoBody | ||
case StringBody(s, encoding, defaultContentType) => | ||
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) | ||
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) | ||
case ByteBufferBody(b, defaultContentType) => | ||
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) | ||
case InputStreamBody(b, defaultContentType) => | ||
InputStreamBody(DeflaterInputStream(b), defaultContentType) | ||
case StreamBody(b) => streamsNotSupported | ||
case FileBody(f, defaultContentType) => | ||
InputStreamBody(DeflaterInputStream(new FileInputStream(f.toFile)), defaultContentType) | ||
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported | ||
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported | ||
} | ||
|
||
private def byteArray(bytes: Array[Byte]): Array[Byte] = { | ||
val deflater = new Deflater() | ||
try { | ||
deflater.setInput(bytes) | ||
deflater.finish() | ||
val byteArrayOutputStream = new ByteArrayOutputStream() | ||
val readBuffer = new Array[Byte](1024) | ||
|
||
while (!deflater.finished()) { | ||
val readCount = deflater.deflate(readBuffer) | ||
if (readCount > 0) { | ||
byteArrayOutputStream.write(readBuffer, 0, readCount) | ||
} | ||
} | ||
|
||
byteArrayOutputStream.toByteArray | ||
} finally deflater.end() | ||
} | ||
} | ||
|
||
private[client4] object Compressor { | ||
def compressIfNeeded[T, R]( | ||
request: GenericRequest[T, R], | ||
compressors: List[Compressor] | ||
): (GenericRequestBody[R], Option[Long]) = | ||
request.options.compressRequestBody match { | ||
case Some(encoding) => | ||
val compressedBody = compressors.find(_.encoding.equalsIgnoreCase(encoding)) match { | ||
case Some(compressor) => compressor(request.body, encoding) | ||
case None => throw new IllegalArgumentException(s"Unsupported encoding: $encoding") | ||
} | ||
|
||
val contentLength = calculateContentLength(compressedBody) | ||
(compressedBody, contentLength) | ||
|
||
case None => (request.body, request.contentLength) | ||
} | ||
|
||
private def calculateContentLength[R](body: GenericRequestBody[R]): Option[Long] = body match { | ||
case NoBody => None | ||
case StringBody(b, e, _) => Some(b.getBytes(e).length.toLong) | ||
case ByteArrayBody(b, _) => Some(b.length.toLong) | ||
case ByteBufferBody(b, _) => None | ||
case InputStreamBody(b, _) => None | ||
case FileBody(f, _) => Some(f.toFile.length()) | ||
case StreamBody(_) => None | ||
case MultipartStreamBody(parts) => None | ||
case BasicMultipartBody(parts) => None | ||
} | ||
|
||
private[compression] def compressingMultipartBodiesNotSupported: Nothing = | ||
throw new IllegalArgumentException("Multipart bodies cannot be compressed") | ||
|
||
private[compression] def streamsNotSupported: Nothing = | ||
throw new IllegalArgumentException("Streams are not supported") | ||
|
||
private[compression] def byteBufferToArray(inputBuffer: ByteBuffer): Array[Byte] = | ||
if (inputBuffer.hasArray()) { | ||
inputBuffer.array() | ||
} else { | ||
val inputBytes = new Array[Byte](inputBuffer.remaining()) | ||
inputBuffer.get(inputBytes) | ||
inputBytes | ||
} | ||
} |
137 changes: 137 additions & 0 deletions
137
core/src/main/scala/sttp/client4/internal/compression/GZIPCompressingInputStream.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package sttp.client4.internal.compression | ||
|
||
import java.io.{ByteArrayInputStream, IOException, InputStream} | ||
import java.util.zip.{CRC32, Deflater} | ||
|
||
// based on: | ||
// https://github.com/http4k/http4k/blob/master/core/core/src/main/kotlin/org/http4k/filter/Gzip.kt#L124 | ||
// https://stackoverflow.com/questions/11036280/compress-an-inputstream-with-gzip | ||
private[client4] class GZIPCompressingInputStream( | ||
source: InputStream, | ||
compressionLevel: Int = java.util.zip.Deflater.DEFAULT_COMPRESSION | ||
) extends InputStream { | ||
|
||
private object State extends Enumeration { | ||
type State = Value | ||
val HEADER, DATA, FINALISE, TRAILER, DONE = Value | ||
} | ||
|
||
import State._ | ||
|
||
private val GZIP_MAGIC = 0x8b1f | ||
private val HEADER_DATA: Array[Byte] = Array( | ||
GZIP_MAGIC.toByte, | ||
(GZIP_MAGIC >> 8).toByte, | ||
Deflater.DEFLATED.toByte, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0 | ||
) | ||
private val INITIAL_BUFFER_SIZE = 8192 | ||
|
||
private val deflater = new Deflater(Deflater.DEFLATED, true) | ||
deflater.setLevel(compressionLevel) | ||
|
||
private val crc = new CRC32() | ||
private var trailer: ByteArrayInputStream = _ | ||
private val header = new ByteArrayInputStream(HEADER_DATA) | ||
|
||
private var deflationBuffer: Array[Byte] = new Array[Byte](INITIAL_BUFFER_SIZE) | ||
private var stage: State = HEADER | ||
|
||
override def read(): Int = { | ||
val readBytes = new Array[Byte](1) | ||
var bytesRead = 0 | ||
while (bytesRead == 0) | ||
bytesRead = read(readBytes, 0, 1) | ||
if (bytesRead != -1) readBytes(0) & 0xff else -1 | ||
} | ||
|
||
@throws[IOException] | ||
override def read(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = stage match { | ||
case HEADER => | ||
val bytesRead = header.read(readBuffer, readOffset, readLength) | ||
if (header.available() == 0) stage = DATA | ||
bytesRead | ||
|
||
case DATA => | ||
if (!deflater.needsInput) { | ||
deflatePendingInput(readBuffer, readOffset, readLength) | ||
} else { | ||
if (deflationBuffer.length < readLength) { | ||
deflationBuffer = new Array[Byte](readLength) | ||
} | ||
|
||
val bytesRead = source.read(deflationBuffer, 0, readLength) | ||
if (bytesRead <= 0) { | ||
stage = FINALISE | ||
deflater.finish() | ||
0 | ||
} else { | ||
crc.update(deflationBuffer, 0, bytesRead) | ||
deflater.setInput(deflationBuffer, 0, bytesRead) | ||
deflatePendingInput(readBuffer, readOffset, readLength) | ||
} | ||
} | ||
|
||
case FINALISE => | ||
if (deflater.finished()) { | ||
stage = TRAILER | ||
val crcValue = crc.getValue.toInt | ||
val totalIn = deflater.getTotalIn | ||
trailer = createTrailer(crcValue, totalIn) | ||
0 | ||
} else { | ||
deflater.deflate(readBuffer, readOffset, readLength, Deflater.FULL_FLUSH) | ||
} | ||
|
||
case TRAILER => | ||
val bytesRead = trailer.read(readBuffer, readOffset, readLength) | ||
if (trailer.available() == 0) stage = DONE | ||
bytesRead | ||
|
||
case DONE => -1 | ||
} | ||
|
||
private def deflatePendingInput(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = { | ||
var bytesCompressed = 0 | ||
while (!deflater.needsInput && readLength - bytesCompressed > 0) | ||
bytesCompressed += deflater.deflate( | ||
readBuffer, | ||
readOffset + bytesCompressed, | ||
readLength - bytesCompressed, | ||
Deflater.FULL_FLUSH | ||
) | ||
bytesCompressed | ||
} | ||
|
||
private def createTrailer(crcValue: Int, totalIn: Int): ByteArrayInputStream = | ||
new ByteArrayInputStream( | ||
Array( | ||
(crcValue >> 0).toByte, | ||
(crcValue >> 8).toByte, | ||
(crcValue >> 16).toByte, | ||
(crcValue >> 24).toByte, | ||
(totalIn >> 0).toByte, | ||
(totalIn >> 8).toByte, | ||
(totalIn >> 16).toByte, | ||
(totalIn >> 24).toByte | ||
) | ||
) | ||
|
||
override def available(): Int = if (stage == DONE) 0 else 1 | ||
|
||
@throws[IOException] | ||
override def close(): Unit = { | ||
source.close() | ||
deflater.end() | ||
if (trailer != null) trailer.close() | ||
header.close() | ||
} | ||
|
||
crc.reset() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.