diff --git a/example/src/main/scala/example/SignCookies.scala b/example/src/main/scala/example/SignCookies.scala new file mode 100644 index 0000000000..1d7d3cc2d6 --- /dev/null +++ b/example/src/main/scala/example/SignCookies.scala @@ -0,0 +1,24 @@ +package example + +import zhttp.http.Cookie.{httpOnly, maxAge, path, sign} +import zhttp.http.{Cookie, Method, Response, _} +import zhttp.service.Server +import zio.duration.durationInt +import zio.{App, ExitCode, URIO} + +/** + * Example to make app using cookies + */ +object SignCookies extends App { + + // Setting cookies with an expiry of 5 days + private val cookie = Cookie("key", "hello") @@ maxAge(5 days) + + private val app = Http.collect[Request] { case Method.GET -> !! / "cookie" => + Response.ok.addCookie(cookie @@ path(!! / "cookie") @@ httpOnly @@ sign("tobiiscool")) + } + + // Run it like any simple app + override def run(args: List[String]): URIO[zio.ZEnv, ExitCode] = + Server.start(8090, app).exitCode +} diff --git a/zio-http/src/main/scala/zhttp/http/Cookie.scala b/zio-http/src/main/scala/zhttp/http/Cookie.scala index 957aafe57d..75a39b3168 100644 --- a/zio-http/src/main/scala/zhttp/http/Cookie.scala +++ b/zio-http/src/main/scala/zhttp/http/Cookie.scala @@ -2,7 +2,11 @@ package zhttp.http import zio.duration._ +import java.security.MessageDigest import java.time.Instant +import java.util.Base64.getEncoder +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec import scala.util.Try final case class Cookie( @@ -15,6 +19,7 @@ final case class Cookie( isHttpOnly: Boolean = false, maxAge: Option[Long] = None, sameSite: Option[Cookie.SameSite] = None, + secret: Option[String] = None, ) { self => /** @@ -73,6 +78,20 @@ final case class Cookie( */ def withSameSite(v: Cookie.SameSite): Cookie = copy(sameSite = Some(v)) + /** + * Adds secret in the cookie + */ + def withSecret(key: String): Cookie = { + copy(secret = Some(key), content = sign(key)) + } + + /** + * Removes secret from the cookie + */ + def withoutSecret: Cookie = { + copy(secret = None) + } + /** * Resets secure flag in the cookie */ @@ -113,7 +132,7 @@ final case class Cookie( */ def encode: String = { val cookie = List( - Some(s"$name=$content"), + Some(s"$name=${content}"), expires.map(e => s"Expires=$e"), maxAge.map(a => s"Max-Age=${a.toString}"), domain.map(d => s"Domain=$d"), @@ -125,6 +144,26 @@ final case class Cookie( cookie.flatten.mkString("; ") } + def sign(secret: String): String = { + try { + val sha256 = Mac.getInstance("HmacSHA256") + val secretKey = new SecretKeySpec(secret.getBytes(), "RSA") + sha256.init(secretKey) + val signed = sha256.doFinal(self.content.getBytes()) + val mda = MessageDigest.getInstance("SHA-512") + self.content + '.' + getEncoder.encodeToString(mda.digest(signed)) + } catch { + case _: Exception => self.content + } + } + + def unSign(secret: String): Option[Cookie] = { + val str = self.content.slice(0, content.lastIndexOf('.')) + val encryptedCookie = self.withContent(str).sign(secret) + if (encryptedCookie == self.content) + Some(self.withSecret(secret).withContent(str)) + else None + } } object Cookie { @@ -139,6 +178,17 @@ object Cookie { } case class Update(f: Cookie => Cookie) + /** + * Decodes from Set-Cookie header value inside of Response into a cookie + */ + def decodeResponseSignedCookie(headerValue: String, secret: Option[String]): Option[Cookie] = { + val decodedCookie = decodeResponseCookie(headerValue) + secret match { + case Some(value) => decodedCookie.flatMap(_.unSign(value)) + case None => decodedCookie + } + } + /** * Decodes from Set-Cookie header value inside of Response into a cookie */ @@ -146,14 +196,13 @@ object Cookie { val cookieWithoutMeta = headerValue.split(";").map(_.trim) val (first, other) = (cookieWithoutMeta.head, cookieWithoutMeta.tail) val (name, content) = splitNameContent(first) - - val cookie = + val cookie = if (name.trim == "" && content.isEmpty) Option.empty[Cookie] else Some(Cookie(name, content.getOrElse(""))) other.map(splitNameContent).map(t => (t._1.toLowerCase, t._2)).foldLeft(cookie) { - case (Some(c), ("expires", Some(v))) => parseDate(v).map(c.withExpiry(_)) - case (Some(c), ("max-age", Some(v))) => Try(v.toLong).toOption.map(c.withMaxAge(_)) + case (Some(c), ("expires", Some(v))) => parseDate(v).map(c.withExpiry) + case (Some(c), ("max-age", Some(v))) => Try(v.toLong).toOption.map(c.withMaxAge) case (Some(c), ("domain", v)) => Some(c.withDomain(v.getOrElse(""))) case (Some(c), ("path", v)) => Some(c.withPath(Path(v.getOrElse("")))) case (Some(c), ("secure", _)) => Some(c.withSecure) @@ -229,4 +278,9 @@ object Cookie { * Updates sameSite in cookie */ def sameSite(sameSite: SameSite): Update = Update(_.withSameSite(sameSite)) + + /** + * Signs content in cookie + */ + def sign(secret: String): Update = Update(_.withSecret(secret)) } diff --git a/zio-http/src/main/scala/zhttp/http/HasCookie.scala b/zio-http/src/main/scala/zhttp/http/HasCookie.scala index 7902328564..7b8c3839d6 100644 --- a/zio-http/src/main/scala/zhttp/http/HasCookie.scala +++ b/zio-http/src/main/scala/zhttp/http/HasCookie.scala @@ -8,6 +8,7 @@ import io.netty.handler.codec.http.HttpHeaderNames sealed trait HasCookie[-A] { def headers(a: A): List[String] def decode(a: A): List[Cookie] + def decodeSignedCookie(a: A, secret: String): List[Cookie] } object HasCookie { @@ -22,6 +23,12 @@ object HasCookie { case Some(list) => list } } + + override def decodeSignedCookie(a: Request, secret: String): List[Cookie] = + headers(a).map(headerValue => Cookie.decodeResponseSignedCookie(headerValue, Some(secret))).collect { + case Some(cookie) => cookie + } + } implicit object ResponseCookie extends HasCookie[Response[Any, Nothing]] { @@ -30,5 +37,10 @@ object HasCookie { override def decode(a: Response[Any, Nothing]): List[Cookie] = headers(a).map(Cookie.decodeResponseCookie).collect { case Some(cookie) => cookie } + + override def decodeSignedCookie(a: Response[Any, Nothing], secret: String): List[Cookie] = + headers(a).map(headerValue => Cookie.decodeResponseSignedCookie(headerValue, Some(secret))).collect { + case Some(cookie) => cookie + } } } diff --git a/zio-http/src/main/scala/zhttp/http/HeaderExtension.scala b/zio-http/src/main/scala/zhttp/http/HeaderExtension.scala index a41d8752bd..4e4261c398 100644 --- a/zio-http/src/main/scala/zhttp/http/HeaderExtension.scala +++ b/zio-http/src/main/scala/zhttp/http/HeaderExtension.scala @@ -65,6 +65,9 @@ private[zhttp] trait HeaderExtension[+A] { self: A => final def getCookiesRaw(implicit ev: HasCookie[A]): List[CharSequence] = ev.headers(self) + final def getSignedCookies(secret: String)(implicit ev: HasCookie[A]): List[Cookie] = + ev.decodeSignedCookie(self, secret) + final def getHeader(headerName: CharSequence): Option[Header] = getHeaders.find(h => contentEqualsIgnoreCase(h.name, headerName)) diff --git a/zio-http/src/main/scala/zhttp/http/Middleware.scala b/zio-http/src/main/scala/zhttp/http/Middleware.scala index 8486529b59..4a61c44845 100644 --- a/zio-http/src/main/scala/zhttp/http/Middleware.scala +++ b/zio-http/src/main/scala/zhttp/http/Middleware.scala @@ -2,7 +2,7 @@ package zhttp.http import io.netty.handler.codec.http.HttpHeaderNames import io.netty.util.AsciiString -import io.netty.util.AsciiString.toLowerCase +import io.netty.util.AsciiString.{contentEqualsIgnoreCase, toLowerCase} import zhttp.http.CORS.DefaultCORSConfig import zhttp.http.HeaderExtension.Only import zhttp.http.Middleware.{Flag, RequestP} @@ -116,6 +116,27 @@ object Middleware { def basicAuth[R, E](u: String, p: String): Middleware[R, E] = basicAuth((user, password) => (user == u) && (password == p)) + /** + * Creates a middleware for signing cookies + */ + def signCookies[R, E](secret: String): Middleware[R, E] = + Middleware + .make((_, _, _) => ()) { case (_, _, _) => + Patch.updateHeaders(resHeaders => + resHeaders + .filter(h => contentEqualsIgnoreCase(h.name, HttpHeaderNames.SET_COOKIE)) + .map(_.value.toString) + .map(a => { + val c = Cookie.decodeResponseCookie(a) + println(c) + c + }) + .collect { case Some(cookie) => + Header.custom(HttpHeaderNames.SET_COOKIE.toString, cookie.withSecret(secret).encode) + }, + ) + } + /** * Creates a middleware for Cross-Origin Resource Sharing (CORS). * @see diff --git a/zio-http/src/main/scala/zhttp/http/Patch.scala b/zio-http/src/main/scala/zhttp/http/Patch.scala index b22d884ac5..18fb78155b 100644 --- a/zio-http/src/main/scala/zhttp/http/Patch.scala +++ b/zio-http/src/main/scala/zhttp/http/Patch.scala @@ -17,6 +17,7 @@ sealed trait Patch { self => case Patch.RemoveHeaders(headers) => res.removeHeaders(headers) case Patch.SetStatus(status) => res.setStatus(status) case Patch.Combine(self, other) => loop[R1, E1](self(res), other) + case Patch.UpdateHeaders(f) => res.updateHeaders(f) } loop(res, self) @@ -24,16 +25,18 @@ sealed trait Patch { self => } object Patch { - case object Empty extends Patch - final case class AddHeaders(headers: List[Header]) extends Patch - final case class RemoveHeaders(headers: List[String]) extends Patch - final case class SetStatus(status: Status) extends Patch - final case class Combine(left: Patch, right: Patch) extends Patch + case object Empty extends Patch + final case class AddHeaders(headers: List[Header]) extends Patch + final case class RemoveHeaders(headers: List[String]) extends Patch + final case class SetStatus(status: Status) extends Patch + final case class Combine(left: Patch, right: Patch) extends Patch + final case class UpdateHeaders(f: List[Header] => List[Header]) extends Patch - def empty: Patch = Empty - def addHeaders(headers: List[Header]): Patch = AddHeaders(headers) - def addHeader(header: Header): Patch = AddHeaders(List(header)) - def addHeader(name: String, value: String): Patch = AddHeaders(List(Header(name, value))) - def removeHeaders(headers: List[String]): Patch = RemoveHeaders(headers) - def setStatus(status: Status): Patch = SetStatus(status) + def empty: Patch = Empty + def addHeaders(headers: List[Header]): Patch = AddHeaders(headers) + def addHeader(header: Header): Patch = AddHeaders(List(header)) + def addHeader(name: String, value: String): Patch = AddHeaders(List(Header(name, value))) + def removeHeaders(headers: List[String]): Patch = RemoveHeaders(headers) + def updateHeaders(f: List[Header] => List[Header]): Patch = UpdateHeaders(f) + def setStatus(status: Status): Patch = SetStatus(status) } diff --git a/zio-http/src/test/scala/zhttp/http/CookieSpec.scala b/zio-http/src/test/scala/zhttp/http/CookieSpec.scala index bfc90c0cff..c274663520 100644 --- a/zio-http/src/test/scala/zhttp/http/CookieSpec.scala +++ b/zio-http/src/test/scala/zhttp/http/CookieSpec.scala @@ -1,16 +1,16 @@ package zhttp.http import zhttp.internal.HttpGen -import zio.test.Assertion.{equalTo, isSome} +import zio.test.Assertion.{equalTo, isNone, isSome} import zio.test._ object CookieSpec extends DefaultRunnableSpec { def spec = suite("Cookies") { suite("response cookies") { testM("encode/decode cookies with ZIO Test Gen") { - checkAll(HttpGen.cookies) { cookie => - val cookieString = cookie.encode - assert(Cookie.decodeResponseCookie(cookieString))(isSome(equalTo(cookie))) && + checkAll(HttpGen.cookies) { case (cookie, _) => + val cookieString = cookie.withoutSecret.encode + assert(Cookie.decodeResponseCookie(cookieString))(isSome(equalTo(cookie.withoutSecret))) && assert(Cookie.decodeResponseCookie(cookieString).map(_.encode))(isSome(equalTo(cookieString))) } } @@ -26,6 +26,25 @@ object CookieSpec extends DefaultRunnableSpec { assert(Cookie.decodeRequestCookie(message))(isSome(equalTo(cookies))) } } + } + + suite("sign/unsign cookies") { + testM("should sign/unsign cookies with same secret") { + checkAll(HttpGen.cookies) { case (cookie, secret) => + val cookieString = cookie.withSecret(secret.getOrElse("")).encode + assert(Cookie.decodeResponseSignedCookie(cookieString, secret))(isSome(equalTo(cookie))) && + assert( + Cookie.decodeResponseSignedCookie(cookieString, secret).map(_.withSecret(secret.getOrElse("")).encode), + )( + isSome(equalTo(cookieString)), + ) + } + } + + testM("should not unsign cookies with different secret") { + checkAll(HttpGen.cookies) { case (cookie, _) => + val cookieSigned = cookie.encode + assert(Cookie.decodeResponseSignedCookie(cookieSigned, Some("a")))(isNone) + } + } } } } diff --git a/zio-http/src/test/scala/zhttp/internal/HttpGen.scala b/zio-http/src/test/scala/zhttp/internal/HttpGen.scala index 60c834e61b..545039fcdd 100644 --- a/zio-http/src/test/scala/zhttp/internal/HttpGen.scala +++ b/zio-http/src/test/scala/zhttp/internal/HttpGen.scala @@ -17,7 +17,7 @@ object HttpGen { data <- dataGen } yield ClientParams(method -> url, headers, data) - def cookies: Gen[Random with Sized, Cookie] = for { + def cookies: Gen[Random with Sized, (Cookie, Option[String])] = for { name <- Gen.anyString content <- Gen.anyString expires <- Gen.option(Gen.anyInstant) @@ -27,7 +27,8 @@ object HttpGen { httpOnly <- Gen.boolean maxAge <- Gen.option(Gen.anyLong) sameSite <- Gen.option(Gen.fromIterable(List(Cookie.SameSite.Strict, Cookie.SameSite.Lax))) - } yield Cookie(name, content, expires, domain, path, secure, httpOnly, maxAge, sameSite) + secret <- Gen.option(Gen.anyString) + } yield (Cookie(name, content, expires, domain, path, secure, httpOnly, maxAge, sameSite, secret), secret) def header: Gen[Random with Sized, Header] = for { key <- Gen.alphaNumericStringBounded(1, 4) diff --git a/zio-http/src/test/scala/zhttp/middleware/MiddlewareSpec.scala b/zio-http/src/test/scala/zhttp/middleware/MiddlewareSpec.scala index 19002e0a70..a88ebfbbf7 100644 --- a/zio-http/src/test/scala/zhttp/middleware/MiddlewareSpec.scala +++ b/zio-http/src/test/scala/zhttp/middleware/MiddlewareSpec.scala @@ -1,5 +1,6 @@ package zhttp.middleware +import io.netty.handler.codec.http.HttpHeaderNames import zhttp.http.Middleware.cors import zhttp.http._ import zhttp.internal.HttpAppTestExtensions @@ -11,6 +12,8 @@ import zio.test.{DefaultRunnableSpec, assert, assertM} import zio.{UIO, ZIO, console} object MiddlewareSpec extends DefaultRunnableSpec with HttpAppTestExtensions { + val cookieRes = Header.custom(HttpHeaderNames.SET_COOKIE.toString, "key=value;httpOnly") + def cond(flg: Boolean) = (_: Any, _: Any, _: Any) => flg def condM(flg: Boolean) = (_: Any, _: Any, _: Any) => UIO(flg) @@ -195,6 +198,21 @@ object MiddlewareSpec extends DefaultRunnableSpec with HttpAppTestExtensions { res <- app(request) } yield assert(res.headers.map(_.toTuple))(hasSubset(expected.map(_.toTuple))) } + + } + + suite("signCookie") { + testM("should sign cookies") { + val app = + Http.ok.addHeader(cookieRes) @@ signCookies("secret") getHeader "set-cookie" + assertM(app(Request()))( + isSome( + equalTo( + "key=value.fm67q+j8zQjFXnLi22ckhgRC5qaQ9srgU3/Fli94OOWmtuo68xcm5LkXbcODb9taM/B48j6kws3eZ0MYDIeWTA==; HttpOnly", + ), + ), + ) + } + } }