Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SqlRuntime #34

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ class PiggyPostgresqlSpec extends PGContainerSuite {
)
.map(_.map(Person.apply))
} yield (nIns, fetchedPeople)
val result = sql.executePool()(using pgPool)
val result = sql.executePool(using pgPool)
assert(result.isSuccess)
assertEquals(result.get._1, 10)
assertEquals(result.get._2.distinct.size, 10)
}

test("PiggyPostgresql Rollback") {
given PgConnectionPool = pgPool
assert(Sql.statement(ddl).executePool().isSuccess)
assert(Sql.statement(ddl).executePool.isSuccess)

val blowup = for {
nIns <- Sql.prepareUpdate(ins, tenPeople*)
_ <- Sql.statement("this is not valid sql")
} yield nIns
assert(blowup.executePool().isFailure)
assert(blowup.executePool.isFailure)

val sql = for {
fetchedPeople <- Sql
Expand All @@ -70,7 +70,7 @@ class PiggyPostgresqlSpec extends PGContainerSuite {
} yield {
fetchedPeople
}
val result: Try[Seq[Person]] = sql.executePool()
val result: Try[Seq[Person]] = sql.executePool
assert(result.isSuccess)
assertEquals(result.get.size, 0)

Expand All @@ -80,7 +80,7 @@ class PiggyPostgresqlSpec extends PGContainerSuite {
given pool: PgConnectionPool = pgPool

val tple =
Sql.statement(s"SELECT 1, 'two'", _.tupled[(Int, String)]).executePool()
Sql.statement(s"SELECT 1, 'two'", _.tupled[(Int, String)]).executePool

assertEquals(tple.get.get, (1, "two"))
}
Expand All @@ -98,7 +98,7 @@ class PiggyPostgresqlSpec extends PGContainerSuite {
_.tupledList[(Int, String, Int)]
)
} yield fetchedPeople
}.executePool().get
}.executePool.get

assert(readBack.size == 10)
assert(readBack.forall(_._2.startsWith("Mark")))
Expand All @@ -116,7 +116,7 @@ class PiggyPostgresqlSpec extends PGContainerSuite {
)
.map(_.map(Person.apply))
} yield (nIns, fetchedPeople)
val result = sql.executePool()(using pgPool)
val result = sql.executePool(using pgPool)
assert(result.isFailure)
assert(result.toEither.left.exists(_.getMessage == "boom"))
}
Expand Down
9 changes: 4 additions & 5 deletions branch/src/main/scala/dev/wishingtree/branch/piggy/Sql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import dev.wishingtree.branch.macaroni.poolers.ResourcePool

import java.sql.{Connection, PreparedStatement, ResultSet}
import scala.compiletime.*
import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.*

