Skip to content

Commit

Permalink
added mediaType parameters to Body methods (#2323)
Browse files Browse the repository at this point in the history
* added mediaType parameters to Body methods

* refactored mediaType hint to follow zio-http conventions

* refactored BodySpec

* added logic to populate http content-type header based on Body typeif said header is missing

* removed header population logic from requestRaw

* changed request logic to override the header with the Body header if it exists

* pr remarks

* fixed streambody infinite recursion

* fixed body.md

---------

Co-authored-by: John A. De Goes <john@degoes.net>
  • Loading branch information
Adriani-Furtado and jdegoes authored Sep 2, 2023
1 parent cfe9e35 commit 8780eb9
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/dsl/body.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ To create an `Body` that encodes a Stream you can use `Body.fromStream`.
- Using a Stream of String

```scala mdoc:silent
val streamHttpData2: Body = Body.fromStream(ZStream("a", "b", "c"), Charsets.Http)
val streamHttpData2: Body = Body.fromCharSequenceStream(ZStream("a", "b", "c"), Charsets.Http)
```

### Creating a Body from a `File`
Expand Down
72 changes: 50 additions & 22 deletions zio-http/src/main/scala/zio/http/Body.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,21 @@ trait Body { self =>
*/
def isEmpty: Boolean

private[zio] def mediaType: Option[MediaType]
/**
* Returns the media type for this Body
*/
def mediaType: Option[MediaType]

/**
* Updates the media type attached to this body, returning a new Body with the
* updated media type
*/
def contentType(newMediaType: MediaType): Body

def contentType(newMediaType: MediaType, newBoundary: Boundary): Body

private[zio] def boundary: Option[Boundary]

private[zio] def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body
}

object Body {
Expand All @@ -142,7 +153,10 @@ object Body {
/**
* Constructs a [[zio.http.Body]] from the contents of a file.
*/
def fromCharSequence(charSequence: CharSequence, charset: Charset = Charsets.Http): Body =
def fromCharSequence(
charSequence: CharSequence,
charset: Charset = Charsets.Http,
): Body =
BodyEncoding.default.fromCharSequence(charSequence, charset)

/**
Expand All @@ -153,7 +167,8 @@ object Body {
/**
* Constructs a [[zio.http.Body]] from the contents of a file.
*/
def fromFile(file: java.io.File, chunkSize: Int = 1024 * 4): Body = FileBody(file, chunkSize)
def fromFile(file: java.io.File, chunkSize: Int = 1024 * 4): Body =
FileBody(file, chunkSize)

/**
* Constructs a [[zio.http.Body]] from from form data, using multipart
Expand Down Expand Up @@ -183,21 +198,26 @@ object Body {
/**
* Constructs a [[zio.http.Body]] from a stream of bytes.
*/
def fromStream(stream: ZStream[Any, Throwable, Byte]): Body = StreamBody(stream)
def fromStream(stream: ZStream[Any, Throwable, Byte]): Body =
StreamBody(stream)

/**
* Constructs a [[zio.http.Body]] from a stream of text, using the specified
* character set, which defaults to the HTTP character set.
*/
def fromStream(stream: ZStream[Any, Throwable, CharSequence], charset: Charset = Charsets.Http)(implicit
def fromCharSequenceStream(
stream: ZStream[Any, Throwable, CharSequence],
charset: Charset = Charsets.Http,
)(implicit
trace: Trace,
): Body =
fromStream(stream.map(seq => Chunk.fromArray(seq.toString.getBytes(charset))).flattenChunks)

/**
* Helper to create Body from String
*/
def fromString(text: String, charset: Charset = Charsets.Http): Body = fromCharSequence(text, charset)
def fromString(text: String, charset: Charset = Charsets.Http): Body =
fromCharSequence(text, charset)

/**
* Constructs a [[zio.http.Body]] from form data using URL encoding and the
Expand Down Expand Up @@ -235,11 +255,13 @@ object Body {

override private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] = Array.empty[Byte]

override private[zio] def mediaType: Option[MediaType] = None

override private[zio] def boundary: Option[Boundary] = None

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body = EmptyBody
override def mediaType: Option[MediaType] = None

override def contentType(newMediaType: MediaType): Body = EmptyBody

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body = EmptyBody
}

private[zio] final case class ChunkBody(
Expand All @@ -265,8 +287,10 @@ object Body {

override private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] = data.toArray

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] final case class FileBody(
Expand Down Expand Up @@ -311,8 +335,10 @@ object Body {
override private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] =
Files.readAllBytes(file.toPath)

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] final case class StreamBody(
Expand All @@ -331,8 +357,10 @@ object Body {

override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] = stream

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] final case class WebsocketBody(socketApp: WebSocketApp[Any]) extends Body {
Expand All @@ -347,16 +375,16 @@ object Body {

private[zio] def boundary: Option[Boundary] = None

private[zio] def contentType(
newMediaType: MediaType,
newBoundary: Option[Boundary],
): Body = this

def isComplete: Boolean = true

def isEmpty: Boolean = true

private[zio] def mediaType: Option[MediaType] = None
def mediaType: Option[MediaType] = None

def contentType(newMediaType: zio.http.MediaType): zio.http.Body = this

def contentType(newMediaType: zio.http.MediaType, newBoundary: zio.http.Boundary): zio.http.Body = this

}

private val zioEmptyArray = ZIO.succeed(Array.empty[Byte])(Trace.empty)
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ object Handler {
): Handler[R, Throwable, Any, Response] =
Handler.fromZIO {
ZIO.environment[R].map { env =>
fromBody(Body.fromStream(stream.provideEnvironment(env), charset))
fromBody(Body.fromCharSequenceStream(stream.provideEnvironment(env), charset))
}
}.flatten

Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ object Response {
* \- stream of data to be sent as Server Sent Events
*/
def fromServerSentEvents(data: ZStream[Any, Nothing, ServerSentEvent])(implicit trace: Trace): Response =
Response(Status.Ok, contentTypeEventStream, Body.fromStream(data.map(_.encode)))
Response(Status.Ok, contentTypeEventStream, Body.fromCharSequenceStream(data.map(_.encode)))

/**
* Creates a new response for the provided socket app
Expand Down
11 changes: 9 additions & 2 deletions zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ final case class ZClient[-Env, -In, +Err, +Out](

private def requestRaw(method: Method, suffix: String, body: Body)(implicit
trace: Trace,
): ZIO[Env & Scope, Err, Response] =
): ZIO[Env & Scope, Err, Response] = {
driver
.request(
version,
Expand All @@ -202,6 +202,7 @@ final case class ZClient[-Env, -In, +Err, +Out](
sslConfig,
proxy,
)
}

def retry[Env1 <: Env](policy: Schedule[Env1, Err, Any]): ZClient[Env1, In, Err, Out] =
transform[Env1, In, Err, Out](bodyEncoder, bodyDecoder, self.driver.retry(policy))
Expand Down Expand Up @@ -645,8 +646,14 @@ object ZClient {
sslConfig: Option[ClientSSLConfig],
proxy: Option[Proxy],
)(implicit trace: Trace): ZIO[Scope, Throwable, Response] = {
val request = Request(version, method, url, headers, body, None)
val requestHeaders = body.mediaType match {
case None => headers
case Some(value) => headers.removeHeader(Header.ContentType).addHeader(Header.ContentType(value))
}

val request = Request(version, method, url, requestHeaders, body, None)
val cfg = config.copy(ssl = sslConfig.orElse(config.ssl), proxy = proxy.orElse(config.proxy))

requestAsync(request, cfg, () => WebSocketApp.unit, None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ private[codec] object EncoderDecoder {
Body.fromMultipartForm(encodeMultipartFormData(inputs), formBoundary)
} else {
if (isEventStream) {
Body.fromStream(inputs(0).asInstanceOf[ZStream[Any, Nothing, ServerSentEvent]].map(_.encode))
Body.fromCharSequenceStream(inputs(0).asInstanceOf[ZStream[Any, Nothing, ServerSentEvent]].map(_.encode))
} else if (inputs.length < 1) {
Body.empty
} else {
Expand Down
18 changes: 12 additions & 6 deletions zio-http/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ object NettyBody extends BodyEncoding {

private[zio] override def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] = asciiString.array()

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] final case class ByteBufBody(
Expand All @@ -103,8 +105,10 @@ object NettyBody extends BodyEncoding {
override private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] =
ByteBufUtil.getBytes(byteBuf)

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] final case class AsyncBody(
Expand Down Expand Up @@ -142,8 +146,10 @@ object NettyBody extends BodyEncoding {

override def toString(): String = s"AsyncBody($unsafeAsync)"

override def contentType(newMediaType: MediaType, newBoundary: Option[Boundary] = None): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(newBoundary))
override def contentType(newMediaType: MediaType): Body = copy(mediaType = Some(newMediaType))

override def contentType(newMediaType: MediaType, newBoundary: Boundary): Body =
copy(mediaType = Some(newMediaType), boundary = boundary.orElse(Some(newBoundary)))
}

private[zio] trait UnsafeAsync {
Expand Down
8 changes: 7 additions & 1 deletion zio-http/src/test/scala/zio/http/BodySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,11 @@ object BodySpec extends ZIOHttpSpec {
),
),
),
)
suite("mediaType")(
test("updates the Body media type with the provided value") {
val body = Body.fromString("test").contentType(MediaType.text.plain)
assertTrue(body.mediaType == Option(MediaType.text.plain))
},
),
) @@ timeout(10 seconds)
}
2 changes: 1 addition & 1 deletion zio-http/src/test/scala/zio/http/ClientSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ object ClientSpec extends HttpRunnableSpec {
val app = Handler.fromFunctionZIO[Request] { req => req.body.asString.map(Response.text(_)) }.sandbox.toHttpApp
val stream = ZStream.fromIterable(List("a", "b", "c"), chunkSize = 1)
val res = app
.deploy(Request(method = Method.POST, body = Body.fromStream(stream)))
.deploy(Request(method = Method.POST, body = Body.fromCharSequenceStream(stream)))
.flatMap(_.body.asString)
assertZIO(res)(equalTo("abc"))
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object ResponseCompressionSpec extends ZIOHttpSpec {
Headers(
Header.ContentType(MediaType.text.plain),
),
Body.fromStream(
Body.fromCharSequenceStream(
ZStream
.unfold[Long, String](0L) { s =>
if (s < 1000) Some((s"$s\n", s + 1)) else None
Expand Down
3 changes: 2 additions & 1 deletion zio-http/src/test/scala/zio/http/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ object ServerSpec extends HttpRunnableSpec {
val dataStream = ZStream.repeat("A").take(MaxSize.toLong)
val app =
Routes(RoutePattern.any -> handler((_: Path, req: Request) => Response(body = req.body))).toHttpApp
val res = app.deploy.body.mapZIO(_.asChunk.map(_.length)).run(body = Body.fromStream(dataStream))
val res =
app.deploy.body.mapZIO(_.asChunk.map(_.length)).run(body = Body.fromCharSequenceStream(dataStream))
assertZIO(res)(equalTo(MaxSize))
}
} +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec {
.deploy(
Request(
method = Method.POST,
body = Body.fromStream(stream),
body = Body.fromCharSequenceStream(stream),
headers = extraHeaders,
),
)
Expand Down Expand Up @@ -113,7 +113,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec {
Request(
method = Method.POST,
url = URL.root / "streaming",
body = Body.fromStream(stream),
body = Body.fromCharSequenceStream(stream),
headers = extraHeaders,
),
)
Expand Down

0 comments on commit 8780eb9

Please sign in to comment.