Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement streaming of multipart field data #2899

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/FormSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,30 @@ object FormSpec extends ZIOHttpSpec {
)
}
} @@ samples(10),
test("output stream maintains the same chunk structure as the input stream") {
val form = Form(
Chunk(
FormField.StreamingBinary(
name = "blob",
data = ZStream.fromChunk(Chunk(1, 2).map(_.toByte)) ++
ZStream.fromChunk(Chunk(3).map(_.toByte)),
contentType = MediaType.application.`octet-stream`,
),
),
)
val boundary = Boundary("X-INSOMNIA-BOUNDARY")
val formByteStream = form.multipartBytes(boundary)
val streamingForm = StreamingForm(formByteStream, boundary)
val expected = Chunk(Chunk[Byte](1, 2), Chunk[Byte](3))
streamingForm.fields.flatMap {
case sb: FormField.StreamingBinary => sb.data
case _ => ZStream.empty
}
.mapChunks(Chunk.single)
.filter(_.nonEmpty)
.runCollect
.map { c => assertTrue(c == expected) }
},
) @@ sequential

def spec =
Expand Down
210 changes: 89 additions & 121 deletions zio-http/shared/src/main/scala/zio/http/StreamingForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import scala.annotation.tailrec
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.stream.{Take, ZChannel, ZStream}
import zio.stream.{Take, ZStream}

import zio.http.StreamingForm.{Buffer, ZStreamOps}
import zio.http.StreamingForm.Buffer
import zio.http.internal.{FormAST, FormState}