Expand Down Expand Up @@ -54,11 +53,11 @@ object Sql {
/** Execute this Sql operation using the given Connection. See
* [[SqlRuntime.execute]].
*/
def execute(d: Duration = Duration.Inf)(using
def execute(using
connection: Connection,
executionContext: ExecutionContext
): Try[A] = {
SqlRuntime.execute(a, d)
SqlRuntime.execute(a)

/** Execute this Sql operation using the given Connection, returning the
* result as a Future. See [[SqlRuntime.executeAsync]].
Expand All @@ -73,11 +72,11 @@ object Sql {
/** Execute this Sql operation using the given ResourcePool[Connection]. See
* [[SqlRuntime.executePool]].
*/
def executePool(d: Duration = Duration.Inf)(using
def executePool(using
pool: ResourcePool[Connection],
executionContext: ExecutionContext
): Try[A] =
SqlRuntime.executePool(a, d)
SqlRuntime.executePool(a)

/** Execute this Sql operation using the given ResourcePool[Connection],
* returning the result as a Future. See [[SqlRuntime.executePoolAsync]].
Expand Down
125 changes: 51 additions & 74 deletions branch/src/main/scala/dev/wishingtree/branch/piggy/SqlRuntime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@ import dev.wishingtree.branch.macaroni.poolers.ResourcePool

import java.sql.{Connection, PreparedStatement}
import scala.annotation.tailrec
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.*

private[piggy] trait SqlRuntime {

def execute[A](
sql: Sql[A],
d: Duration
)(using connection: Connection, executionContext: ExecutionContext): Try[A]
sql: Sql[A]
)(using connection: Connection): Try[A]

def executePool[A, B <: ResourcePool[Connection]](
sql: Sql[A],
d: Duration
)(using
pool: B,
executionContext: ExecutionContext
): Try[A]
sql: Sql[A]
)(using pool: B): Try[A]

def executeAsync[A](
sql: Sql[A]
Expand All @@ -39,14 +33,13 @@ object SqlRuntime extends SqlRuntime {
* entire chain of Sql operations is done over the given Connection, and the
* transaction is rolled back on failure.
*/
override def execute[A](sql: Sql[A], d: Duration = Duration.Inf)(using
connection: Connection,
executionContext: ExecutionContext
override def execute[A](sql: Sql[A])(using
connection: Connection
): Try[A] =
Try {
try {
connection.setAutoCommit(false)
val result = Await.result(evalF(sql), d)
val result = eval(sql).get
connection.commit()
result
} catch {
Expand All @@ -56,20 +49,15 @@ object SqlRuntime extends SqlRuntime {
} finally {
connection.setAutoCommit(true)
}

}

/** Execute a Sql[A] using a ResourcePool[Connection], returning the result as
* a Try. The entire chain of Sql operations is done over a single Connection
* from the pool, and the transaction is rolled back on failure.
*/
override def executePool[A, B <: ResourcePool[Connection]](
sql: Sql[A],
d: Duration = Duration.Inf
)(using
pool: B,
executionContext: ExecutionContext
): Try[A] =
sql: Sql[A]
)(using pool: B): Try[A] =
Try {
pool.use { conn =>
execute(sql)(using conn)
Expand Down Expand Up @@ -98,72 +86,61 @@ object SqlRuntime extends SqlRuntime {
}.flatten

@tailrec
private final def evalF[A](sql: Sql[A])(using
connection: Connection,
executionContext: ExecutionContext
): Future[A] = {
private final def eval[A](sql: Sql[A])(using
connection: Connection
): Try[A] = {
sql match {
case Sql.StatementRs(sql, fn) =>
Future.fromTry {
Using.Manager { use =>
val statement = use(connection.createStatement())
val res = statement.execute(sql)
val rs = use(statement.getResultSet)
fn(rs)
}
Using.Manager { use =>
val statement = use(connection.createStatement())
val res = statement.execute(sql)
val rs = use(statement.getResultSet)
fn(rs)
}
case Sql.StatementCount(sql) =>
Future.fromTry {
Using.Manager { use =>
val statement = use(connection.createStatement())
val res = statement.execute(sql)
statement.getUpdateCount
}
Using.Manager { use =>
val statement = use(connection.createStatement())
val res = statement.execute(sql)
statement.getUpdateCount
}
case Sql.PreparedExec(sqlFn, args) =>
Future.fromTry {
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
helpers.foreach(_.setAndExecute(ps))
}
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
helpers.foreach(_.setAndExecute(ps))
}
case Sql.PreparedUpdate(sqlFn, args) =>
Future.fromTry {
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
val counts: Seq[Int] = helpers.map(_.setAndExecuteUpdate(ps))
counts.foldLeft(0)(_ + _)
}
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
val counts: Seq[Int] = helpers.map(_.setAndExecuteUpdate(ps))
counts.foldLeft(0)(_ + _)
}
case Sql.PreparedQuery(sqlFn, rsFn, args) =>
Future.fromTry {
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
helpers.flatMap { h =>
rsFn(h.setAndExecuteQuery(ps))
}
Using.Manager { use =>
val helpers = args.map(sqlFn)
val ps: PreparedStatement =
use(connection.prepareStatement(helpers.head.psStr))
helpers.flatMap { h =>
rsFn(h.setAndExecuteQuery(ps))
}
}
case Sql.Fail(e) =>
Future.failed(e)
Failure(e)
case Sql.MappedValue(a) =>
Future.successful(a)
Success(a)
case Sql.Recover(sql, fm) =>
evalRecover(sql, fm)
case Sql.FlatMap(sql, fn) =>
sql match {
case Sql.FlatMap(s, f) =>
evalF(s.flatMap(f(_).flatMap(fn)))
eval(s.flatMap(f(_).flatMap(fn)))
case Sql.Recover(s, f) =>
evalRecoverFlatMap(s, f, fn)
case Sql.MappedValue(a) =>
evalF(fn(a))
eval(fn(a))
case s =>
evalFlatMap(s, fn)
}
Expand All @@ -174,24 +151,24 @@ object SqlRuntime extends SqlRuntime {
sql: Sql[A],
rf: Throwable => Sql[A],
fm: A => Sql[B]
)(using Connection, ExecutionContext): Future[B] = {
evalF(sql)
.recoverWith { case t: Throwable => evalF(rf(t)) }
.flatMap(a => evalF(fm(a)))
)(using Connection): Try[B] = {
eval(sql)
.recoverWith { case t: Throwable => eval(rf(t)) }
.flatMap(a => eval(fm(a)))
}

private def evalFlatMap[A, B](
sql: Sql[A],
fn: A => Sql[B]
)(using Connection, ExecutionContext): Future[B] = {
evalF(sql).flatMap(r => evalF(fn(r)))
)(using Connection): Try[B] = {
eval(sql).flatMap(r => eval(fn(r)))
}

private def evalRecover[A](
sql: Sql[A],
f: Throwable => Sql[A]
)(using Connection, ExecutionContext): Future[A] = {
evalF(sql).recoverWith { case t: Throwable => evalF(f(t)) }
)(using Connection): Try[A] = {
eval(sql).recoverWith { case t: Throwable => eval(f(t)) }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import dev.wishingtree.branch.veil.Veil
*/
trait Flag[R] {

/** The name of the flag, e.g. "help". This will be parsed as s"--$name",
* e.g. "--help"
/** The name of the flag, e.g. "help". This will be parsed as s"--$name", e.g.
* "--help"
*/
val name: String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import dev.wishingtree.branch.ursula.args.BooleanFlag

/** A flag that triggers the display of help information.
*
* This flag is a boolean flag, meaning it does not expect an argument.
* It can be triggered using either the long form "--help" or the short form "-h".
* This flag is a boolean flag, meaning it does not expect an argument. It can
* be triggered using either the long form "--help" or the short form "-h".
*/
case object HelpFlag extends BooleanFlag {
override val name: String = "help"
override val shortKey: String = "h"
override val description: String = "Prints help"
}
}