diff --git a/example/src/main/scala/example/PlainTextBenchmarkServer.scala b/example/src/main/scala/example/PlainTextBenchmarkServer.scala index 397eacdd0d..17952e0e5d 100644 --- a/example/src/main/scala/example/PlainTextBenchmarkServer.scala +++ b/example/src/main/scala/example/PlainTextBenchmarkServer.scala @@ -27,8 +27,8 @@ object Main extends App { .provideCustomLayer(ServerChannelFactory.auto ++ EventLoopGroup.auto(8)) .exitCode } - - private def app(response: Response) = Http.fromHExit(HExit.succeed(response)) + private val path = "/plaintext" + private def app(response: Response) = Http.fromHExit(HExit.succeed(response)).whenPathEq(path) private def server(response: Response) = Server.app(app(response)) ++ diff --git a/zio-http/src/main/scala/zhttp/http/Http.scala b/zio-http/src/main/scala/zhttp/http/Http.scala index 7ea3da92ca..7c77873c42 100644 --- a/zio-http/src/main/scala/zhttp/http/Http.scala +++ b/zio-http/src/main/scala/zhttp/http/Http.scala @@ -360,6 +360,12 @@ sealed trait Http[-R, +E, -A, +B] extends (A => ZIO[R, Option[E], B]) { self => final def unwrap[R1 <: R, E1 >: E, C](implicit ev: B <:< ZIO[R1, E1, C]): Http[R1, E1, A, C] = self.flatMap(Http.fromZIO(_)) + /** + * Applies Http based only if the condition function evaluates to true + */ + final def when[A2 <: A](f: A2 => Boolean): Http[R, E, A2, B] = + Http.When(f, self) + /** * Widens the type of the output */ @@ -413,6 +419,8 @@ sealed trait Http[-R, +E, -A, +B] extends (A => ZIO[R, Option[E], B]) { self => self.execute(a).foldExit(ee(_).execute(a), bb(_).execute(a), dd.execute(a)) case RunMiddleware(app, mid) => mid(app).execute(a) + + case When(f, other) => if (f(a)) other.execute(a) else HExit.empty } } @@ -451,6 +459,16 @@ object Http { */ override def updateHeaders(update: Headers => Headers): HttpApp[R, E] = http.map(_.updateHeaders(update)) + /** + * Applies Http based on the path + */ + def whenPathEq(p: Path): HttpApp[R, E] = http.whenPathEq(p.toString) + + /** + * Applies Http based on the path as string + */ + def whenPathEq(p: String): HttpApp[R, E] = http.when(_.unsafeEncode.uri().contentEquals(p)) + private[zhttp] def compile[R1 <: R]( zExec: HttpRuntime[R1], settings: Server.Config[R1, Throwable], @@ -825,9 +843,11 @@ object Http { private case class Attempt[A](a: () => A) extends Http[Any, Nothing, Any, A] - private case object Empty extends Http[Any, Nothing, Any, Nothing] - private final case class FromHExit[R, E, B](h: HExit[R, E, B]) extends Http[R, E, Any, B] + private final case class When[R, E, A, B](f: A => Boolean, other: Http[R, E, A, B]) extends Http[R, E, A, B] + + private case object Empty extends Http[Any, Nothing, Any, Nothing] + private case object Identity extends Http[Any, Nothing, Any, Nothing] } diff --git a/zio-http/src/main/scala/zhttp/http/Request.scala b/zio-http/src/main/scala/zhttp/http/Request.scala index 558fcfe5c4..144c6c8407 100644 --- a/zio-http/src/main/scala/zhttp/http/Request.scala +++ b/zio-http/src/main/scala/zhttp/http/Request.scala @@ -1,6 +1,7 @@ package zhttp.http import io.netty.buffer.{ByteBuf, ByteBufUtil} +import io.netty.handler.codec.http.{DefaultFullHttpRequest, HttpRequest} import zhttp.http.headers.HeaderExtension import zio.{Chunk, Task, UIO} @@ -21,6 +22,7 @@ trait Request extends HeaderExtension[Request] { self => override def method: Method = m override def url: URL = u override def headers: Headers = h + override def unsafeEncode: HttpRequest = self.unsafeEncode override def remoteAddress: Option[InetAddress] = self.remoteAddress override def data: HttpData = self.data } @@ -83,6 +85,11 @@ trait Request extends HeaderExtension[Request] { self => */ def setUrl(url: URL): Request = self.copy(url = url) + /** + * Gets the HttpRequest + */ + private[zhttp] def unsafeEncode: HttpRequest + /** * Gets the complete url */ @@ -108,10 +115,16 @@ object Request { val h = headers val ra = remoteAddress val d = data + new Request { override def method: Method = m override def url: URL = u override def headers: Headers = h + override def unsafeEncode: HttpRequest = { + val jVersion = Version.`HTTP/1.1`.toJava + val path = url.relative.encode + new DefaultFullHttpRequest(jVersion, method.toJava, path) + } override def remoteAddress: Option[InetAddress] = ra override def data: HttpData = d } @@ -137,6 +150,7 @@ object Request { override def method: Method = req.method override def remoteAddress: Option[InetAddress] = req.remoteAddress override def url: URL = req.url + override def unsafeEncode: HttpRequest = req.unsafeEncode override def data: HttpData = req.data } diff --git a/zio-http/src/main/scala/zhttp/service/Handler.scala b/zio-http/src/main/scala/zhttp/service/Handler.scala index 5667e673cd..4c2076ccfe 100644 --- a/zio-http/src/main/scala/zhttp/service/Handler.scala +++ b/zio-http/src/main/scala/zhttp/service/Handler.scala @@ -33,6 +33,8 @@ private[zhttp] final case class Handler[R]( override def headers: Headers = Headers.make(jReq.headers()) + override def unsafeEncode: HttpRequest = jReq + override def remoteAddress: Option[InetAddress] = { ctx.channel().remoteAddress() match { case m: InetSocketAddress => Some(m.getAddress) diff --git a/zio-http/src/test/scala/zhttp/http/HttpSpec.scala b/zio-http/src/test/scala/zhttp/http/HttpSpec.scala index 5fefc4e8c3..4bc849e50b 100644 --- a/zio-http/src/test/scala/zhttp/http/HttpSpec.scala +++ b/zio-http/src/test/scala/zhttp/http/HttpSpec.scala @@ -322,6 +322,18 @@ object HttpSpec extends DefaultRunnableSpec with HExitAssertion { assert(actual)(isSuccess(equalTo("bar"))) } } - }, + } + + suite("when")( + test("should execute http only when condition applies") { + val app = Http.succeed(1).when((_: Any) => true) + val actual = app.execute(0) + assert(actual)(isSuccess(equalTo(1))) + } + + test("should not execute http when condition doesn't apply") { + val app = Http.succeed(1).when((_: Any) => false) + val actual = app.execute(0) + assert(actual)(isEmpty) + }, + ), ) @@ timeout(10 seconds) }