diff --git a/zio-http/jvm/src/test/scala/zio/http/FormSpec.scala b/zio-http/jvm/src/test/scala/zio/http/FormSpec.scala index 9e1b0117fc..1169919168 100644 --- a/zio-http/jvm/src/test/scala/zio/http/FormSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/FormSpec.scala @@ -271,6 +271,33 @@ object FormSpec extends ZIOHttpSpec { collected.get("file").get.asInstanceOf[FormField.Binary].data == bytes, ) }, + test("StreamingForm dynamically resizes") { + val N = 1000 + val expected = Chunk.fromArray(Array.fill(N)(scala.util.Random.nextInt()).map(_.toByte)) + val form = + Form( + Chunk( + FormField.binaryField( + name = "identifier", + data = Chunk(10.toByte), + mediaType = MediaType.application.`octet-stream`, + ), + FormField.StreamingBinary( + name = "blob", + data = ZStream.fromChunk(expected), + contentType = MediaType.application.`octet-stream`, + ), + ), + ) + val boundary = Boundary("X-INSOMNIA-BOUNDARY") + for { + formBytes <- form.multipartBytes(boundary).runCollect + formByteStream = ZStream.fromChunk(formBytes) + streamingForm = StreamingForm(formByteStream, boundary, 16) + out <- streamingForm.collectAll + res = out.get("blob").get.asInstanceOf[FormField.Binary].data + } yield assertTrue(res == expected) + } @@ timeout(3.seconds), test("decoding random form") { check(Gen.chunkOfBounded(2, 8)(formField)) { fields => for { diff --git a/zio-http/shared/src/main/scala/zio/http/StreamingForm.scala b/zio-http/shared/src/main/scala/zio/http/StreamingForm.scala index 51cc105691..610c535ae6 100644 --- a/zio-http/shared/src/main/scala/zio/http/StreamingForm.scala +++ b/zio-http/shared/src/main/scala/zio/http/StreamingForm.scala @@ -18,6 +18,8 @@ package zio.http import java.nio.charset.Charset +import scala.annotation.tailrec + import zio._ import zio.stacktracer.TracingImplicits.disableAutoTrace @@ -172,14 +174,31 @@ object StreamingForm { new State(FormState.fromBoundary(boundary), None, _inNonStreamingPart = false) } - private final class Buffer(bufferSize: Int) { - private val buffer: Array[Byte] = new Array[Byte](bufferSize) + private final class Buffer(initialSize: Int) { + private var buffer: Array[Byte] = new Array[Byte](initialSize) private var length: Int = 0 + private def ensureHasCapacity(requiredCapacity: Int): Unit = { + @tailrec + def calculateNewCapacity(existing: Int, required: Int): Int = { + val newCap = existing * 2 + if (newCap < required) calculateNewCapacity(newCap, required) + else newCap + } + + val l = buffer.length + if (l <= requiredCapacity) { + val newArray = Array.ofDim[Byte](calculateNewCapacity(l, requiredCapacity)) + java.lang.System.arraycopy(buffer, 0, newArray, 0, l) + buffer = newArray + } else () + } + def addByte( crlfBoundary: Chunk[Byte], byte: Byte, ): Chunk[Take[Nothing, Byte]] = { + ensureHasCapacity(length + crlfBoundary.length) buffer(length) = byte if (length < (crlfBoundary.length - 1)) { // Not enough bytes to check if we have the boundary