diff --git a/zio-http/src/main/scala/zhttp/http/Middleware.scala b/zio-http/src/main/scala/zhttp/http/Middleware.scala index db047d6e91..12d6a8b909 100644 --- a/zio-http/src/main/scala/zhttp/http/Middleware.scala +++ b/zio-http/src/main/scala/zhttp/http/Middleware.scala @@ -124,6 +124,9 @@ sealed trait Middleware[-R, +E, +AIn, -BIn, -AOut, +BOut] { self => final def runAfter[R1 <: R, E1 >: E](effect: ZIO[R1, E1, Any]): Middleware[R1, E1, AIn, BIn, AOut, BOut] = self.mapZIO(bOut => effect.as(bOut)) + final def runBefore[R1 <: R, E1 >: E](effect: ZIO[R1, E1, Any]): Middleware[R1, E1, AIn, BIn, AOut, BOut] = + self.contramapZIO(b => effect.as(b)) + /** * Applies Middleware based only if the condition function evaluates to true */ diff --git a/zio-http/src/test/scala/zhttp/http/MiddlewareSpec.scala b/zio-http/src/test/scala/zhttp/http/MiddlewareSpec.scala index b52a15441f..66183d7a02 100644 --- a/zio-http/src/test/scala/zhttp/http/MiddlewareSpec.scala +++ b/zio-http/src/test/scala/zhttp/http/MiddlewareSpec.scala @@ -56,10 +56,20 @@ object MiddlewareSpec extends DefaultRunnableSpec with HExitAssertion { val app = Http.identity[Int] @@ mid assertM(app(0))(equalTo(3)) } + + testM("runBefore") { + val mid = Middleware.identity.runBefore(console.putStrLn("A")) + val app = Http.fromZIO(console.putStrLn("B")) @@ mid + assertM(app(()) *> TestConsole.output)(equalTo(Vector("A\n", "B\n"))) + } + testM("runAfter") { - val mid = Middleware.succeed(1).runAfter(console.putStrLn("A")) - val app = Http.succeed(1) @@ mid - assertM(app(()) *> TestConsole.output)(equalTo(Vector("A\n"))) + val mid = Middleware.identity.runAfter(console.putStrLn("B")) + val app = Http.fromZIO(console.putStrLn("A")) @@ mid + assertM(app(()) *> TestConsole.output)(equalTo(Vector("A\n", "B\n"))) + } + + testM("runBefore and runAfter") { + val mid = Middleware.identity.runBefore(console.putStrLn("A")).runAfter(console.putStrLn("C")) + val app = Http.fromZIO(console.putStrLn("B")) @@ mid + assertM(app(()) *> TestConsole.output)(equalTo(Vector("A\n", "B\n", "C\n"))) } + testM("race") { val mid = Middleware.succeed('A').delay(2 second) race Middleware.succeed("B").delay(1 second)