Skip to content

Commit

Permalink
Implement streaming of multipart field data (#2899)
Browse files Browse the repository at this point in the history
* Implement streaming of multipart field data

* Only stream form field data for `fields` method

* Make Scala 2.12 happy
  • Loading branch information
kyri-petrou authored Jun 11, 2024
1 parent d1cb10d commit dc0e883
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 121 deletions.
24 changes: 24 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/FormSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,30 @@ object FormSpec extends ZIOHttpSpec {
)
}
} @@ samples(10),
test("output stream maintains the same chunk structure as the input stream") {
val form = Form(
Chunk(
FormField.StreamingBinary(
name = "blob",
data = ZStream.fromChunk(Chunk(1, 2).map(_.toByte)) ++
ZStream.fromChunk(Chunk(3).map(_.toByte)),
contentType = MediaType.application.`octet-stream`,
),
),
)
val boundary = Boundary("X-INSOMNIA-BOUNDARY")
val formByteStream = form.multipartBytes(boundary)
val streamingForm = StreamingForm(formByteStream, boundary)
val expected = Chunk(Chunk[Byte](1, 2), Chunk[Byte](3))
streamingForm.fields.flatMap {
case sb: FormField.StreamingBinary => sb.data
case _ => ZStream.empty
}
.mapChunks(Chunk.single)
.filter(_.nonEmpty)
.runCollect
.map { c => assertTrue(c == expected) }
},
) @@ sequential

def spec =
Expand Down
210 changes: 89 additions & 121 deletions zio-http/shared/src/main/scala/zio/http/StreamingForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import scala.annotation.tailrec
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.stream.{Take, ZChannel, ZStream}
import zio.stream.{Take, ZStream}

import zio.http.StreamingForm.{Buffer, ZStreamOps}
import zio.http.StreamingForm.Buffer
import zio.http.internal.{FormAST, FormState}

