Skip to content

Commit

Permalink
Add Session to Prepare/Extract and Encoders/Decoders. Correct session…
Browse files Browse the repository at this point in the history
….udtValueOf delegation
  • Loading branch information
deusaquilus committed Aug 3, 2021
1 parent 9bc8d33 commit 3f3678b
Show file tree
Hide file tree
Showing 140 changed files with 643 additions and 516 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MysqlAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConnect
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result match {
case r: MySQLQueryResult =>
returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId)))
returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId)), ())
case _ =>
fail("This is a bug. Cannot extract returning value.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ class MysqlAsyncContextSpec extends Spec {
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), (row, session) => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}

"prepare" in {
testContext.prepareParams("", ps => (Nil, ps ++ List("Sarah", 127))) mustEqual List("'Sarah'", "127")
testContext.prepareParams("", (ps, session) => (Nil, ps ++ List("Sarah", 127))) mustEqual List("'Sarah'", "127")
}

override protected def beforeAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result.rows match {
case Some(r) if r.nonEmpty =>
returningExtractor(r.head)
returningExtractor(r.head, ())
case _ =>
fail("This is a bug. Cannot extract returning value.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trait ArrayDecoders extends ArrayEncoding {

def arrayDecoder[I, O, Col <: Seq[O]](mapper: I => O)(implicit bf: CBF[O, Col], iTag: ClassTag[I], oTag: ClassTag[O]): Decoder[Col] =
AsyncDecoder[Col](SqlTypes.ARRAY)(new BaseDecoder[Col] {
def apply(index: Index, row: ResultRow): Col = {
def apply(index: Index, row: ResultRow, session: Session): Col = {
row(index) match {
case seq: IndexedSeq[Any] => seq.foldLeft(bf.newBuilder) {
case (b, x: I) => b += mapper(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ class PostgresAsyncContextSpec extends Spec {
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), (row, session) => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}

"prepare" in {
testContext.prepareParams("", ps => (Nil, ps ++ List("Sarah", 127))) mustEqual List("'Sarah'", "127")
testContext.prepareParams("", (ps, session) => (Nil, ps ++ List("Sarah", 127))) mustEqual List("'Sarah'", "127")
}

override protected def beforeAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
override type RunActionReturningResult[T] = T
override type RunBatchActionResult = List[Long]
override type RunBatchActionReturningResult[T] = List[T]
override type Session = Unit

override def close = {
Await.result(pool.close, Duration.Inf)
Expand Down Expand Up @@ -70,11 +71,11 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
}

def executeQuery[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit ec: ExecutionContext): Future[List[T]] = {
val (params, values) = prepare(Nil)
val (params, values) = prepare(Nil, ())
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(sql, values)).map {
_.rows match {
case Some(rows) => rows.map(extractor).toList
case Some(rows) => rows.map(row => extractor(row, ())).toList
case None => Nil
}
}
Expand All @@ -84,14 +85,14 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
executeQuery(sql, prepare, extractor).map(handleSingleResult)

def executeAction[T](sql: String, prepare: Prepare = identityPrepare)(implicit ec: ExecutionContext): Future[Long] = {
val (params, values) = prepare(Nil)
val (params, values) = prepare(Nil, ())
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(sql, values)).map(_.rowsAffected)
}

def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningAction)
val (params, values) = prepare(Nil)
val (params, values) = prepare(Nil, ())
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(expanded, values))
.map(extractActionResult(returningAction, extractor))
Expand Down Expand Up @@ -124,6 +125,6 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
}.map(_.flatten.toList)

override private[getquill] def prepareParams(statement: String, prepare: Prepare): Seq[String] = {
prepare(Nil)._2.map(param => prepareParam(param))
prepare(Nil, ())._2.map(param => prepareParam(param))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ trait Decoders {

case class AsyncDecoder[T](sqlType: DecoderSqlType)(implicit decoder: BaseDecoder[T])
extends BaseDecoder[T] {
override def apply(index: Index, row: ResultRow) =
decoder(index, row)
override def apply(index: Index, row: ResultRow, session: Session) =
decoder(index, row, session)
}

def decoder[T: ClassTag](
f: PartialFunction[Any, T] = PartialFunction.empty,
sqlType: DecoderSqlType
): Decoder[T] =
AsyncDecoder[T](sqlType)(new BaseDecoder[T] {
def apply(index: Index, row: ResultRow) = {
def apply(index: Index, row: ResultRow, session: Session) = {
row(index) match {
case value: T => value
case value if f.isDefinedAt(value) => f(value)
Expand All @@ -40,12 +40,12 @@ trait Decoders {

implicit def mappedDecoder[I, O](implicit mapped: MappedEncoding[I, O], decoder: Decoder[I]): Decoder[O] =
AsyncDecoder(decoder.sqlType)(new BaseDecoder[O] {
def apply(index: Index, row: ResultRow): O =
mapped.f(decoder.apply(index, row))
def apply(index: Index, row: ResultRow, session: Session): O =
mapped.f(decoder.apply(index, row, session))
})

trait NumericDecoder[T] extends BaseDecoder[T] {
def apply(index: Index, row: ResultRow) =
def apply(index: Index, row: ResultRow, session: Session) =
row(index) match {
case v: Byte => decode(v)
case v: Short => decode(v)
Expand All @@ -63,10 +63,10 @@ trait Decoders {

implicit def optionDecoder[T](implicit d: Decoder[T]): Decoder[Option[T]] =
AsyncDecoder(d.sqlType)(new BaseDecoder[Option[T]] {
def apply(index: Index, row: ResultRow) = {
def apply(index: Index, row: ResultRow, session: Session) = {
row(index) match {
case null => None
case value => Some(d(index, row))
case value => Some(d(index, row, session))
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@ trait Encoders {

case class AsyncEncoder[T](sqlType: DecoderSqlType)(implicit encoder: BaseEncoder[T])
extends BaseEncoder[T] {
override def apply(index: Index, value: T, row: PrepareRow) =
encoder.apply(index, value, row)
override def apply(index: Index, value: T, row: PrepareRow, session: Session) =
encoder.apply(index, value, row, session)
}

def encoder[T](sqlType: DecoderSqlType): Encoder[T] =
encoder(identity[T], sqlType)

def encoder[T](f: T => Any, sqlType: DecoderSqlType): Encoder[T] =
AsyncEncoder[T](sqlType)(new BaseEncoder[T] {
def apply(index: Index, value: T, row: PrepareRow) =
def apply(index: Index, value: T, row: PrepareRow, session: Session) =
row :+ f(value)
})

implicit def mappedEncoder[I, O](implicit mapped: MappedEncoding[I, O], e: Encoder[O]): Encoder[I] =
AsyncEncoder(e.sqlType)(new BaseEncoder[I] {
def apply(index: Index, value: I, row: PrepareRow) =
e(index, mapped.f(value), row)
def apply(index: Index, value: I, row: PrepareRow, session: Session) =
e(index, mapped.f(value), row, session)
})

implicit def optionEncoder[T](implicit d: Encoder[T]): Encoder[Option[T]] =
AsyncEncoder(d.sqlType)(new BaseEncoder[Option[T]] {
def apply(index: Index, value: Option[T], row: PrepareRow) = {
def apply(index: Index, value: Option[T], row: PrepareRow, session: Session) = {
value match {
case None => nullEncoder(index, null, row)
case Some(v) => d(index, v, row)
case None => nullEncoder(index, null, row, session)
case Some(v) => d(index, v, row, session)
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ trait UUIDObjectEncoding {

implicit val uuidDecoder: Decoder[UUID] =
AsyncDecoder(SqlTypes.UUID)(
(index: Index, row: ResultRow) => row(index) match {
(index: Index, row: ResultRow, session: Session) => row(index) match {
case value: UUID => value
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ trait UUIDStringEncoding {

implicit val uuidDecoder: Decoder[UUID] =
AsyncDecoder(SqlTypes.UUID)(
(index: Index, row: ResultRow) => row(index) match {
(index: Index, row: ResultRow, session: Session) => row(index) match {
case value: String => UUID.fromString(value)
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ class CassandraLagomAsyncContext[N <: NamingStrategy](
override type RunQueryResult[T] = Seq[T]
override type RunActionResult = Done
override type RunBatchActionResult = Done
override type Session = CassandraLagomSession

private val logger = ContextLogger(this.getClass)

def prepareAction[T](cql: String, prepare: Prepare = identityPrepare)(implicit executionContext: ExecutionContext): CassandraSession => Future[BoundStatement] = (session: Session) => {
val prepareResult = session.prepare(cql).map(bs => prepare(bs.bind()))
def prepareAction[T](cql: String, prepare: Prepare = identityPrepare)(implicit executionContext: ExecutionContext): CassandraLagomSession => Future[BoundStatement] = (session: Session) => {
val prepareResult = session.cs.prepare(cql).map(bs => prepare(bs.bind(), session))
val preparedRow = prepareResult.map {
case (params, bs) =>
logger.logQuery(cql, params)
Expand All @@ -31,7 +32,7 @@ class CassandraLagomAsyncContext[N <: NamingStrategy](
preparedRow
}

def prepareBatchAction[T](groups: List[BatchGroup])(implicit executionContext: ExecutionContext): CassandraSession => Future[List[BoundStatement]] = (session: Session) => {
def prepareBatchAction[T](groups: List[BatchGroup])(implicit executionContext: ExecutionContext): CassandraLagomSession => Future[List[BoundStatement]] = (session: Session) => {
val batches = groups.flatMap {
case BatchGroup(cql, prepares) =>
prepares.map(cql -> _)
Expand All @@ -44,16 +45,16 @@ class CassandraLagomAsyncContext[N <: NamingStrategy](
}

def executeQuery[T](cql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit executionContext: ExecutionContext): Result[RunQueryResult[T]] = {
val statement = prepareAsyncAndGetStatement(cql, prepare, logger)
statement.flatMap(st => session.selectAll(st)).map(_.map(extractor))
val statement = prepareAsyncAndGetStatement(cql, prepare, wrappedSession, logger)
statement.flatMap(st => session.selectAll(st)).map(_.map(row => extractor(row, wrappedSession)))
}

def executeQuerySingle[T](cql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor)(implicit executionContext: ExecutionContext): Result[RunQuerySingleResult[T]] = {
executeQuery(cql, prepare, extractor).map(_.headOption)
}

def executeAction[T](cql: String, prepare: Prepare = identityPrepare)(implicit executionContext: ExecutionContext): Result[RunActionResult] = {
val statement = prepareAsyncAndGetStatement(cql, prepare, logger)
val statement = prepareAsyncAndGetStatement(cql, prepare, wrappedSession, logger)
statement.flatMap(st => session.executeWrite(st))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@ package io.getquill
import akka.Done
import com.datastax.driver.core.BoundStatement
import com.lightbend.lagom.scaladsl.persistence.cassandra.CassandraSession
import io.getquill.context.cassandra.CassandraSessionContext
import io.getquill.context.UdtValueLookup
import io.getquill.context.cassandra.CassandraSessionlessContext

import scala.concurrent.{ ExecutionContext, Future }

case class CassandraLagomSession(cs: CassandraSession) extends UdtValueLookup

abstract class CassandraLagomSessionContext[N <: NamingStrategy](
val naming: N,
val session: CassandraSession
)
extends CassandraSessionContext[N] {
extends CassandraSessionlessContext[N] {

override type RunActionResult = Done
override type RunBatchActionResult = Done
override type Session = CassandraSession
override type Session = CassandraLagomSession

val wrappedSession = CassandraLagomSession(session)

override def prepareAsync(cql: String)(implicit executionContext: ExecutionContext): Future[BoundStatement] = {
session.prepare(cql).map(_.bind())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class CassandraLagomStreamContext[N <: NamingStrategy](
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor
)(implicit executionContext: ExecutionContext): Result[RunQueryResult[T]] = {
val statement = prepareAsyncAndGetStatement(cql, prepare, logger)
val resultSource = statement.map(st => session.select(st).map(extractor))
val statement = prepareAsyncAndGetStatement(cql, prepare, wrappedSession, logger)
val resultSource = statement.map(st => session.select(st).map(row => extractor(row, wrappedSession)))
Source
.fromFutureSource(resultSource)
.mapMaterializedValue(_ => NotUsed)
Expand All @@ -50,7 +50,7 @@ class CassandraLagomStreamContext[N <: NamingStrategy](
implicit
executionContext: ExecutionContext
): Result[RunActionResult] = {
val statement = prepareAsyncAndGetStatement(cql, prepare, logger)
val statement = prepareAsyncAndGetStatement(cql, prepare, CassandraLagomSession(session), logger)
Source.fromFuture(statement).mapAsync(1) { st =>
session.executeWrite(st)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class QueryResultTypeCassandraAsyncSpec extends QueryResultTypeCassandraSpec {
"bind" - {
"action" - {
"noArgs" in {
val bs = result(context.prepare(insert(OrderTestEntity(1, 2))))
val bs = result(context.prepare(insert(OrderTestEntity(1, 2)))(context.wrappedSession))
bs.preparedStatement().getVariables.size() mustEqual 0
}

"withArgs" in {
val bs = result(context.prepare(insert(lift(OrderTestEntity(1, 2)))))
val bs = result(context.prepare(insert(lift(OrderTestEntity(1, 2))))(context.wrappedSession))
bs.preparedStatement().getVariables.size() mustEqual 2
bs.getInt("id") mustEqual 1
bs.getInt("i") mustEqual 2
Expand All @@ -44,12 +44,12 @@ class QueryResultTypeCassandraAsyncSpec extends QueryResultTypeCassandraSpec {

"query" - {
"noArgs" in {
val bs = result(context.prepare(deleteAll))
val bs = result(context.prepare(deleteAll)(context.wrappedSession))
bs.preparedStatement().getVariables.size() mustEqual 0
}

"withArgs" in {
val batches = result(context.prepare(liftQuery(List(OrderTestEntity(1, 2))).foreach(e => insert(e))))
val batches = result(context.prepare(liftQuery(List(OrderTestEntity(1, 2))).foreach(e => insert(e)))(context.wrappedSession))
batches.foreach { bs =>
bs.preparedStatement().getVariables.size() mustEqual 2
bs.getInt("id") mustEqual 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class CassandraMonixContext[N <: NamingStrategy](
.flatMap(Observable.fromAsyncStateAction((rs: ResultSet) => page(rs).map((_, rs)))(_))
.takeWhile(_.nonEmpty)
.flatMap(Observable.fromIterable)
.map(extractor)
.map(row => extractor(row, this))
}

def executeQuery[T](cql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor): Task[List[T]] = {
Expand Down Expand Up @@ -84,7 +84,7 @@ class CassandraMonixContext[N <: NamingStrategy](
implicit val executor: Scheduler = scheduler

super.prepareAsync(cql)
.map(prepare)
.map(row => prepare(row, this))
.onComplete {
case Success((params, bs)) =>
logger.logQuery(cql, params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CassandraStreamContext[N <: NamingStrategy](
.flatMap(Observable.fromAsyncStateAction((rs: ResultSet) => page(rs).map((_, rs)))(_))
.takeWhile(_.nonEmpty)
.flatMap(Observable.fromIterable)
.map(extractor)
.map(row => extractor(row, this))
}

def executeQuerySingle[T](cql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor): Observable[T] =
Expand All @@ -76,7 +76,7 @@ class CassandraStreamContext[N <: NamingStrategy](
implicit val executor: Scheduler = scheduler

super.prepareAsync(cql)
.map(prepare)
.map(row => prepare(row, this))
.onComplete {
case Success((params, bs)) =>
logger.logQuery(cql, params)
Expand Down
Loading

0 comments on commit 3f3678b

Please sign in to comment.