final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: Boundary, bufferSize: Int = 8192) {
Expand All @@ -35,7 +35,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
* Runs the streaming form and collects all parts in memory, returning a Form
*/
def collectAll(implicit trace: Trace): ZIO[Any, Throwable, Form] =
fields.mapZIO {
streamFormFields(bufferUpToBoundary = true).mapZIO {
case sb: FormField.StreamingBinary =>
sb.collect
case other: FormField =>
Expand All @@ -45,36 +45,44 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
}

def fields(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
streamFormFields(bufferUpToBoundary = false)

private def streamFormFields(
bufferUpToBoundary: Boolean,
)(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
ZStream.unwrapScoped {
implicit val unsafe: Unsafe = Unsafe.unsafe

for {
runtime <- ZIO.runtime[Any]
buffer <- ZIO.succeed(new Buffer(bufferSize))
buffer <- ZIO.succeed(new Buffer(bufferSize, crlfBoundary, bufferUpToBoundary))
abort <- Promise.make[Nothing, Unit]
fieldQueue <- Queue.bounded[Take[Throwable, FormField]](4)
state = initialState
reader =
source
.mapAccumImmediate(initialState) { (state, byte) =>
def handleBoundary(ast: Chunk[FormAST]): (StreamingForm.State, Option[FormField]) =
if (state.inNonStreamingPart) {
FormField.fromFormAST(ast, charset) match {
case Right(formData) =>
buffer.reset()
(state.reset, Some(formData))
case Left(e) => throw e.asException
}
} else {
buffer.reset()
(state.reset, None)
source.runForeachChunk { bytes =>
def handleBoundary(ast: Chunk[FormAST]): Option[FormField] =
if (state.inNonStreamingPart) {
FormField.fromFormAST(ast, charset) match {
case Right(formData) =>
buffer.reset()
state.reset
Some(formData)
case Left(e) => throw e.asException
}
} else {
buffer.reset()
state.reset
None
}

def handleByte(byte: Byte, isLastByte: Boolean): Option[FormField] = {
state.formState match {
case formState: FormState.FormStateBuffer =>
val nextFormState = formState.append(byte)
state.currentQueue match {
case Some(queue) =>
val takes = buffer.addByte(crlfBoundary, byte)
val takes = buffer.addByte(byte, isLastByte)
if (takes.nonEmpty) {
runtime.unsafe.run(queue.offerAll(takes).raceFirst(abort.await)).getOrThrowFiberFailure()
}
Expand All @@ -98,31 +106,43 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
streamingFormData <- FormField
.incomingStreamingBinary(newFormState.tree, newQueue)
.mapError(_.asException)
nextState = state.withCurrentQueue(newQueue)
} yield (nextState, Some(streamingFormData))
_ = state.withCurrentQueue(newQueue)
} yield Some(streamingFormData)
}.getOrThrowFiberFailure()
} else {
val nextState = state.withInNonStreamingPart(true)
(nextState, None)
val _ = state.withInNonStreamingPart(true)
None
}
} else {
(state, None)
None
}
case FormState.BoundaryEncapsulated(ast) =>
handleBoundary(ast)
case FormState.BoundaryClosed(ast) =>
handleBoundary(ast)
}
case _ =>
(state, None)
None
}
}
.mapZIO { field =>
fieldQueue.offer(Take.single(field))

val builder = Chunk.newBuilder[FormField]
val it = bytes.iterator
var hasNext = it.hasNext
while (hasNext) {
val byte = it.next()
hasNext = it.hasNext
handleByte(byte, !hasNext) match {
case Some(field) => builder += field
case _ => ()
}
}
val fields = builder.result()
fieldQueue.offer(Take.chunk(fields)).when(fields.nonEmpty)
}
// FIXME: .blocking here is temporary until we figure out a better way to avoid running effects within mapAccumImmediate
_ <- ZIO
.blocking(reader.runDrain)
.blocking(reader)
.catchAllCause(cause => fieldQueue.offer(Take.failCause(cause)))
.ensuring(fieldQueue.offer(Take.end))
.forkScoped
Expand All @@ -140,7 +160,7 @@ final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary:
private def initialState: StreamingForm.State =
StreamingForm.initialState(boundary)

private val crlfBoundary: Chunk[Byte] = Chunk[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes
private def crlfBoundary: Array[Byte] = Array[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes.toArray[Byte]
}

object StreamingForm {
Expand Down Expand Up @@ -174,9 +194,10 @@ object StreamingForm {
new State(FormState.fromBoundary(boundary), None, _inNonStreamingPart = false)
}

private final class Buffer(initialSize: Int) {
private var buffer: Array[Byte] = new Array[Byte](initialSize)
private var length: Int = 0
private final class Buffer(bufferSize: Int, crlfBoundary: Array[Byte], bufferUpToBoundary: Boolean) {
private var buffer: Array[Byte] = Array.ofDim(bufferSize)
private var index: Int = 0
private val boundarySize = crlfBoundary.length

private def ensureHasCapacity(requiredCapacity: Int): Unit = {
@tailrec
Expand All @@ -194,104 +215,51 @@ object StreamingForm {
} else ()
}

def addByte(
crlfBoundary: Chunk[Byte],
byte: Byte,
): Chunk[Take[Nothing, Byte]] = {
ensureHasCapacity(length + crlfBoundary.length)
buffer(length) = byte
if (length < (crlfBoundary.length - 1)) {
// Not enough bytes to check if we have the boundary
length += 1
Chunk.empty
} else {
var foundBoundary = true
var i = 0
while (i < crlfBoundary.length && foundBoundary) {
if (buffer(length - i) != crlfBoundary(crlfBoundary.length - 1 - i)) {
foundBoundary = false
}
i += 1
}

if (foundBoundary) {
// We have found the boundary
val preBoundary =
Chunk.fromArray(Chunk.fromArray(buffer).take(length + 1 - crlfBoundary.length).toArray[Byte])
length = 0
Chunk(Take.chunk(preBoundary), Take.end)
} else {
// We don't have the boundary
if (length < (buffer.length - 2)) {
length += 1
Chunk.empty
} else {
val preBoundary =
Chunk.fromArray(Chunk.fromArray(buffer).take(length + 1 - crlfBoundary.length).toArray[Byte])
for (i <- crlfBoundary.indices) {
buffer(i) = buffer(length + 1 - crlfBoundary.length + i)
}
length = crlfBoundary.length
Chunk(Take.chunk(preBoundary))
}
private def matchesPartialBoundary(idx: Int): Boolean = {
val bs = boundarySize
var i = 0
var result = false
while (i < bs && i <= idx && !result) {
val i0 = idx - i
var i1 = 0
while (i >= i1 && buffer(i0 + i1) == crlfBoundary(i1) && !result) {
if (i == i1) result = true
i1 += 1
}
i += 1
}
result
}

def reset(): Unit = {
length = 0
}
}
def addByte(byte: Byte, isLastByte: Boolean): Chunk[Take[Nothing, Byte]] = {
val idx = index
ensureHasCapacity(idx + boundarySize + 1)
buffer(idx) = byte
index += 1

implicit class ZStreamOps[R, E, A](self: ZStream[R, E, A]) {
var i = 0
var foundFullBoundary = idx >= boundarySize - 1
while (i < boundarySize && foundFullBoundary) {
if (buffer(idx + 1 - crlfBoundary.length + i) != crlfBoundary(i)) {
foundFullBoundary = false
}
i += 1
}

private def mapAccumImmediate[S1, B](
self: Chunk[A],
)(s1: S1)(f1: (S1, A) => (S1, Option[B])): (S1, Option[(B, Chunk[A])]) = {
val iterator = self.chunkIterator
var index = 0
var s = s1
var result: Option[B] = None
while (iterator.hasNextAt(index) && result.isEmpty) {
val a = iterator.nextAt(index)
index += 1
val tuple = f1(s, a)
s = tuple._1
result = tuple._2
if (foundFullBoundary) {
reset()
val toTake = idx + 1 - boundarySize
if (toTake == 0) Chunk(Take.end)
else Chunk(Take.chunk(Chunk.fromArray(buffer.take(toTake))), Take.end)
} else if (!bufferUpToBoundary && isLastByte && byte != '-' && !matchesPartialBoundary(idx)) {
reset()
Chunk(Take.chunk(Chunk.fromArray(buffer.take(idx + 1))))
} else {
Chunk.empty
}
(s, result.map(b => (b, self.drop(index))))
}

/**
* Statefully maps over the elements of this stream to sometimes produce new
* elements. Each new element gets immediately emitted regardless of the
* upstream chunk size.
*/
def mapAccumImmediate[S, A1](s: => S)(f: (S, A) => (S, Option[A1]))(implicit trace: Trace): ZStream[R, E, A1] =
ZStream.succeed(s).flatMap { s =>
def chunkAccumulator(currS: S, in: Chunk[A]): ZChannel[Any, E, Chunk[A], Any, E, Chunk[A1], Unit] =
mapAccumImmediate(in)(currS)(f) match {
case (nextS, Some((a1, remaining))) =>
ZChannel.write(Chunk.single(a1)) *>
accumulator(nextS, remaining)
case (nextS, None) =>
accumulator(nextS, Chunk.empty)
}

def accumulator(currS: S, leftovers: Chunk[A]): ZChannel[Any, E, Chunk[A], Any, E, Chunk[A1], Unit] =
if (leftovers.isEmpty) {
ZChannel.readWithCause(
(in: Chunk[A]) => {
chunkAccumulator(currS, in)
},
(err: Cause[E]) => ZChannel.refailCause(err),
(_: Any) => ZChannel.unit,
)
} else {
chunkAccumulator(currS, leftovers)
}

ZStream.fromChannel(self.channel >>> accumulator(s, Chunk.empty))
}
def reset(): Unit =
index = 0
}
}
Loading