diff --git a/core/src/main/scala/io/github/gaelrenoux/tranzactio/ConnectionSource.scala b/core/src/main/scala/io/github/gaelrenoux/tranzactio/ConnectionSource.scala index d3d1e2b..d787d2e 100644 --- a/core/src/main/scala/io/github/gaelrenoux/tranzactio/ConnectionSource.scala +++ b/core/src/main/scala/io/github/gaelrenoux/tranzactio/ConnectionSource.scala @@ -29,13 +29,13 @@ object ConnectionSource { def runTransaction[R, E, A](task: Connection => ZIO[R, E, A], commitOnFailure: => Boolean = false) (implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[R, Either[DbException, E], A] = { ZIO.acquireReleaseWith(openConnection.mapError(Left(_)))(closeConnection(_).orDie) { (c: Connection) => - setAutoCommit(c, autoCommit = false) - .mapError(Left(_)) + setAutoCommit(c, autoCommit = false).mapError(Left(_)) .zipRight(task(c).mapError(Right(_))) - .tapBoth( - _ => if (commitOnFailure) commitConnection(c).mapError(Left(_)) else rollbackConnection(c).mapError(Left(_)), - _ => commitConnection(c).mapError(Left(_)) - ) + .tapErrorCause { (queryCause: Cause[Either[DbException, E]]) => + (if (commitOnFailure) commitConnection(c) else rollbackConnection(c)) + .mapErrorCause { rollbackCause => rollbackCause.map(Left(_)) && queryCause } + } + .zipLeft(commitConnection(c).mapError(Left(_))) } } diff --git a/core/src/test/scala/io/github/gaelrenoux/tranzactio/ConnectionSourceTest.scala b/core/src/test/scala/io/github/gaelrenoux/tranzactio/ConnectionSourceTest.scala new file mode 100644 index 0000000..3c09fda --- /dev/null +++ b/core/src/test/scala/io/github/gaelrenoux/tranzactio/ConnectionSourceTest.scala @@ -0,0 +1,125 @@ +package io.github.gaelrenoux.tranzactio + +import zio.test._ +import zio.{test => _, _} + +import java.sql.Connection + + +object ConnectionSourceTest extends ZIOSpec[TestEnvironment] { + type Env = TestEnvironment + type MySpec = Spec[Env, Any] + + implicit private val errorStrategies: ErrorStrategies = ErrorStrategies.Nothing + + // TODO add aspect to timeout tests to 5 seconds + + override def bootstrap: ZLayer[Any, Any, Env] = testEnvironment + + val connectionCountSql = "select count(*) from information_schema.sessions" + + def connectionCountQuery(c: Connection): ZIO[Any, Throwable, Int] = ZIO.attemptBlocking { + val stmt = c.prepareStatement(connectionCountSql) + try { + val rs = stmt.executeQuery() + rs.next() + val count = rs.getInt(1) + rs.close() + count + } finally { + stmt.close() + } + } + + def spec: MySpec = suite("Single connection ConnectionSource Tests")( + testRunTransactionFailureOnOpen, + testRunTransactionFailureOnAutoCommit, + testRunTransactionFailureOnCommit, + testRunTransactionFailureOnCommitAfterFailure, + testRunTransactionFailureOnRollback, + testRunTransactionFailureOnClose, + testRunAutoCommitFailureOnOpen, + testRunAutoCommitFailureOnAutoCommit, + testRunAutoCommitFailureOnClose + ) + + private val testRunTransactionFailureOnOpen = test("runTransaction failure > on open") { + val cs = new FailingConnectionSource(errorStrategies)(failOnOpen = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runTransaction(connectionCountQuery) + zio.flip.map { e => + assertTrue(e == Left(DbException.Wrapped(FailingConnectionSource.OpenException))) + } + } + + private val testRunTransactionFailureOnAutoCommit = test("runTransaction failure > on auto-commit") { + val cs = new FailingConnectionSource(errorStrategies)(failOnAutoCommit = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runTransaction(connectionCountQuery) + zio.flip.map { e => + assertTrue(e == Left(DbException.Wrapped(FailingConnectionSource.AutoCommitException))) + } + } + + private val testRunTransactionFailureOnCommit = test("runTransaction failure > on commit") { + val cs = new FailingConnectionSource(errorStrategies)(failOnCommit = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runTransaction(connectionCountQuery) + zio.flip.map { e => + assertTrue(e == Left(DbException.Wrapped(FailingConnectionSource.CommitException))) + } + } + + private val testRunTransactionFailureOnCommitAfterFailure = test("runTransaction failure > on commit (after failure)") { + val cs = new FailingConnectionSource(errorStrategies)(failOnCommit = true) + val zio: ZIO[Any, Either[DbException, String], Int] = cs.runTransaction(_ => ZIO.fail("Not a good query"), commitOnFailure = true) + zio.cause.map { + case Cause.Both(Cause.Fail(left, _), Cause.Fail(right, _)) => + assertTrue( + left == Left(DbException.Wrapped(FailingConnectionSource.CommitException)), + right == Right("Not a good query") + ) + } + } + + private val testRunTransactionFailureOnRollback = test("runTransaction failure > on rollback") { + val cs = new FailingConnectionSource(errorStrategies)(failOnRollback = true) + val zio: ZIO[Any, Either[DbException, String], Int] = cs.runTransaction(_ => ZIO.fail("Not a good query")) + zio.cause.map { + case Cause.Both(Cause.Fail(left, _), Cause.Fail(right, _)) => + assertTrue( + left == Left(DbException.Wrapped(FailingConnectionSource.RollbackException)), + right == Right("Not a good query") + ) + } + } + + private val testRunTransactionFailureOnClose = test("runTransaction failure > on close") { + val cs = new FailingConnectionSource(errorStrategies)(failOnClose = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runTransaction(connectionCountQuery) + zio.cause.map { + case Cause.Die(ex, _) => assertTrue(ex == DbException.Wrapped(FailingConnectionSource.CloseException)) + } + } + + private val testRunAutoCommitFailureOnOpen = test("runAutoCommit failure > on open") { + val cs = new FailingConnectionSource(errorStrategies)(failOnOpen = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runAutoCommit(connectionCountQuery) + zio.flip.map { e => + assertTrue(e == Left(DbException.Wrapped(FailingConnectionSource.OpenException))) + } + } + + private val testRunAutoCommitFailureOnAutoCommit = test("runAutoCommit failure > on auto-commit") { + val cs = new FailingConnectionSource(errorStrategies)(failOnAutoCommit = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runAutoCommit(connectionCountQuery) + zio.flip.map { e => + assertTrue(e == Left(DbException.Wrapped(FailingConnectionSource.AutoCommitException))) + } + } + + private val testRunAutoCommitFailureOnClose = test("runAutoCommit failure > on close") { + val cs = new FailingConnectionSource(errorStrategies)(failOnClose = true) + val zio: ZIO[Any, Either[DbException, Throwable], Int] = cs.runAutoCommit(connectionCountQuery) + zio.cause.map { + case Cause.Die(ex, _) => assertTrue(ex == DbException.Wrapped(FailingConnectionSource.CloseException)) + } + } +} diff --git a/core/src/test/scala/io/github/gaelrenoux/tranzactio/FailingConnectionSource.scala b/core/src/test/scala/io/github/gaelrenoux/tranzactio/FailingConnectionSource.scala new file mode 100644 index 0000000..473b61d --- /dev/null +++ b/core/src/test/scala/io/github/gaelrenoux/tranzactio/FailingConnectionSource.scala @@ -0,0 +1,56 @@ +package io.github.gaelrenoux.tranzactio + +import zio.{Task, Trace, ZIO} + +import java.sql.{Connection, DriverManager} +import java.util.UUID + +/** A ConnectionSource that fails on some operations */ +class FailingConnectionSource(defaultErrorStrategies: ErrorStrategiesRef)( + failOnOpen: Boolean = false, + failOnAutoCommit: Boolean = false, + failOnCommit: Boolean = false, + failOnRollback: Boolean = false, + failOnClose: Boolean = false +) extends ConnectionSource.ServiceBase(defaultErrorStrategies) { + + import FailingConnectionSource._ + + override protected def getConnection(implicit trace: Trace): Task[Connection] = ZIO.attemptBlocking { + val uuid = UUID.randomUUID().toString + DriverManager.getConnection(s"jdbc:h2:mem:$uuid;DB_CLOSE_DELAY=10", "sa", "sa") + } + + override def openConnection(implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[Any, DbException, Connection] = + if (failOnOpen) ZIO.fail(DbException.Wrapped(OpenException)) + else super.openConnection + + override def setAutoCommit(c: => Connection, autoCommit: => Boolean)(implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[Any, DbException, Unit] = { + if (failOnAutoCommit) ZIO.fail(DbException.Wrapped(AutoCommitException)) + else super.setAutoCommit(c, autoCommit) + } + + override def commitConnection(c: => Connection)(implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[Any, DbException, Unit] = + if (failOnCommit) ZIO.fail(DbException.Wrapped(CommitException)) + else super.commitConnection(c) + + override def rollbackConnection(c: => Connection)(implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[Any, DbException, Unit] = + if (failOnRollback) ZIO.fail(DbException.Wrapped(RollbackException)) + else super.rollbackConnection(c) + + override def closeConnection(c: => Connection)(implicit errorStrategies: ErrorStrategiesRef, trace: Trace): ZIO[Any, DbException, Unit] = + if (failOnClose) ZIO.fail(DbException.Wrapped(CloseException)) + else super.closeConnection(c) +} + +object FailingConnectionSource { + case object OpenException extends RuntimeException("Oops my connection") + + case object AutoCommitException extends RuntimeException("Oops my auto-commit") + + case object CommitException extends RuntimeException("Oops my commit") + + case object RollbackException extends RuntimeException("Oops my rollback") + + case object CloseException extends RuntimeException("Oops my closing") +}