Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Doobie streams #21

Merged
merged 4 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ val H2Version = "1.4.200"
libraryDependencies ++= Seq(
/* ZIO */
"dev.zio" %% "zio" % ZioVersion,
"dev.zio" %% "zio-streams" % ZioVersion,
"dev.zio" %% "zio-interop-cats" % ZioCatsVersion,

/* Doobie */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import cats.effect.Resource
import io.github.gaelrenoux.tranzactio.utils.ZCatsBlocker
import zio.blocking.Blocking
import zio.interop.catz._
import zio.stream.ZStream
import zio.stream.interop.fs2z._
import zio.{Has, Task, ZIO, ZLayer}


Expand All @@ -17,12 +19,22 @@ package object doobie extends Wrapper {
override final type Database = Has[Database.Service]
override final type Query[A] = _root_.doobie.ConnectionIO[A]
override final type TranzactIO[A] = ZIO[Connection, DbException, A]
final type TranzactIOStream[A] = ZStream[Connection, DbException, A]

/** Default queue size when converting from FS2 streams. */
final val DefaultStreamQueueSize = 16

override final def tzio[A](q: Query[A]): TranzactIO[A] =
ZIO.accessM[Connection] { c =>
c.get.trans.apply(q)
}.mapError(DbException.Wrapped)

/** Converts a Doobie stream to a ZStream. Note that you can provide a queue size, default value is the same as in ZIO. */
final def tzioStream[A](q: fs2.Stream[Query, A], queueSize: Int = DefaultStreamQueueSize): TranzactIOStream[A] =
ZStream.accessStream[Connection] { c =>
c.get.transP(monadErrorInstance).apply(q).toZStream(queueSize)
}.mapError(DbException.Wrapped)

/** Database for the Doobie wrapper */
object Database extends DatabaseModuleBase[Connection, DatabaseOps.ServiceOps[Connection]] {
self =>
Expand Down
1 change: 1 addition & 0 deletions src/samples/scala/samples/SamplesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ object SamplesSpec extends DefaultRunnableSpec {

override def spec: ZSpec[TestEnvironment, Any] = suite("SamplesSpec")(
testApp("Doobie", doobie.LayeredApp),
testApp("Doobie-Streaming", doobie.LayeredAppStreaming),
testApp("Anorm", anorm.LayeredApp)
)

Expand Down
2 changes: 1 addition & 1 deletion src/samples/scala/samples/doobie/LayeredApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object LayeredApp extends zio.App {
} yield trio

ZIO.accessM[AppEnv] { env =>
// if this implicit is not provided, tranwactio will use Conf.Root.dbRecovery instead
// if this implicit is not provided, tranzactio will use Conf.Root.dbRecovery instead
implicit val errorRecovery: ErrorStrategiesRef = env.get[Conf.Root].alternateDbRecovery
Database.transactionOrWidenR[AppEnv](queries)
}
Expand Down
57 changes: 57 additions & 0 deletions src/samples/scala/samples/doobie/LayeredAppStreaming.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package samples.doobie

import io.github.gaelrenoux.tranzactio.doobie._
import io.github.gaelrenoux.tranzactio.{DbException, ErrorStrategiesRef}
import samples.{Conf, ConnectionPool, Person}
import zio._
import zio.stream._

/** Same as LayeredApp, but using Doobie's stream (converted into ZIO strem). */
// scalastyle:off magic.number
object LayeredAppStreaming extends zio.App {

private val zenv = ZEnv.any
private val conf = Conf.live("samble-doobie-app-streaming")
private val dbRecoveryConf = conf >>> ZLayer.fromService { (c: Conf.Root) => c.dbRecovery }
private val datasource = (conf ++ zenv) >>> ConnectionPool.live
private val database = (datasource ++ zenv ++ dbRecoveryConf) >>> Database.fromDatasourceAndErrorStrategies
private val personQueries = PersonQueries.live

type AppEnv = ZEnv with Database with PersonQueries with Conf
private val appEnv = zenv ++ conf ++ database ++ personQueries

override def run(args: List[String]): ZIO[zio.ZEnv, Nothing, ExitCode] = {
val prog = for {
_ <- console.putStrLn("Starting the app")
trio <- myApp().provideLayer(appEnv)
_ <- console.putStrLn(trio.mkString(", "))
} yield ExitCode(0)

prog.orDie
}

/** Main code for the application. Results in a big ZIO depending on the AppEnv. */
def myApp(): ZIO[AppEnv, DbException, List[Person]] = {
val queries = for {
_ <- console.putStrLn("Creating the table")
_ <- PersonQueries.setup
_ <- console.putStrLn("Inserting the trio")
_ <- PersonQueries.insert(Person("Buffy", "Summers"))
_ <- PersonQueries.insert(Person("Willow", "Rosenberg"))
_ <- PersonQueries.insert(Person("Alexander", "Harris"))
_ <- PersonQueries.insert(Person("Rupert", "Giles")) // insert one more!
_ <- console.putStrLn("Reading the trio")
trio <- {
val stream: ZStream[PersonQueries with Connection, DbException, Person] = PersonQueries.listStream.take(3)
stream.run(Sink.foldLeft(List[Person]())(_.prepended(_)))
}
} yield trio.reverse

ZIO.accessM[AppEnv] { env =>
// if this implicit is not provided, tranzactio will use Conf.Root.dbRecovery instead
implicit val errorRecovery: ErrorStrategiesRef = env.get[Conf.Root].alternateDbRecovery
Database.transactionOrWidenR[AppEnv](queries)
}
}

}
9 changes: 9 additions & 0 deletions src/samples/scala/samples/doobie/PersonQueries.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import doobie.implicits._
import io.github.gaelrenoux.tranzactio.DbException
import io.github.gaelrenoux.tranzactio.doobie._
import samples.Person
import zio.stream.ZStream
import zio.{ULayer, ZIO, ZLayer}

object PersonQueries {
Expand All @@ -13,6 +14,8 @@ object PersonQueries {

val list: TranzactIO[List[Person]]

val listStream: TranzactIOStream[Person]

def insert(p: Person): TranzactIO[Unit]

val failing: TranzactIO[Unit]
Expand All @@ -33,6 +36,10 @@ object PersonQueries {
sql"""SELECT given_name, family_name FROM person""".query[Person].to[List]
}

val listStream: TranzactIOStream[Person] = tzioStream {
sql"""SELECT given_name, family_name FROM person""".query[Person].stream
}

def insert(p: Person): TranzactIO[Unit] = tzio {
sql"""INSERT INTO person (given_name, family_name) VALUES (${p.givenName}, ${p.familyName})"""
.update.run.map(_ => ())
Expand All @@ -48,6 +55,8 @@ object PersonQueries {

val list: ZIO[PersonQueries with Connection, DbException, List[Person]] = ZIO.accessM(_.get.list)

val listStream: ZStream[PersonQueries with Connection, DbException, Person] = ZStream.accessStream(_.get.listStream)

def insert(p: Person): ZIO[PersonQueries with Connection, DbException, Unit] = ZIO.accessM(_.get.insert(p))

val failing: ZIO[PersonQueries with Connection, DbException, Unit] = ZIO.accessM(_.get.failing)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package io.github.gaelrenoux.tranzactio.integration

import doobie.implicits._
import doobie.util.fragment.Fragment
import io.github.gaelrenoux.tranzactio.ConnectionSource
import io.github.gaelrenoux.tranzactio.doobie._
import samples.Person
import samples.doobie.PersonQueries
import zio.{ULayer, ZLayer}
import zio.blocking.Blocking
import zio.test.Assertion._
import zio.test._
import zio.{ULayer, ZLayer}


/** Integration tests for Doobie */
object DoobieIT extends ITSpec[Database, PersonQueries] {
Expand All @@ -18,6 +20,7 @@ object DoobieIT extends ITSpec[Database, PersonQueries] {
override val personQueriesLive: ULayer[PersonQueries] = PersonQueries.live

val buffy: Person = Person("Buffy", "Summers")
val giles: Person = Person("Rupert", "Giles")

val connectionCountQuery: TranzactIO[Int] = tzio(Fragment.const(connectionCountSql).query[Int].unique)

Expand All @@ -30,7 +33,8 @@ object DoobieIT extends ITSpec[Database, PersonQueries] {
testDataCommittedOnAutoCommitSuccess,
testConnectionClosedOnAutoCommitSuccess,
testDataRollbackedOnAutoCommitFailure,
testConnectionClosedOnAutoCommitFailure
testConnectionClosedOnAutoCommitFailure,
testStreamDoesNotLoadAllValues
)

private val testDataCommittedOnTransactionSuccess = testM("data committed on transaction success") {
Expand Down Expand Up @@ -105,4 +109,21 @@ object DoobieIT extends ITSpec[Database, PersonQueries] {
} yield assert(connectionCount)(equalTo(1)) // only the current connection
}

private val testStreamDoesNotLoadAllValues = testM("stream does not load all values") {
for {
_ <- Database.autoCommitR[PersonQueries](PersonQueries.setup)
_ <- Database.autoCommitR[PersonQueries](PersonQueries.insert(buffy))
_ <- Database.autoCommitR[PersonQueries](PersonQueries.insert(giles))
result <- Database.autoCommit {
val doobieStream = sql"""SELECT given_name, family_name FROM person""".query[Person]
.streamWithChunkSize(1) // make sure it's read one by one
.map { p =>
if (p.givenName == "Rupert") throw new IllegalStateException // fail on the second one, if it's ever read
else p
}
tzioStream(doobieStream).take(1).runHead // only keep one
}
} yield assert(result)(isSome(equalTo(buffy)))
}

}