Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically validate request body using Schema.validate #2360

Merged
merged 1 commit into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package zio.http.codec

import scala.util.control.NoStackTrace

import zio.Cause
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Cause, Chunk}

import zio.schema.validation.ValidationError

import zio.http.{Path, Status}

Expand All @@ -28,34 +30,44 @@ sealed trait HttpCodecError extends Exception with NoStackTrace {
def message: String
}
object HttpCodecError {
final case class MissingHeader(headerName: String) extends HttpCodecError {
final case class MissingHeader(headerName: String) extends HttpCodecError {
def message = s"Missing header $headerName"
}
final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError {
final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError {
def message = s"Expected $expected but found $actual"
}
final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError {
final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Expected to find ${textCodec} but found pre-mature end to the path ${path}"
}
final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError {
final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError {
def message = s"Malformed path ${path} failed to decode using $pathCodec: $error"
}
final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError {
final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError {
def message = s"Expected status code ${expected} but found ${actual}"
}
final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError {
final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Malformed header $headerName failed to decode using $textCodec"
}
final case class MissingQueryParam(queryParamName: String) extends HttpCodecError {
final case class MissingQueryParam(queryParamName: String) extends HttpCodecError {
def message = s"Missing query parameter $queryParamName"
}
final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError {
final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec"
}
final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError {
final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError {
def message = s"Malformed request body failed to decode: $details"
}
final case class CustomError(message: String) extends HttpCodecError
final case class InvalidEntity(details: String, cause: Chunk[ValidationError] = Chunk.empty) extends HttpCodecError {
def message = s"A well-formed entity failed validation: $details"
}
object InvalidEntity {
def wrap(errors: Chunk[ValidationError]): InvalidEntity =
InvalidEntity(
errors.foldLeft("")((acc, err) => acc + err.message + "\n"),
errors,
)
}
final case class CustomError(message: String) extends HttpCodecError

def isHttpCodecError(cause: Cause[Any]): Boolean = {
!cause.isFailure && cause.defects.forall(e => e.isInstanceOf[HttpCodecError])
Expand Down
22 changes: 17 additions & 5 deletions zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package zio.http.codec.internal
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.stream.ZStream
import zio.stream.{ZPipeline, ZStream}

import zio.schema._
import zio.schema.codec.BinaryCodec

import zio.http.codec.HttpCodecError
import zio.http.{Body, FormField, MediaType}

/**
Expand Down Expand Up @@ -92,13 +93,12 @@ private[internal] object BodyCodec {

final case class Single[A](schema: Schema[A], mediaType: Option[MediaType], name: Option[String])
extends BodyCodec[A] {
def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] = {
def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] =
if (schema == Schema[Unit]) ZIO.unit.asInstanceOf[IO[Throwable, A]]
else
body.asChunk.flatMap { chunk =>
ZIO.fromEither(codec.decode(chunk))
}
}
}.flatMap(validateZIO(schema))

def encodeToBody(value: A, codec: BinaryCodec[A])(implicit trace: Trace): Body =
Body.fromChunk(codec.encode(value))
Expand All @@ -111,11 +111,23 @@ private[internal] object BodyCodec {
def decodeFromBody(body: Body, codec: BinaryCodec[E])(implicit
trace: Trace,
): IO[Throwable, ZStream[Any, Nothing, E]] =
ZIO.succeed((body.asStream >>> codec.streamDecoder).orDie)
ZIO.succeed((body.asStream >>> codec.streamDecoder >>> validateStream(schema)).orDie)

def encodeToBody(value: ZStream[Any, Nothing, E], codec: BinaryCodec[E])(implicit trace: Trace): Body =
Body.fromStream(value >>> codec.streamEncoder)

type Element = E
}

private[internal] def validateZIO[A](schema: Schema[A])(e: A)(implicit trace: Trace): ZIO[Any, HttpCodecError, A] = {
val errors = Schema.validate(e)(schema)
if (errors.isEmpty) ZIO.succeed(e)
else ZIO.fail(HttpCodecError.InvalidEntity.wrap(errors))
}

private[internal] def validateStream[E](schema: Schema[E])(implicit
trace: Trace,
): ZPipeline[Any, HttpCodecError, E, E] =
ZPipeline.mapZIO(validateZIO(schema))

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package zio.http.codec.internal

import zio._
import zio.test._

import zio.stream.{ZSink, ZStream}

import zio.schema._
import zio.schema.annotation.validate
import zio.schema.validation.Validation

import zio.http.codec.HttpCodecError

object BodyCodecSpec extends ZIOSpecDefault {
import BodyCodec._

case class User(
@validate(Validation.greaterThan(0))
id: Int,
@validate(Validation.minLength(2) && Validation.maxLength(64))
name: String,
)
object User {
val schema: Schema[User] = DeriveSchema.gen[User]
}

def spec = suite("BodyCodecSpec")(
suite("validateZIO")(
test("returns a valid entity") {
val valid = User(12, "zio")

for {
actual <- validateZIO(User.schema)(valid)
} yield assertTrue(valid == actual)
} +
test("fails with HttpCodecError for invalid entity") {
val invalid = User(-4, "z")
val validated = BodyCodec.validateZIO(User.schema)(invalid)

assertZIO(validated.exit)(Assertion.failsWithA[HttpCodecError.InvalidEntity])
},
),
suite("validateStream")(
test("returns all valid entities") {
val users = Chunk(
User(1, "Will"),
User(2, "Ammon"),
)
val valids = ZStream.fromChunk(users)

for {
validatedUsers <- valids.via(validateStream(User.schema)).runCollect
} yield assertTrue(validatedUsers == users)
},
test("fails with HttpCodecError for invalid entity") {
val users = Chunk(
User(1, "Will"),
User(-5, "Ammon"),
)
val invalid = ZStream.fromChunk(users)

for {
validatedUsers <- invalid.via(validateStream(User.schema)).runCollect.exit
} yield assert(validatedUsers)(Assertion.failsWithA[HttpCodecError.InvalidEntity])
},
),
)
}
40 changes: 40 additions & 0 deletions zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import zio.test._

import zio.stream.ZStream

import zio.schema.annotation.validate
import zio.schema.codec.{DecodeError, JsonCodec}
import zio.schema.validation.Validation
import zio.schema.{DeriveSchema, Schema, StandardType}

import zio.http.Header.ContentType
Expand All @@ -38,6 +40,11 @@ object EndpointSpec extends ZIOSpecDefault {

case class NewPost(value: String)

case class User(
@validate(Validation.greaterThan(0))
id: Int,
)

def spec = suite("EndpointSpec")(
suite("handler")(
test("simple request") {
Expand Down Expand Up @@ -547,6 +554,39 @@ object EndpointSpec extends ZIOSpecDefault {
body2 == "{\"message\":\"something went wrong\"}",
)
},
test("validation occurs automatically on schema") {

implicit val schema: Schema[User] = DeriveSchema.gen[User]

val routes =
Endpoint(POST / "users")
.in[User]
.out[String]
.implement {
Handler.fromFunctionZIO { _ =>
ZIO.succeed("User ID is greater than 0")
}
}
.handleErrorCause { case cause =>
Response.text("Caught: " + cause.defects.headOption.fold("no known cause")(d => d.getMessage))
}

val request1 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":0}"""))
val request2 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":1}"""))

for {
response1 <- routes.toHttpApp.runZIO(request1)
body1 <- response1.body.asString.orDie

response2 <- routes.toHttpApp.runZIO(request2)
body2 <- response2.body.asString.orDie
} yield assertTrue(
extractStatus(response1) == Status.BadRequest,
body1 == "",
extractStatus(response2) == Status.Ok,
body2 == "\"User ID is greater than 0\"",
)
},
),
suite("byte stream input/output")(
test("responding with a byte stream") {
Expand Down