Skip to content

Commit

Permalink
Merge pull request #794 from softwaremill/tapir-based-secuirty
Browse files Browse the repository at this point in the history
Tapir based secuirty
  • Loading branch information
adamw authored Feb 3, 2022
2 parents 5407d35 + b75ba4b commit 50ec550
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 27 deletions.
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
version=3.4.0
maxColumn = 140
runner.dialect = scala213
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class Http() extends Tapir with TapirJsonCirce with TapirSchemas with StrictLogg
/** Base endpoint description for secured endpoints. Specifies that errors are always returned as JSON values corresponding to the
* [[Error_OUT]] class, and that authentication is read from the `Authorization: Bearer` header.
*/
val secureEndpoint: Endpoint[Unit, Id, (StatusCode, Error_OUT), Unit, Any] =
baseEndpoint.in(auth.bearer[String]().map(_.asInstanceOf[Id])(identity))
val secureEndpoint: Endpoint[Id, Unit, (StatusCode, Error_OUT), Unit, Any] =
baseEndpoint.securityIn(auth.bearer[String]().map(_.asInstanceOf[Id])(identity))

//

Expand Down Expand Up @@ -62,9 +62,9 @@ class Http() extends Tapir with TapirJsonCirce with TapirSchemas with StrictLogg

//

implicit class TaskOut[T](f: IO[T]) {
implicit class IOOut[T](f: IO[T]) {

/** An extension method for [[IO]], which converts a possibly failed task, to a task which either returns the error converted to an
/** An extension method for [[IO]], which converts a possibly failed IO, to a task which either returns the error converted to an
* [[Error_OUT]] instance, or returns the successful value unchanged.
*/
def toOut: IO[Either[(StatusCode, Error_OUT), T]] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.softwaremill.bootzooka.security

import java.security.SecureRandom
import java.time.Instant
import cats.data.OptionT
import cats.effect.IO
import com.softwaremill.bootzooka._
Expand All @@ -11,6 +9,8 @@ import com.softwaremill.bootzooka.util._
import com.softwaremill.tagging._
import com.typesafe.scalalogging.StrictLogging

import java.security.SecureRandom
import java.time.Instant
import scala.concurrent.duration._

class Auth[T](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,36 @@ class UserApi(http: Http, auth: Auth[ApiKey], userService: UserService, xa: Tran
.in(UserPath / "changepassword")
.in(jsonBody[ChangePassword_IN])
.out(jsonBody[ChangePassword_OUT])
.serverLogic { case (authData, data) =>
(for {
userId <- auth(authData)
_ <- userService.changePassword(userId, data.currentPassword, data.newPassword).transact(xa)
} yield ChangePassword_OUT()).toOut
}
.serverSecurityLogic(authData => auth(authData).toOut)
.serverLogic(id =>
data =>
(for {
_ <- userService.changePassword(id, data.currentPassword, data.newPassword).transact(xa)
} yield ChangePassword_OUT()).toOut
)

private val getUserEndpoint = secureEndpoint.get
.in(UserPath)
.out(jsonBody[GetUser_OUT])
.serverLogic { authData =>
(for {
userId <- auth(authData)
user <- userService.findById(userId).transact(xa)
} yield GetUser_OUT(user.login, user.emailLowerCased, user.createdOn)).toOut
}
.serverSecurityLogic(authData => auth(authData).toOut)
.serverLogic(id =>
(_: Unit) =>
(for {
user <- userService.findById(id).transact(xa)
} yield GetUser_OUT(user.login, user.emailLowerCased, user.createdOn)).toOut
)

private val updateUserEndpoint = secureEndpoint.post
.in(UserPath)
.in(jsonBody[UpdateUser_IN])
.out(jsonBody[UpdateUser_OUT])
.serverLogic { case (authData, data) =>
(for {
userId <- auth(authData)
_ <- userService.changeUser(userId, data.login, data.email).transact(xa)
} yield UpdateUser_OUT()).toOut
}
.serverSecurityLogic(authData => auth(authData).toOut)
.serverLogic(id =>
data =>
(for {
_ <- userService.changeUser(id, data.login, data.email).transact(xa)
} yield UpdateUser_OUT()).toOut
)

val endpoints: ServerEndpoints =
NonEmptyList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trait TestEmbeddedPostgres extends BeforeAndAfterEach with BeforeAndAfterAll wit
override protected def beforeAll(): Unit = {
super.beforeAll()
postgres = EmbeddedPostgres.builder().start()
val url = postgres.getJdbcUrl("postgres", "postgres")
val url = postgres.getJdbcUrl("postgres")
postgres.getPostgresDatabase.getConnection.asInstanceOf[PgConnection].setPrepareThreshold(100)
currentDbConfig = TestConfig.db.copy(
username = "postgres",
Expand Down
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.sys.process.Process
import complete.DefaultParsers._

val doobieVersion = "1.0.0-RC2"
val http4sVersion = "0.23.8"
val http4sVersion = "0.23.9"
val circeVersion = "0.14.1"
val tsecVersion = "0.4.0"
val sttpVersion = "3.4.1"
Expand Down Expand Up @@ -83,7 +83,7 @@ val emailDependencies = Seq(
val scalatest = "org.scalatest" %% "scalatest" % "3.2.11" % Test
val unitTestingStack = Seq(scalatest)

val embeddedPostgres = "com.opentable.components" % "otj-pg-embedded" % "0.13.4" % Test
val embeddedPostgres = "com.opentable.components" % "otj-pg-embedded" % "1.0.0" % Test
val dbTestingStack = Seq(embeddedPostgres)

val commonDependencies = baseDependencies ++ unitTestingStack ++ loggingDependencies ++ configDependencies
Expand Down

0 comments on commit 50ec550

Please sign in to comment.