diff --git a/zio-http/src/main/scala/zio/http/Route.scala b/zio-http/src/main/scala/zio/http/Route.scala index 125389974a..2c6aa79e93 100644 --- a/zio-http/src/main/scala/zio/http/Route.scala +++ b/zio-http/src/main/scala/zio/http/Route.scala @@ -62,7 +62,8 @@ sealed trait Route[-Env, +Err] { self => self match { case Provided(route, env) => Provided(route.handleErrorCause(f), env) case Augmented(route, aspect) => Augmented(route.handleErrorCause(f), aspect) - case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + case Handled(routePattern, handler, location) => + Handled(routePattern, handler.mapErrorCause(c => f(c.asInstanceOf[Cause[Nothing]])), location) case Unhandled(rpm, handler, zippable, location) => val handler2: Handler[Env, Response, Request, Response] = { @@ -96,7 +97,8 @@ sealed trait Route[-Env, +Err] { self => self match { case Provided(route, env) => Provided(route.handleErrorCauseZIO(f), env) case Augmented(route, aspect) => Augmented(route.handleErrorCauseZIO(f), aspect) - case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + case Handled(routePattern, handler, location) => + Handled(routePattern, handler.mapErrorCauseZIO(c => f(c.asInstanceOf[Cause[Nothing]])), location) case Unhandled(rpm, handler, zippable, location) => val handler2: Handler[Env, Response, Request, Response] = { @@ -162,7 +164,14 @@ sealed trait Route[-Env, +Err] { self => self match { case Provided(route, env) => Provided(route.handleErrorRequestCause(f), env) case Augmented(route, aspect) => Augmented(route.handleErrorRequestCause(f), aspect) - case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + case Handled(routePattern, handler, location) => + Handled( + routePattern, + Handler.fromFunctionHandler[Request] { (req: Request) => + handler.mapErrorCause(c => f(req, c.asInstanceOf[Cause[Nothing]])) + }, + location, + ) case Unhandled(rpm, handler, zippable, location) => val handler2: Handler[Env, Response, Request, Response] = { @@ -201,7 +210,14 @@ sealed trait Route[-Env, +Err] { self => self match { case Provided(route, env) => Provided(route.handleErrorRequestCauseZIO(f), env) case Augmented(route, aspect) => Augmented(route.handleErrorRequestCauseZIO(f), aspect) - case Handled(routePattern, handler, location) => Handled(routePattern, handler, location) + case Handled(routePattern, handler, location) => + Handled( + routePattern, + Handler.fromFunctionHandler[Request] { (req: Request) => + handler.mapErrorCauseZIO(c => f(req, c.asInstanceOf[Cause[Nothing]])) + }, + location, + ) case Unhandled(rpm, handler, zippable, location) => val handler2: Handler[Env, Response, Request, Response] = { diff --git a/zio-http/src/test/scala/zio/http/RouteSpec.scala b/zio-http/src/test/scala/zio/http/RouteSpec.scala index ab24e57f13..d1d4c02ea5 100644 --- a/zio-http/src/test/scala/zio/http/RouteSpec.scala +++ b/zio-http/src/test/scala/zio/http/RouteSpec.scala @@ -124,6 +124,46 @@ object RouteSpec extends ZIOHttpSpec { resultWarning == "error accessing /endpoint: hmm...", ) }, + test("handleErrorCause should handle defects") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.dieMessage("hmm...") } + val errorHandled = route.handleErrorCause(_ => Response.text("error").status(Status.InternalServerError)) + val request = Request.get(URL.decode("/endpoint").toOption.get) + for { + response <- errorHandled.toHttpApp.runZIO(request) + bodyString <- response.body.asString + } yield assertTrue(extractStatus(response) == Status.InternalServerError, bodyString == "error") + }, + test("handleErrorCauseZIO should handle defects") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.dieMessage("hmm...") } + val errorHandled = + route.handleErrorCauseZIO(_ => ZIO.succeed(Response.text("error").status(Status.InternalServerError))) + val request = Request.get(URL.decode("/endpoint").toOption.get) + for { + response <- errorHandled.toHttpApp.runZIO(request) + bodyString <- response.body.asString + } yield assertTrue(extractStatus(response) == Status.InternalServerError, bodyString == "error") + }, + test("handleErrorRequestCause should handle defects") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.dieMessage("hmm...") } + val errorHandled = + route.handleErrorRequestCause((_, _) => Response.text("error").status(Status.InternalServerError)) + val request = Request.get(URL.decode("/endpoint").toOption.get) + for { + response <- errorHandled.toHttpApp.runZIO(request) + bodyString <- response.body.asString + } yield assertTrue(extractStatus(response) == Status.InternalServerError, bodyString == "error") + }, + test("handleErrorRequestCauseZIO should handle defects") { + val route = Method.GET / "endpoint" -> handler { (_: Request) => ZIO.dieMessage("hmm...") } + val errorHandled = route.handleErrorRequestCauseZIO((_, _) => + ZIO.succeed(Response.text("error").status(Status.InternalServerError)), + ) + val request = Request.get(URL.decode("/endpoint").toOption.get) + for { + response <- errorHandled.toHttpApp.runZIO(request) + bodyString <- response.body.asString + } yield assertTrue(extractStatus(response) == Status.InternalServerError, bodyString == "error") + }, ), ) }