Skip to content

Commit

Permalink
Implement missing operators in Middleware (#807)
Browse files Browse the repository at this point in the history
* Implement missing operators in Middleware

* fix as operator

* headers Middleware changes

* sign cookie

* extend with HeaderExtensions

* rename suite

* PR comments
  • Loading branch information
d11-amitsingh authored Jan 17, 2022
1 parent 7c387de commit 9a5d6bc
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 38 deletions.
18 changes: 17 additions & 1 deletion zio-http/src/main/scala/zhttp/http/Middleware.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ sealed trait Middleware[-R, +E, +AIn, -BIn, -AOut, +BOut] { self =>
*/
final def apply[R1 <: R, E1 >: E](http: Http[R1, E1, AIn, BIn]): Http[R1, E1, AOut, BOut] = execute(http)

/**
* Makes the middleware resolve with a constant Middleware
*/
final def as[BOut0](
bout: BOut0,
): Middleware[R, E, AIn, BIn, AOut, BOut0] =
self.map(_ => bout)

/**
* Combines two middleware that operate on the same input and output types, into one.
*/
Expand Down Expand Up @@ -129,7 +137,15 @@ sealed trait Middleware[-R, +E, +AIn, -BIn, -AOut, +BOut] { self =>
* Applies Middleware based only if the condition function evaluates to true
*/
final def when[AOut0 <: AOut](cond: AOut0 => Boolean): Middleware[R, E, AIn, BIn, AOut0, BOut] =
Middleware.ifThenElse[AOut0](cond(_))(
whenZIO(a => UIO(cond(a)))

/**
* Applies Middleware based only if the condition effectful function evaluates to true
*/
final def whenZIO[R1 <: R, E1 >: E, AOut0 <: AOut](
cond: AOut0 => ZIO[R1, E1, Boolean],
): Middleware[R1, E1, AIn, BIn, AOut0, BOut] =
Middleware.ifThenElseZIO[AOut0](cond(_))(
isTrue = _ => self,
isFalse = _ => Middleware.identity,
)
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zhttp/http/middleware/Cors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[zhttp] trait Cors {
),
)
case (_, Some(origin), _) if allowCORS(origin, req.method) =>
Middleware.addHeader(corsHeaders(origin, req.method, isPreflight = false))
Middleware.addHeaders(corsHeaders(origin, req.method, isPreflight = false))
case _ => Middleware.identity
}
})
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zhttp/http/middleware/Csrf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private[zhttp] trait Csrf {
tokenName: String = "x-csrf-token",
tokenGen: ZIO[R, Nothing, String] = UIO(UUID.randomUUID.toString),
): HttpMiddleware[R, E] =
Middleware.addCookieM(tokenGen.map(Cookie(tokenName, _)))
Middleware.addCookieZIO(tokenGen.map(Cookie(tokenName, _)))

