Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select codec based on response status for endpoint client (#2727) #2929

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 3 additions & 100 deletions zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,91 +346,6 @@ object RoundtripSpec extends ZIOHttpSpec {
"42",
)
},
test("middleware error returned") {

val alwaysFailingMiddleware = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[String](Status.Custom(900)),
)

val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware

val endpointRoute =
endpoint.implementHandler(Handler.identity)

val routes = endpointRoute.toRoutes

val app = routes @@ alwaysFailingMiddleware
.implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit)

for {
port <- Server.install(app)
executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass"))))

out <- ZIO
.serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor =>
executor.apply(endpoint.apply(42))
}
.provideSome[Client & Scope](executorLayer)
.flip
} yield assert(out)(equalTo("FAIL"))
},
test("failed middleware deserialization") {
val alwaysFailingMiddleware = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[String](Status.Custom(900)),
)

val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddleware

val alwaysFailingMiddlewareWithAnotherSignature = EndpointMiddleware(
authorization,
HttpCodec.empty,
HttpCodec.error[Long](Status.Custom(900)),
)

val endpointWithAnotherSignature =
Endpoint(GET / "users" / int("userId")).out[Int] @@ alwaysFailingMiddlewareWithAnotherSignature

val endpointRoute =
endpoint.implementHandler(Handler.identity)

val routes = endpointRoute.toRoutes

val app = routes @@ alwaysFailingMiddleware.implement[Any, Unit](_ => ZIO.fail("FAIL"))(_ => ZIO.unit)

for {
port <- Server.install(app)
executorLayer = ZLayer(ZIO.serviceWith[Client](makeExecutor(_, port, Authorization.Basic("user", "pass"))))

cause <- ZIO
.serviceWithZIO[EndpointExecutor[alwaysFailingMiddleware.In]] { executor =>
executor.apply(endpointWithAnotherSignature.apply(42))
}
.provideSome[Client with Scope](executorLayer)
.cause
} yield assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using endpoint error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using middleware error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected a number, got F)",
),
)
},
test("Failed endpoint deserialization") {
val endpoint =
Endpoint(GET / "users" / int("userId")).out[Int].outError[Int](Status.Custom(999))
Expand All @@ -457,21 +372,9 @@ object RoundtripSpec extends ZIOHttpSpec {
}
.provideSome[Client with Scope](executorLayer)
.cause
} yield assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using endpoint error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"java.lang.IllegalStateException: Cannot deserialize using middleware error codec",
),
) && assert(cause.prettyPrint)(
containsString(
"Suppressed: java.lang.IllegalStateException: Trying to decode with Undefined codec.",
),
) && assert(cause.prettyPrint)(
containsString(
"""Suppressed: zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""",
} yield assertTrue(
cause.prettyPrint.contains(
"""zio.http.codec.HttpCodecError$MalformedBody: Malformed request body failed to decode: (expected '"' got '4')""",
),
)
},
Expand Down
26 changes: 25 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import zio.schema.Schema
import zio.http.Header.Accept.MediaTypeWithQFactor
import zio.http._
import zio.http.codec.HttpCodec.{Annotated, Metadata}
import zio.http.codec.internal.EncoderDecoder
import zio.http.codec.internal.{AtomizedCodecs, EncoderDecoder}

/**
* A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP
Expand All @@ -48,6 +48,27 @@ sealed trait HttpCodec[-AtomTypes, Value] {

private lazy val encoderDecoder: EncoderDecoder[AtomTypes, Value] = EncoderDecoder(self)

private def statusCodecs: Chunk[SimpleCodec[Status, _]] =
self.asInstanceOf[HttpCodec[_, _]] match {
case HttpCodec.Fallback(left, right, _, _) => left.statusCodecs ++ right.statusCodecs
case HttpCodec.Combine(left, right, _) => left.statusCodecs ++ right.statusCodecs
case HttpCodec.Annotated(codec, _) => codec.statusCodecs
case HttpCodec.TransformOrFail(codec, _, _) => codec.statusCodecs
case HttpCodec.Empty => Chunk.empty
case HttpCodec.Halt => Chunk.empty
case atom: HttpCodec.Atom[_, _] =>
atom match {
case HttpCodec.Status(codec, _) => Chunk.single(codec)
case _ => Chunk.empty
}
}

private lazy val statusCodes: Set[Status] = statusCodecs.collect { case SimpleCodec.Specified(status) =>
status
}.toSet

private lazy val matchesAnyStatus: Boolean = statusCodecs.contains(SimpleCodec.Unspecified[Status]())

/**
* Returns a new codec that is the same as this one, but has attached docs,
* which will render whenever docs are generated from the codec.
Expand Down Expand Up @@ -238,6 +259,9 @@ sealed trait HttpCodec[-AtomTypes, Value] {
else Left(s"Expected ${expected} but found ${actual}"),
)(_ => expected)

private[http] def matchesStatus(status: Status) =
matchesAnyStatus || statusCodes.contains(status)

def named(name: String): HttpCodec[AtomTypes, Value] =
HttpCodec.Annotated(self, Metadata.Named(name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ final case class EndpointExecutor[+MI](
alt: Alternator[E, invocation.middleware.Err],
ev: MI <:< invocation.middleware.In,
trace: Trace,
): ZIO[Scope, alt.Out, B] = {
): ZIO[Scope, E, B] = {
middlewareInput.flatMap { mi =>
getClient(invocation.endpoint).orDie.flatMap { endpointClient =>
endpointClient.execute(client, invocation)(ev(mi))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl
) {
def execute(client: Client, invocation: Invocation[P, I, E, O, M])(
mi: invocation.middleware.In,
)(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, alt.Out, O] = {
)(implicit alt: Alternator[E, invocation.middleware.Err], trace: Trace): ZIO[Scope, E, O] = {
val request0 = endpoint.input.encodeRequest(invocation.input)
val request = request0.copy(url = endpointRoot ++ request0.url)

Expand All @@ -44,28 +44,12 @@ private[endpoint] final case class EndpointClient[P, I, E, O, M <: EndpointMiddl
)

client.request(withDefaultAcceptHeader).orDie.flatMap { response =>
if (response.status.isSuccess) {
if (endpoint.output.matchesStatus(response.status)) {
endpoint.output.decodeResponse(response).orDie
} else if (endpoint.error.matchesStatus(response.status)) {
endpoint.error.decodeResponse(response).orDie.flip
} else {
// Preferentially decode an error from the handler, before falling back
// to decoding the middleware error:
val handlerError =
endpoint.error
.decodeResponse(response)
.map(e => alt.left(e))
.mapError(t => new IllegalStateException("Cannot deserialize using endpoint error codec", t))

val middlewareError =
invocation.middleware.error
.decodeResponse(response)
.map(e => alt.right(e))
.mapError(t => new IllegalStateException("Cannot deserialize using middleware error codec", t))

handlerError.catchAllCause { handlerCause =>
middlewareError.catchAllCause { middlewareCause =>
ZIO.failCause(handlerCause ++ middlewareCause)
}
}.orDie.flip
ZIO.die(new IllegalStateException(s"Status code: ${response.status} is not defined in the endpoint"))
}
}
}
Expand Down
Loading