Skip to content

Commit

Permalink
Select codec based on response status for endpoint client (#2727)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jun 22, 2024
1 parent 0afa589 commit 7fbcd99
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 123 deletions.
103 changes: 3 additions & 100 deletions zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,91 +346,6 @@ object RoundtripSpec extends ZIOHttpSpec {
"42",
)
},
test("middleware error returned") {

val alwaysFailingMiddleware = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[String](Status.Custom(900)),
)

val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware

val endpointRoute =
endpoint.implementHandler(Handler.identity)

val routes = endpointRoute.toRoutes

val app = routes @@ alwaysFailingMiddleware
.implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit)

for {
port <- Server.install(app)
executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass"))))

out <- ZIO
.serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor =>
executor.apply(endpoint.apply(42))
}
.provideSome[Client & Scope](executorLayer)
.flip
} yield assert(out)(equalTo("FAIL"))
},
test("failed middleware deserialization") {
val alwaysFailingMiddleware = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[String](Status.Custom(900)),
)

val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware

val alwaysFailingMiddlewareWithAnotherSignature = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[Long](Status.Custom(900)),
)

val endpointWithAnotherSignature =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddlewareWithAnotherSignature

val endpointRoute =
endpoint.implementHandler(Handler.identity)

val routes = endpointRoute.toRoutes

val app = routes @@ alwaysFailingMiddleware.implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit)

for {
port <- Server.install(app)
executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass"))))

cause <- ZIO
.serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor =>
executor.apply(endpointWithAnotherSignature.apply(42))
}
.provideSome[Client with Scope](executorLayer)
.cause
} yield assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using endpoint error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using middleware error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected a number, got F)",
),
)
},
test("Failed endpoint deserialization") {
val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int].outError[Int](Status.Custom(999))
Expand All @@ -457,21 +372,9 @@ object RoundtripSpec extends ZIOHttpSpec {
}
.provideSome[Client with Scope](executorLayer)
.cause
} yield assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using endpoint error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using middleware error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.",
),
) && assert(cause.prettyPrint)(
containsString(
"""Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""",
} yield assertTrue(
cause.prettyPrint.contains(
"""zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""",
),
)
},
Expand Down
26 changes: 25 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import zio.schema.Schema
import zio.http.Header.Accept.MediaTypeWithQFactor
import zio.http._
import zio.http.codec.HttpCodec.{Annotated, Metadata}
import zio.http.codec.internal.EncoderDecoder
import zio.http.codec.internal.{AtomizedCodecs, EncoderDecoder}

/**
* A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP
Expand All @@ -48,6 +48,27 @@ sealed trait HttpCodec[-AtomTypes, Value] {

private lazy val encoderDecoder: EncoderDecoder[AtomTypes, Value] = EncoderDecoder(self)

private def statusCodecs: Chunk[SimpleCodec[Status, _]] =
self.asInstanceOf[HttpCodec[_, _]] match {
case HttpCodec.Fallback(left, right, _, _) => left.statusCodecs ++ right.statusCodecs
case HttpCodec.Combine(left, right, _) => left.statusCodecs ++ right.statusCodecs
case HttpCodec.Annotated(codec, _) => codec.statusCodecs
case HttpCodec.TransformOrFail(codec, _, _) => codec.statusCodecs
case HttpCodec.Empty => Chunk.empty
case HttpCodec.Halt => Chunk.empty
case atom: HttpCodec.Atom[_, _] =>
atom match {
case HttpCodec.Status(codec, _) => Chunk.single(codec)
case _ => Chunk.empty
}
}

private lazy val statusCodes: Set[Status] = statusCodecs.collect { case SimpleCodec.Specified(status) =>
status
}.toSet

private lazy val matchesAnyStatus: Boolean = statusCodecs.contains(SimpleCodec.Unspecified[Status]())

/**
* Returns a new codec that is the same as this one, but has attached docs,
* which will render whenever docs are generated from the codec.
Expand Down Expand Up @@ -238,6 +259,9 @@ sealed trait HttpCodec[-AtomTypes, Value] {
else Left(s"Expected ${expected} but found ${actual}"),
)(_ => expected)

private[http] def matchesStatus(status: Status) =
matchesAnyStatus || statusCodes.contains(status)

def named(name: String): HttpCodec[AtomTypes, Value] =
HttpCodec.Annotated(self, Metadata.Named(name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final case class EndpointExecutor[+MI](
alt: Alternator[E, invocation.middleware.Err],
ev: MI <:< invocation.middleware.In,
trace: Trace,
): ZIO[Scope, alt.Out, B] = {
): ZIO[Scope, E, B] = {
middlewareInput.flatMap { mi =>
getClient(invocation.endpoint).orDie.flatMap { endpointClient =>
endpointClient.execute(client, invocation)(ev(mi))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl
) {
def execute(client: Client, invocation: Invocation[P, I, E, O, M])(
mi: invocation.middleware.In,
)(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, alt.Out, O] = {
)(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, E, O] = {
val request0 = endpoint.input.encodeRequest(invocation.input)
val request = request0.copy(url = endpointRoot ++ request0.url)

Expand All @@ -44,28 +44,12 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl
)

client.request(withDefaultAcceptHeader).orDie.flatMap { response =>
if (response.status.isSuccess) {
if (endpoint.output.matchesStatus(response.status)) {
endpoint.output.decodeResponse(response).orDie
} else if (endpoint.error.matchesStatus(response.status)) {
endpoint.error.decodeResponse(response).orDie.flip
} else {
// Preferentially decode an error from the handler, before falling back
// to decoding the middleware error:
val handlerError =
endpoint.error
.decodeResponse(response)
.map(e => alt.left(e))
.mapError(t => new IllegalStateException("Cannot deserialize using endpoint error codec", t))

val middlewareError =
invocation.middleware.error
.decodeResponse(response)
.map(e => alt.right(e))
.mapError(t => new IllegalStateException("Cannot deserialize using middleware error codec", t))

handlerError.catchAllCause { handlerCause =>
middlewareError.catchAllCause { middlewareCause =>
ZIO.failCause(handlerCause ++ middlewareCause)
}
}.orDie.flip
ZIO.die(new IllegalStateException(s"Status code: ${response.status} is not defined in the endpoint"))
}
}
}
Expand Down

0 comments on commit 7fbcd99

Please sign in to comment.