def csrfValidate(tokenName: String = "x-csrf-token"): HttpMiddleware[Any, Nothing] = {
Middleware.whenHeader(
Expand Down
55 changes: 29 additions & 26 deletions zio-http/src/main/scala/zhttp/http/middleware/Web.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package zhttp.http.middleware

import zhttp.http._
import zhttp.http.middleware.Web._
import zhttp.http.headers.HeaderModifier
import zhttp.http.middleware.Web.{PartialResponseMake, PartialResponseMakeZIO}
import zio.clock.Clock
import zio.console.Console
import zio.duration.Duration
Expand All @@ -12,7 +13,8 @@ import java.io.IOException
/**
* Middlewares on an HttpApp
*/
private[zhttp] trait Web extends Cors with Csrf with Auth {
private[zhttp] trait Web extends Cors with Csrf with Auth with HeaderModifier[HttpMiddleware[Any, Nothing]] {
self =>

/**
* Logical operator to decide which middleware to select based on the predicate.
Expand Down Expand Up @@ -65,27 +67,15 @@ private[zhttp] trait Web extends Cors with Csrf with Auth {
* Sets cookie in response headers
*/
def addCookie(cookie: Cookie): HttpMiddleware[Any, Nothing] =
addHeader(Headers.setCookie(cookie))
self.withSetCookie(cookie)

/**
* Adds the provided header and value to the response
* Updates the provided list of headers to the response
*/
def addHeader(name: String, value: String): HttpMiddleware[Any, Nothing] =
patch((_, _) => Patch.addHeader(name, value))
override def updateHeaders(update: Headers => Headers): HttpMiddleware[Any, Nothing] =
Web.updateHeaders(update)

/**
* Adds the provided header to the response
*/
def addHeader(header: Headers): HttpMiddleware[Any, Nothing] =
patch((_, _) => Patch.addHeader(header))

/**
* Adds the provided list of headers to the response
*/
def addHeaders(headers: Headers): HttpMiddleware[Any, Nothing] =
patch((_, _) => Patch.addHeader(headers))

def addCookieM[R, E](cookie: ZIO[R, E, Cookie]): HttpMiddleware[R, E] =
def addCookieZIO[R, E](cookie: ZIO[R, E, Cookie]): HttpMiddleware[R, E] =
patchZIO((_, _) => cookie.mapBoth(Option(_), c => Patch.addHeader(Headers.setCookie(c))))

/**
Expand Down Expand Up @@ -125,12 +115,6 @@ private[zhttp] trait Web extends Cors with Csrf with Auth {
def patchZIO[R, E](f: (Status, Headers) => ZIO[R, Option[E], Patch]): HttpMiddleware[R, E] =
makeResponseZIO(_ => ZIO.unit)((status, headers, _) => f(status, headers))

/**
* Removes the header by name
*/
def removeHeader(name: String): HttpMiddleware[Any, Nothing] =
patch((_, _) => Patch.removeHeaders(List(name)))

/**
* Runs the effect before the request is passed on to the HttpApp on which the middleware is applied.
*/
Expand All @@ -142,6 +126,19 @@ private[zhttp] trait Web extends Cors with Csrf with Auth {
*/
def setStatus(status: Status): HttpMiddleware[Any, Nothing] = patch((_, _) => Patch.setStatus(status))

/**
* Creates a middleware for signing cookies
*/
def signCookies(secret: String): HttpMiddleware[Any, Nothing] =
updateHeaders {
case h if h.getHeader(HeaderNames.setCookie).isDefined =>
Headers(
HeaderNames.setCookie,
Cookie.decodeResponseCookie(h.getHeader(HeaderNames.setCookie).get._2.toString).get.sign(secret).encode,
)
case h => h
}

/**
* Creates a new constants middleware that always executes the app provided, independent of where the middleware is
* applied
Expand All @@ -167,7 +164,7 @@ private[zhttp] trait Web extends Cors with Csrf with Auth {
Middleware.makeZIO(req => f(MiddlewareRequest(req)))
}

object Web {
object Web extends HeaderModifier[HttpMiddleware[Any, Nothing]] {

final case class PartialResponseMake[S](req: MiddlewareRequest => S) extends AnyVal {
def apply(res: (Status, Headers, S) => Patch): HttpMiddleware[Any, Nothing] = {
Expand All @@ -189,4 +186,10 @@ object Web {
outgoing = (response, state) => res(response.status, response.getHeaders, state).map(patch => patch(response)),
)
}

/**
* Updates the current Headers with new one, using the provided update function passed.
*/
override def updateHeaders(update: Headers => Headers): HttpMiddleware[Any, Nothing] =
Middleware.patch((_, _) => Patch.updateHeaders(update))
}
27 changes: 27 additions & 0 deletions zio-http/src/test/scala/zhttp/http/MiddlewareSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ object MiddlewareSpec extends DefaultRunnableSpec with HExitAssertion {
val app = Http.succeed(1) @@ mid
assertM(app(()))(equalTo("OK"))
} +
testM("as") {
val mid = Middleware.fromHttp(Http.succeed("Not OK")).as("OK")
val app = Http.succeed(1) @@ mid
assertM(app(()))(equalTo("OK"))
} +
testM("interceptZIO") {
for {
ref <- Ref.make(0)
Expand Down Expand Up @@ -109,6 +114,28 @@ object MiddlewareSpec extends DefaultRunnableSpec with HExitAssertion {
assertM(app(0))(equalTo("0Foo0FooBar"))
}
} +
suite("when") {
val mid = Middleware.succeed(0)
testM("condition is true") {
val app = Http.identity[Int] @@ mid.when[Int](_ => true)
assertM(app(10))(equalTo(0))
} +
testM("condition is false") {
val app = Http.identity[Int] @@ mid.when[Int](_ => false)
assertM(app(1))(equalTo(1))
}
} +
suite("whenZIO") {
val mid = Middleware.succeed(0)
testM("condition is true") {
val app = Http.identity[Int] @@ mid.whenZIO[Any, Nothing, Int](_ => UIO(true))
assertM(app(10))(equalTo(0))
} +
testM("condition is false") {
val app = Http.identity[Int] @@ mid.whenZIO[Any, Nothing, Int](_ => UIO(false))
assertM(app(1))(equalTo(1))
}
} +
suite("codec") {
testM("codec success") {
val mid = Middleware.codec[String, Int](a => Right(a.toInt), b => Right(b.toString))
Expand Down
60 changes: 51 additions & 9 deletions zio-http/src/test/scala/zhttp/http/middleware/WebSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,38 @@ object WebSpec extends DefaultRunnableSpec with HttpAppTestExtensions {
private val midB = Middleware.addHeader("X-Custom", "B")

def spec = suite("HttpMiddleware") {

suite("debug") {
testM("log status method url and time") {
val program = run(app @@ debug) *> TestConsole.output
assertM(program)(equalTo(Vector("200 GET /health 1000ms\n")))
suite("headers suite") {
testM("addHeaders") {
val middleware = addHeaders(Headers("KeyA", "ValueA") ++ Headers("KeyB", "ValueB"))
val headers = (Http.ok @@ middleware).getHeaderValues
assertM(headers(Request()))(contains("ValueA") && contains("ValueB"))
} +
testM("log 404 status method url and time") {
val program = run(Http.empty ++ Http.notFound @@ debug) *> TestConsole.output
assertM(program)(equalTo(Vector("404 GET /health 0ms\n")))
testM("addHeader") {
val middleware = addHeader("KeyA", "ValueA")
val headers = (Http.ok @@ middleware).getHeaderValues
assertM(headers(Request()))(contains("ValueA"))
} +
testM("updateHeaders") {
val middleware = updateHeaders(_ => Headers("KeyA", "ValueA"))
val headers = (Http.ok @@ middleware).getHeaderValues
assertM(headers(Request()))(contains("ValueA"))
} +
testM("removeHeader") {
val middleware = removeHeader("KeyA")
val headers = (Http.succeed(Response.ok.setHeaders(Headers("KeyA", "ValueA"))) @@ middleware) getHeader "KeyA"
assertM(headers(Request()))(isNone)
}
} +
suite("debug") {
testM("log status method url and time") {
val program = run(app @@ debug) *> TestConsole.output
assertM(program)(equalTo(Vector("200 GET /health 1000ms\n")))
} +
testM("log 404 status method url and time") {
val program = run(Http.empty ++ Http.notFound @@ debug) *> TestConsole.output
assertM(program)(equalTo(Vector("404 GET /health 0ms\n")))
}
} +
suite("when") {
testM("condition is true") {
val program = run(app @@ debug.when(_ => true)) *> TestConsole.output
Expand All @@ -38,6 +59,16 @@ object WebSpec extends DefaultRunnableSpec with HttpAppTestExtensions {
assertM(log)(equalTo(Vector()))
}
} +
suite("whenZIO") {
testM("condition is true") {
val program = run(app @@ debug.whenZIO(_ => UIO(true))) *> TestConsole.output
assertM(program)(equalTo(Vector("200 GET /health 1000ms\n")))
} +
testM("condition is false") {
val log = run(app @@ debug.whenZIO(_ => UIO(false))) *> TestConsole.output
assertM(log)(equalTo(Vector()))
}
} +
suite("race") {
testM("achieved") {
val program = run(app @@ timeout(5 seconds)).map(_.status)
Expand Down Expand Up @@ -117,11 +148,22 @@ object WebSpec extends DefaultRunnableSpec with HttpAppTestExtensions {
testM("addCookieM") {
val cookie = Cookie("test", "testValue")
val app =
(Http.ok @@ addCookieM(UIO(cookie))).getHeader("set-cookie")
(Http.ok @@ addCookieZIO(UIO(cookie))).getHeader("set-cookie")
assertM(app(Request()))(
equalTo(Some(cookie.encode)),
)
}
} +
suite("signCookies") {
testM("should sign cookies") {
val cookie = Cookie("key", "value").withHttpOnly
val app = Http.ok.withSetCookie(cookie) @@ signCookies("secret") getHeader "set-cookie"
assertM(app(Request()))(isSome(equalTo(cookie.sign("secret").encode)))
} +
testM("sign cookies no cookie header") {
val app = (Http.ok.addHeader("keyA", "ValueA") @@ signCookies("secret")).getHeaderValues
assertM(app(Request()))(contains("ValueA"))
}
}
}

Expand Down

6 comments on commit 9a5d6bc

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 824445.20

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 814435.48

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 849142.89

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 844885.82

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 822919.80

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance Benchmark:

Concurrency: 256
Requests/sec: 845232.47

Please sign in to comment.