Skip to content

Commit

Permalink
Enhancing batch query functionality (static and dynamic) to allow mul…
Browse files Browse the repository at this point in the history
…tiple lift & liftQuery (#100)
  • Loading branch information
deusaquilus authored May 16, 2022
1 parent d95a205 commit fb14ec1
Show file tree
Hide file tree
Showing 14 changed files with 644 additions and 279 deletions.
2 changes: 1 addition & 1 deletion quill-sql/src/main/scala/io/getquill/DslModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class LazyPlanter[T, PrepareRow, Session](value: T, uid: String) extends Pl
}

// Equivalent to CaseClassValueLift
case class EagerEntitiesPlanter[T, PrepareRow, Session](value: Iterable[T], uid: String) extends Planter[Query[T], PrepareRow, Session] {
case class EagerEntitiesPlanter[T, PrepareRow, Session](value: Iterable[T], uid: String, fieldGetters: List[InjectableEagerPlanter[?, PrepareRow, Session]], fieldClass: ast.CaseClass) extends Planter[Query[T], PrepareRow, Session] {
def unquote: Query[T] =
throw new RuntimeException("Unquotation can only be done from a quoted block.")
}
Expand Down
404 changes: 226 additions & 178 deletions quill-sql/src/main/scala/io/getquill/context/BatchQueryExecution.scala

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions quill-sql/src/main/scala/io/getquill/context/LiftMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import io.getquill.parser.Lifter
import io.getquill.CaseClassLift
import io.getquill.ast.CaseClass
import io.getquill.InjectableEagerPlanter
import io.getquill.util.Format

object LiftQueryMacro {
private[getquill] def newUuid = java.util.UUID.randomUUID().toString
Expand All @@ -46,9 +47,16 @@ object LiftQueryMacro {
quat match
case _: Quat.Product =>
// Not sure why cast back to iterable is needed here but U param is not needed once it is inside of the planter
'{ EagerEntitiesPlanter($entity.asInstanceOf[Iterable[T]], ${ Expr(newUuid) }).unquote } // [T, PrepareRow] // adding these causes assertion failed: unresolved symbols: value Context_this
val (lifterClass, lifters) =
LiftMacro.liftInjectedProduct[T, PrepareRow, Session]
val lifterClassExpr = Lifter.caseClass(lifterClass)
val liftedLiftersExpr = Expr.ofList(lifters)
val returning =
'{ EagerEntitiesPlanter($entity.asInstanceOf[Iterable[T]], ${ Expr(newUuid) }, ${ liftedLiftersExpr }, ${ lifterClassExpr }).unquote }
returning
case _ =>
val encoder = LiftMacro.summonEncoderOrFail[T, PrepareRow, Session](entity)
// [T, PrepareRow] // adding these causes assertion failed: unresolved symbols: value Context_this
'{ EagerListPlanter($entity.asInstanceOf[Iterable[T]].toList, $encoder, ${ Expr(newUuid) }).unquote }
}
}
Expand Down Expand Up @@ -80,17 +88,10 @@ object LiftMacro {
}
}