final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: Boundary, bufferSize: Int = 8192) {
Expand All @@ -35,7 +35,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
* Runs the streaming form and collects all parts in memory, returning a Form
*/
def collectAll(implicit trace: Trace): ZIO[Any, Throwable, Form] =
fields.mapZIO {
streamFormFields(bufferUpToBoundary = true).mapZIO {
case sb: FormField.StreamingBinary =>
sb.collect
case other: FormField =>
Expand All @@ -45,36 +45,44 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
}

def fields(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
streamFormFields(bufferUpToBoundary = false)

private def streamFormFields(
bufferUpToBoundary: Boolean,
)(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
ZStream.unwrapScoped {
implicit val unsafe: Unsafe = Unsafe.unsafe

for {
runtime <- ZIO.runtime[Any]
buffer <- ZIO.succeed(new Buffer(bufferSize))
buffer <- ZIO.succeed(new Buffer(bufferSize, crlfBoundary, bufferUpToBoundary))
abort <- Promise.make[Nothing, Unit]
fieldQueue <- Queue.bounded[Take[Throwable, FormField]](4)
state = initialState
reader =
source
.mapAccumImmediate(initialState) { (state, byte) =>
def handleBoundary(ast: Chunk[FormAST]): (StreamingForm.State, Option[FormField]) =
if (state.inNonStreamingPart) {
FormField.fromFormAST(ast, charset) match {
case Right(formData) =>
buffer.reset()
(state.reset, Some(formData))
case Left(e) => throw e.asException
}
} else {
buffer.reset()
(state.reset, None)
source.runForeachChunk { bytes =>
def handleBoundary(ast: Chunk[FormAST]): Option[FormField] =
if (state.inNonStreamingPart) {
FormField.fromFormAST(ast, charset) match {
case Right(formData) =>
buffer.reset()
state.reset
Some(formData)
case Left(e) => throw e.asException
}
} else {
buffer.reset()
state.reset
None
}

def handleByte(byte: Byte, isLastByte: Boolean): Option[FormField] = {
state.formState match {
case formState: FormState.FormStateBuffer =>
val nextFormState = formState.append(byte)
state.currentQueue match {
case Some(queue) =>
val takes = buffer.addByte(crlfBoundary, byte)
val takes = buffer.addByte(byte, isLastByte)
if (takes.nonEmpty) {
runtime.unsafe.run(queue.offerAll(takes).raceFirst(abort.await)).getOrThrowFiberFailure()
}
Expand All @@ -98,31 +106,43 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
streamingFormData <- FormField
.incomingStreamingBinary(newFormState.tree, newQueue)
.mapError(_.asException)
nextState = state.withCurrentQueue(newQueue)
} yield (nextState, Some(streamingFormData))
_ = state.withCurrentQueue(newQueue)
} yield Some(streamingFormData)
}.getOrThrowFiberFailure()
} else {
val nextState = state.withInNonStreamingPart(true)
(nextState, None)
val _ = state.withInNonStreamingPart(true)
None
}
} else {
(state, None)
None
}
case FormState.BoundaryEncapsulated(ast) =>
handleBoundary(ast)
case FormState.BoundaryClosed(ast) =>
handleBoundary(ast)
}
case _ =>
(state, None)
None
}
}
.mapZIO { field =>
fieldQueue.offer(Take.single(field))

val builder = Chunk.newBuilder[FormField]
val it = bytes.iterator
var hasNext = it.hasNext
while (hasNext) {
val byte = it.next()
hasNext = it.hasNext
handleByte(byte, !hasNext) match {
case Some(field) => builder += field
case _ => ()
}
}
val fields = builder.result()
fieldQueue.offer(Take.chunk(fields)).when(fields.nonEmpty)
}
// FIXME: .blocking here is temporary until we figure out a better way to avoid running effects within mapAccumImmediate
_ <- ZIO
.blocking(reader.runDrain)
.blocking(reader)
.catchAllCause(cause => fieldQueue.offer(Take.failCause(cause)))
.ensuring(fieldQueue.offer(Take.end))
.forkScoped
Expand All @@ -140,7 +160,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
private def initialState: StreamingForm.State =
StreamingForm.initialState(boundary)

private val crlfBoundary: Chunk[Byte] = Chunk[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes
private def crlfBoundary: Array[Byte] = Array[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes.toArray[Byte]
}

object StreamingForm {
Expand Down Expand Up @@ -174,9 +194,10 @@ object StreamingForm {
new State(FormState.fromBoundary(boundary), None, _inNonStreamingPart = false)
}

private final class Buffer(initialSize: Int) {
private var buffer: Array[Byte] = new Array[Byte](initialSize)
private var length: Int = 0
private final class Buffer(bufferSize: Int, crlfBoundary: Array[Byte], bufferUpToBoundary: Boolean) {
private var buffer: Array[Byte] = Array.ofDim(bufferSize)
private var index: Int = 0
private val boundarySize = crlfBoundary.length

private def ensureHasCapacity(requiredCapacity: Int): Unit = {
@tailrec
Expand All @@ -194,104 +215,51 @@ object StreamingForm {
} else ()
}

def addByte(
crlfBoundary: Chunk[Byte],
byte: Byte,
): Chunk[Take[Nothing, Byte]] = {
ensureHasCapacity(length + crlfBoundary.length)
buffer(length) = byte
if (length < (crlfBoundary.length - 1)) {
// Not enough bytes to check if we have the boundary
length += 1
Chunk.empty
} else {
var foundBoundary = true
var i = 0
while (i < crlfBoundary.length && foundBoundary) {
if (buffer(length - i) != crlfBoundary(crlfBoundary.length - 1 - i)) {
foundBoundary = false
}
i += 1
}

if (foundBoundary) {
// We have found the boundary
val preBoundary =
Chunk.fromArray(Chunk.fromArray(buffer).take(length + 1 - crlfBoundary.length).toArray[Byte])
length = 0
Chunk(Take.chunk(preBoundary), Take.end)
} else {
// We don't have the boundary
if (length < (buffer.length - 2)) {
length += 1
Chunk.empty
} else {
val preBoundary =
Chunk.fromArray(Chunk.fromArray(buffer).take(length + 1 - crlfBoundary.length).toArray[Byte])
for (i <- crlfBoundary.indices) {
buffer(i) = buffer(length + 1 - crlfBoundary.length + i)
}
length = crlfBoundary.length
Chunk(Take.chunk(preBoundary))
}
private def matchesPartialBoundary(idx: Int): Boolean = {
val bs = boundarySize
var i = 0
var result = false
while (i < bs && i <= idx && !result) {
val i0 = idx - i
var i1 = 0
while (i >= i1 && buffer(i0 + i1) == crlfBoundary(i1) && !result) {
if (i == i1) result = true
i1 += 1
}
i += 1
}
result
}

def reset(): Unit = {
length = 0
}
}
def addByte(byte: Byte, isLastByte: Boolean): Chunk[Take[Nothing, Byte]] = {
val idx = index
ensureHasCapacity(idx + boundarySize + 1)
buffer(idx) = byte
index += 1

implicit class ZStreamOps[R, E, A](self: ZStream[R, E, A]) {
var i = 0
var foundFullBoundary = idx >= boundarySize - 1
while (i < boundarySize && foundFullBoundary) {
if (buffer(idx + 1 - crlfBoundary.length + i) != crlfBoundary(i)) {
foundFullBoundary = false
}
i += 1
}

private def mapAccumImmediate[S1, B](
self: Chunk[A],
)(s1: S1)(f1: (S1, A) => (S1, Option[B])): (S1, Option[(B, Chunk[A])]) = {
val iterator = self.chunkIterator
var index = 0
var s = s1
var result: Option[B] = None
while (iterator.hasNextAt(index) && result.isEmpty) {
val a = iterator.nextAt(index)
index += 1
val tuple = f1(s, a)
s = tuple._1
result = tuple._2
if (foundFullBoundary) {
reset()
val toTake = idx + 1 - boundarySize
if (toTake == 0) Chunk(Take.end)
else Chunk(Take.chunk(Chunk.fromArray(buffer.take(toTake))), Take.end)
} else if (!bufferUpToBoundary && isLastByte && byte != '-' && !matchesPartialBoundary(idx)) {
reset()
Chunk(Take.chunk(Chunk.fromArray(buffer.take(idx + 1))))
} else {
Chunk.empty
}
(s, result.map(b => (b, self.drop(index))))
}

/**
* Statefully maps over the elements of this stream to sometimes produce new
* elements. Each new element gets immediately emitted regardless of the
* upstream chunk size.
*/
def mapAccumImmediate[S, A1](s: => S)(f: (S, A) => (S, Option[A1]))(implicit trace: Trace): ZStream[R, E, A1] =
ZStream.succeed(s).flatMap { s =>
def chunkAccumulator(currS: S, in: Chunk[A]): ZChannel[Any, E, Chunk[A], Any, E, Chunk[A1], Unit] =
mapAccumImmediate(in)(currS)(f) match {
case (nextS, Some((a1, remaining))) =>
ZChannel.write(Chunk.single(a1)) *>
accumulator(nextS, remaining)
case (nextS, None) =>
accumulator(nextS, Chunk.empty)
}

def accumulator(currS: S, leftovers: Chunk[A]): ZChannel[Any, E, Chunk[A], Any, E, Chunk[A1], Unit] =
if (leftovers.isEmpty) {
ZChannel.readWithCause(
(in: Chunk[A]) => {
chunkAccumulator(currS, in)
},
(err: Cause[E]) => ZChannel.refailCause(err),
(_: Any) => ZChannel.unit,
)
} else {
chunkAccumulator(currS, leftovers)
}

ZStream.fromChannel(self.channel >>> accumulator(s, Chunk.empty))
}
def reset(): Unit =
index = 0
}
}

0 comments on commit dc0e883

Please sign in to comment.