private[getquill] def liftInjectedScalar[T, PrepareRow, Session](using qctx: Quotes, tpe: Type[T], prepareRowTpe: Type[PrepareRow], sessionTpe: Type[Session]): (ScalarTag, Expr[InjectableEagerPlanter[_, PrepareRow, Session]]) = {
import qctx.reflect._
val uuid = java.util.UUID.randomUUID.toString
(ScalarTag(uuid), injectableLiftValue[T, PrepareRow, Session]('{ (t: T) => t }, uuid))
}

// TODO Injected => Injectable
private[getquill] def liftInjectedProduct[T, PrepareRow, Session](using qctx: Quotes, tpe: Type[T], prepareRowTpe: Type[PrepareRow], sessionTpe: Type[Session]): (CaseClass, List[Expr[InjectableEagerPlanter[_, PrepareRow, Session]]]) = {
import qctx.reflect._
val (caseClassAstInitial, liftsInitial) = liftInjectedProductComponents[T, PrepareRow]
// println("========= CaseClass Initial =========\n" + io.getquill.util.Messages.qprint(caseClassAstInitial))
val TaggedLiftedCaseClass(caseClassAst, lifts) = TaggedLiftedCaseClass(caseClassAstInitial, liftsInitial).reKeyWithUids()
val liftPlanters =
lifts.map((liftKey, lift) =>
Expand All @@ -112,7 +113,6 @@ object LiftMacro {

// Get the elaboration and AST once so that it will not have to be parsed out of the liftedCombo (since they are normally returned by ElaborateStructure.ofProductValue)
val elaborated = ElaborateStructure.Term.ofProduct[T](ElaborationSide.Encoding)
// println("========= Elaboration =========\n" + io.getquill.util.Messages.qprint(elaborated))
val (_, caseClassAst) = ElaborateStructure.productValueToAst[T](elaborated)
val caseClass = caseClassAst.asInstanceOf[io.getquill.ast.CaseClass]

Expand Down
24 changes: 21 additions & 3 deletions quill-sql/src/main/scala/io/getquill/context/Particularize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import scala.annotation.tailrec
import io.getquill.idiom._
import scala.quoted._
import io.getquill.util.Format
import io.getquill.metaprog.InjectableEagerPlanterExpr

/**
* For a query that has a filter(p => liftQuery(List("Joe","Jack")).contains(p.name)) we need to turn
Expand All @@ -28,6 +29,12 @@ import io.getquill.util.Format
* which has to be manipulated inside of a '{ ... } block.
*/
object Particularize:
// ====================================== TODO additional-lifts case here too ======================================
// ====================================== TODO additional-lifts case here too ======================================
// ====================================== TODO additional-lifts case here too ======================================
// ====================================== TODO additional-lifts case here too ======================================
// ====================================== TODO additional-lifts case here too ======================================
// the following should test for that: update - extra lift + scalars + liftQuery/setContains
object Static:
/** Convenience constructor for doing particularization from an Unparticular.Query */
def apply[PrepareRowTemp](query: Unparticular.Query, lifts: List[Expr[Planter[_, _, _]]], runtimeLiftingPlaceholder: Expr[Int => String], emptySetContainsToken: Token => Token)(using Quotes): Expr[String] =
Expand All @@ -38,7 +45,7 @@ object Particularize:

enum LiftChoice:
case ListLift(value: EagerListPlanterExpr[Any, PrepareRowTemp, Session])
case SingleLift(value: EagerPlanterExpr[Any, PrepareRowTemp, Session])
case SingleLift(value: PlanterExpr[Any, PrepareRowTemp, Session])

val listLifts: Map[String, EagerListPlanterExpr[Any, PrepareRowTemp, Session]] =
lifts.collect {
Expand All @@ -52,9 +59,16 @@ object Particularize:
planterExpr.asInstanceOf[EagerPlanterExpr[Any, PrepareRowTemp, Session]]
}.map(lift => (lift.uid, lift)).toMap

val injectableLifts: Map[String, InjectableEagerPlanterExpr[Any, PrepareRowTemp, Session]] =
lifts.collect {
case PlanterExpr.Uprootable(planterExpr: InjectableEagerPlanterExpr[_, _, _]) =>
planterExpr.asInstanceOf[InjectableEagerPlanterExpr[Any, PrepareRowTemp, Session]]
}.map(lift => (lift.uid, lift)).toMap

def getLifts(uid: String): LiftChoice =
listLifts.get(uid).map(LiftChoice.ListLift(_))
.orElse(singleLifts.get(uid).map(LiftChoice.SingleLift(_)))
.orElse(injectableLifts.get(uid).map(LiftChoice.SingleLift(_)))
.getOrElse {
throw new IllegalArgumentException(s"Cannot find list-lift with UID ${uid} (from all the lifts ${lifts.map(io.getquill.util.Format.Expr(_))})")
}
Expand Down Expand Up @@ -227,7 +241,12 @@ object Particularize:

object Dynamic:
/** Convenience constructor for doing particularization from an Unparticular.Query */
def apply[PrepareRowTemp](query: Unparticular.Query, lifts: List[Planter[_, _, _]], liftingPlaceholder: Int => String, emptySetContainsToken: Token => Token): String =
def apply[PrepareRowTemp](
query: Unparticular.Query,
lifts: List[Planter[_, _, _]],
liftingPlaceholder: Int => String,
emptySetContainsToken: Token => Token
): String =
raw(query.realQuery, lifts, liftingPlaceholder, emptySetContainsToken)

private[getquill] def raw[PrepareRowTemp, Session](statements: Statement, lifts: List[Planter[_, _, _]], liftingPlaceholder: Int => String, emptySetContainsToken: Token => Token): String = {
Expand All @@ -237,7 +256,6 @@ object Particularize:

val listLifts = lifts.collect { case e: EagerListPlanter[_, _, _] => e.asInstanceOf[EagerListPlanter[Any, PrepareRowTemp, Session]] }.map(lift => (lift.uid, lift)).toMap
val singleLifts = lifts.collect { case e: EagerPlanter[_, _, _] => e.asInstanceOf[EagerPlanter[Any, PrepareRowTemp, Session]] }.map(lift => (lift.uid, lift)).toMap
// For dynamic lifts, it is possible that we have injectable lifts that have not yet been resolved
val injectableLifts = lifts.collect { case e: InjectableEagerPlanter[_, _, _] => e.asInstanceOf[InjectableEagerPlanter[Any, PrepareRowTemp, Session]] }.map(lift => (lift.uid, lift)).toMap

def getLifts(uid: String): LiftChoice =
Expand Down
120 changes: 105 additions & 15 deletions quill-sql/src/main/scala/io/getquill/context/QueryExecution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import io.getquill.util.CommonExtensions
import io.getquill.generic.ElaborateTrivial
import io.getquill.quat.QuatMaking
import io.getquill.quat.Quat
import io.getquill.ast.External

object ContextOperation:
case class Argument[I, T, A <: QAC[I, T] with Action[I], D <: Idiom, N <: NamingStrategy, PrepareRow, ResultRow, Session, Ctx <: Context[_, _], Res](
Expand Down Expand Up @@ -413,7 +414,10 @@ object PrepareDynamicExecution:
naming: N,
elaborationBehavior: ElaborationBehavior,
topLevelQuat: Quat,
spliceBehavior: SpliceBehavior = SpliceBehavior.NeedsSplice
spliceBehavior: SpliceBehavior = SpliceBehavior.NeedsSplice,
// For a batch query, these are the other lifts besides the primary liftQuery lifts.
// This should be empty & ignored for all other query types.
additionalLifts: List[Planter[?, ?, ?]] = List()
) =
// Splice all quotation values back into the AST recursively, by this point these quotations are dynamic
// which means that the compiler has not done the splicing for us. We need to do this ourselves.
Expand All @@ -431,9 +435,6 @@ object PrepareDynamicExecution:

val splicedAst = ElaborateTrivial(elaborationBehavior)(splicedAstRaw)

// Pull out the all the Planter instances (for now they need to be EagerPlanters for Dynamic Queries)
val lifts = gatheredLifts.map(lift => (lift.uid, lift)).toMap

// TODO Should make this enable-able via a logging configuration
// println("=============== Dynamic Expanded Ast Is ===========\n" + io.getquill.util.Messages.qprint(splicedAst))

Expand Down Expand Up @@ -467,25 +468,114 @@ object PrepareDynamicExecution:
// Get the UIDs from the lifts, if they are something unexpected (e.g. Lift elements from Quill 2.x) throw an exception
val liftTags =
externals.map {
case ScalarTag(uid) => uid
case other => throw new IllegalArgumentException(s"Invalid Lift Tag: ${other}")
case tag @ ScalarTag(_) => tag
case other => throw new IllegalArgumentException(s"Invalid Lift Tag: ${other}")
}

val queryString = Particularize.Dynamic(unparticularQuery, gatheredLifts, idiom.liftingPlaceholder, idiom.emptySetContainsToken)
val queryString = Particularize.Dynamic(unparticularQuery, gatheredLifts ++ additionalLifts, idiom.liftingPlaceholder, idiom.emptySetContainsToken)

// Match the ScalarTags we pulled out earlier (in ReifyStatement) with corresponding Planters because
// the Planters can be out of order (I.e. in a different order then the ?s in the SQL query that they need to be spliced into).
// The ScalarTags are comming directly from the tokenized AST however and their order should be correct.
// also, some of they may be filtered out
val sortedLifts = liftTags.map { tag =>
lifts.get(tag) match
case Some(lift) => lift
case None => throw new IllegalArgumentException(s"Could not lookup value for the tag: ${tag}")
}

(queryString, outputAst, sortedLifts, extractor)
val (sortedLifts, sortedSecondaryLifts) =
processLifts(gatheredLifts, liftTags, additionalLifts) match
case Right((sl, ssl)) => (sl, ssl)
case Left(msg) =>
throw new IllegalArgumentException(
s"Could not process the lifts:\n" +
s"${gatheredLifts.map(_.toString).mkString("====\n")}" +
(if (additionalLifts.nonEmpty) s"${additionalLifts.map(_.toString).mkString("====\n")}" else "") +
s"Due to an error: $msg"
)

(queryString, outputAst, sortedLifts, extractor, sortedSecondaryLifts)

end apply

private[getquill] def processLifts(
lifts: List[Planter[_, _, _]],
matchingExternals: List[External],
secondaryLifts: List[Planter[_, _, _]] = List()
): Either[String, (List[Planter[_, _, _]], List[Planter[_, _, _]])] =
val encodeablesMap =
lifts.map(e => (e.uid, e)).toMap

val secondaryEncodeablesMap =
secondaryLifts.map(e => (e.uid, e)).toMap

val uidsOfScalarTags =
matchingExternals.collect {
case tag: ScalarTag => tag.uid
}

enum UidStatus:
// Most normal lifts and the liftQuery of batches
case Primary(uid: String, planter: Planter[?, ?, ?])
// In batch queries, any lifts that are not part of the initial liftQuery
case Secondary(uid: String, planter: Planter[?, ?, ?])
// Lift planter was not found, this means an error
case NotFound(uid: String)
def print: String = this match
case Primary(uid, planter) => s"PrimaryPlanter($uid, ${planter})"
case Secondary(uid, planter) => s"SecondaryPlanter($uid, ${planter})"
case NotFound(uid) => s"NotFoundPlanter($uid)"

val sortedEncodeables =
uidsOfScalarTags
.map { uid =>
encodeablesMap.get(uid) match
case Some(element) => UidStatus.Primary(uid, element)
case None =>
secondaryEncodeablesMap.get(uid) match
case Some(element) => UidStatus.Secondary(uid, element)
case None => UidStatus.NotFound(uid)
}

object HasNotFoundUids:
def unapply(statuses: List[UidStatus]) =
val collected =
statuses.collect {
case UidStatus.NotFound(uid) => uid
}
if (collected.nonEmpty) Some(collected) else None

object PrimaryThenSecondary:
def unapply(statuses: List[UidStatus]) =
val (primaries, secondaries) =
statuses.partition {
case UidStatus.Primary(_, _) => true
case _ => false
}
val primariesFound = primaries.collect { case p: UidStatus.Primary => p }
val secondariesFound = secondaries.collect { case s: UidStatus.Secondary => s }
val goodPartitioning =
primariesFound.length == primaries.length && secondariesFound.length == secondaries.length
if (goodPartitioning)
Some((primariesFound.map(_.planter), secondariesFound.map(_.planter)))
else
None

val outputEncodeables =
sortedEncodeables match
case HasNotFoundUids(uids) =>
Left(s"Invalid Transformations Encountered. Cannot find lift with IDs: ${uids}.")
case PrimaryThenSecondary(primaryPlanters, secondaryPlanters /*or List() if none*/ ) =>
Right((primaryPlanters, secondaryPlanters))
case other =>
Left(
s"Invalid transformation primary and secondary encoders were mixed.\n" +
s"All secondary planters must come after all primary ones but found:\n" +
s"${other.map(_.print).mkString("=====\n")}"
)

// TODO This should be logged if some fine-grained debug logging is enabled. Maybe as part of some phase that can be enabled via -Dquill.trace.types config
// val remaining = encodeables.removedAll(uidsOfScalarTags)
// if (!remaining.isEmpty)
// println(s"Ignoring the following lifts: [${remaining.map((_, v) => Format.Expr(v.plant)).mkString(", ")}]")
outputEncodeables
end processLifts

end PrepareDynamicExecution

/**
Expand Down Expand Up @@ -521,7 +611,7 @@ object RunDynamicExecution:
topLevelQuat: Quat
): Res = {
// println("===== Passed Ast: " + io.getquill.util.Messages.qprint(quoted.ast))
val (queryString, outputAst, sortedLifts, extractor) =
val (queryString, outputAst, sortedLifts, extractor, _) =
PrepareDynamicExecution[I, T, RawT, D, N, PrepareRow, ResultRow, Session](quoted, rawExtractor, ctx.idiom, ctx.naming, elaborationBehavior, topLevelQuat)

// Use the sortedLifts to prepare the method that will prepare the SQL statement
Expand Down
8 changes: 8 additions & 0 deletions quill-sql/src/main/scala/io/getquill/context/QuoteMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import io.getquill.metaprog.SummonParser
import io.getquill.metaprog.SummonSerializationBehaviors
import io.getquill.parser.engine.History
import io.getquill.context.sql.norm.SimplifyFilterTrue
import io.getquill.parser.Unlifter
import io.getquill.util.Format

object ExtractLifts {
// Find all lifts, dedupe by UID since lifts can be inlined multiple times hence
Expand Down Expand Up @@ -61,6 +63,7 @@ object QuoteMacro {

def apply[T](bodyRaw: Expr[T])(using Quotes, Type[T], Type[Parser]): Expr[Quoted[T]] = {
import quotes.reflect._

// NOTE Can disable underlyingArgument here if needed and make body = bodyRaw. See https://github.com/lampepfl/dotty/pull/8041 for detail
val body = bodyRaw.asTerm.underlyingArgument.asExpr

Expand All @@ -71,6 +74,11 @@ object QuoteMacro {
val ast = SimplifyFilterTrue(BetaReduction(rawAst))

val reifiedAst = Lifter.WithBehavior(serializeQuats, serializeAst)(ast)
val u = Unlifter(reifiedAst)
u match
case id @ io.getquill.ast.Ident("Ast", _) =>
println(s"******** Lifted $id (quat: ${io.getquill.util.Messages.qprint(id.quat)}) from ${Format.Expr(reifiedAst)} from ${io.getquill.util.Messages.qprint(io.getquill.util.Messages.qprint(rawAst))}")
case _ =>

// Extract runtime quotes and lifts
val (lifts, pluckedUnquotes) = ExtractLifts(bodyRaw)
Expand Down
10 changes: 9 additions & 1 deletion quill-sql/src/main/scala/io/getquill/context/StaticState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ import io.getquill.ast.Ast
import io.getquill.metaprog.PlanterExpr
import io.getquill.idiom.Idiom

case class StaticState(query: Unparticular.Query, rawLifts: List[PlanterExpr[?, ?, ?]], returnAction: Option[ReturnAction], idiom: Idiom)(queryAst: => Ast):
case class StaticState(
query: Unparticular.Query,
rawLifts: List[PlanterExpr[?, ?, ?]],
returnAction: Option[ReturnAction],
idiom: Idiom,
// For a batch query, lifts other than the one from the primary liftQuery go here. THey need to be know about separately
// in the batch query case. Should be empty & ignored for non batch cases.
secondaryLifts: List[PlanterExpr[?, ?, ?]] = List()
)(queryAst: => Ast):
/**
* Plant all the lifts and return them.
* NOTE: If this is used frequently would it be worth caching (i.e. since this object is immutable)
Expand Down
Loading

0 comments on commit fb14ec1

Please sign in to comment.