Skip to content

Commit

Permalink
Improve translate function. Make contexts more lazy. (#3146)
Browse files Browse the repository at this point in the history
  • Loading branch information
deusaquilus authored Dec 5, 2024
1 parent cb465cc commit 87bcdb2
Show file tree
Hide file tree
Showing 23 changed files with 187 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CassandraPekkoContext[+N <: NamingStrategy](
groups: List[BatchGroup]
)(info: ExecutionInfo, dc: Runner)(implicit executionContext: ExecutionContext): Result[RunBatchActionResult] =
Future.sequence {
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare.map(executeAction(cql, _)(info, dc))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class CassandraZioContext[+N <: NamingStrategy](val naming: N)
env <- ZIO.service[CassandraZioSession]
_ <- {
val batchGroups =
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare
.map(prep => executeAction(cql, prep)(info, dc).provideEnvironment(ZEnvironment(env)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CassandraAsyncContext[+N <: NamingStrategy](
groups: List[BatchGroup]
)(info: ExecutionInfo, dc: Runner)(implicit executionContext: ExecutionContext): Result[RunBatchActionResult] =
Future.sequence {
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare.map(executeAction(cql, _)(info, dc))
}
}.map(_ => ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CassandraSyncContext[+N <: NamingStrategy](
}

def executeBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner): Unit =
groups.foreach { case BatchGroup(cql, prepare) =>
groups.foreach { case BatchGroup(cql, prepare, _) =>
prepare.foreach(executeAction(cql, _)(info, dc))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class AsyncMirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](
)(executionInfo: ExecutionInfo, dc: Runner)(implicit ec: ExecutionContext) =
Future {
BatchActionMirror(
groups.map { case BatchGroup(string, prepare) =>
groups.map { case BatchGroup(string, prepare, _) =>
(string, prepare.map(_(Row(), session)._2))
},
executionInfo
Expand All @@ -148,7 +148,7 @@ class AsyncMirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](
)(executionInfo: ExecutionInfo, dc: Runner)(implicit ec: ExecutionContext) =
Future {
BatchActionReturningMirror[T](
groups.map { case BatchGroupReturning(string, returningBehavior, prepare) =>
groups.map { case BatchGroupReturning(string, returningBehavior, prepare, _) =>
(string, returningBehavior, prepare.map(_(Row(), session)._2))
},
extractor,
Expand Down
6 changes: 3 additions & 3 deletions quill-core/src/main/scala/io/getquill/MirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class MirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](

def executeBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner) =
BatchActionMirror(
groups.map { case BatchGroup(string, prepare) =>
groups.map { case BatchGroup(string, prepare, _) =>
(string, prepare.map(_(Row(), session)._2))
},
info
Expand All @@ -139,7 +139,7 @@ class MirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](
extractor: Extractor[T]
)(info: ExecutionInfo, dc: Runner) =
new BatchActionReturningMirror[T](
groups.map { case BatchGroupReturning(string, returningBehavior, prepare) =>
groups.map { case BatchGroupReturning(string, returningBehavior, prepare, _) =>
(string, returningBehavior, prepare.map(_(Row(), session)._2))
},
extractor,
Expand All @@ -151,7 +151,7 @@ class MirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](

def prepareBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner) =
(session: Session) =>
groups.flatMap { case BatchGroup(string, prepare) =>
groups.flatMap { case BatchGroup(string, prepare, _) =>
prepare.map(_(Row(), session)._2)
}

Expand Down
29 changes: 15 additions & 14 deletions quill-core/src/main/scala/io/getquill/context/ActionMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,28 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
import c.universe.{Function => _, Ident => _, _}

def translateQuery(quoted: Tree): Tree =
translateQueryPrettyPrint(quoted, q"false")
translateQueryPrettyPrint(quoted, q"io.getquill.context.TranslateOptions()")

def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree = {
def translateQueryPrettyPrint(quoted: Tree, options: Tree): Tree = {
val expanded = expand(extractAst(quoted), inferQuat(quoted.tpe))
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val (idiomContext, expanded) = $expanded
${c.prefix}.translateQuery(
expanded.string,
expanded.prepare,
prettyPrint = ${prettyPrint}
expanded.liftings,
options = ${options}
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
}

def translateBatchQuery(quoted: Tree): Tree =
translateBatchQueryPrettyPrint(quoted, q"false")
translateBatchQueryPrettyPrint(quoted, q"io.getquill.context.TranslateOptions()")

def translateBatchQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree =
// TODO need to change this to include liftings
def translateBatchQueryPrettyPrint(quoted: Tree, options: Tree): Tree =
expandBatchActionNew(quoted, isReturning = false) {
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext, canDoBatch) =>
q"""
Expand All @@ -50,12 +51,12 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
${c.prefix}.translateBatchQuery(
batches.map { subBatch =>
val expanded = $expanded
(expanded.string, expanded.prepare)
(expanded.string, expanded.prepare, expanded.liftings)
}.groupBy(_._1).map {
case (string, items) =>
${c.prefix}.BatchGroup(string, items.map(_._2).toList)
${c.prefix}.BatchGroup(string, items.map(_._2).toList, items.map(_._3).toList)
}.toList,
$prettyPrint
$options
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
Expand Down Expand Up @@ -142,11 +143,11 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
} else {
$batch.toList.map(element => List(element))
}
/* batchesSharded: List[(String, (Row, MirrorSession) => (List[Any], Row) <a.k.a: prepare>)] */
/* batchesSharded: List[(String, (Row, MirrorSession) => (List[Any], Row <a.k.a: prepare>), List[ScalarLift]) ] */
val batchesSharded = batches.map { subBatch => {
/* `expanded` is io.getquill.context.ExpandWithInjectables(ast, subBatch, injectableLiftList) */
val expanded = $expanded
(expanded.string, expanded.prepare)
(expanded.string, expanded.prepare, expanded.liftings)
}
}
/*
Expand All @@ -163,7 +164,7 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
*/
batchesSharded.groupByOrdered(_._1).map {
case (string, items) =>
${c.prefix}.BatchGroup(string, items.map(_._2).toList)
${c.prefix}.BatchGroup(string, items.map(_._2).toList, items.map(_._3).toList)
}.toList
})(io.getquill.context.ExecutionInfo.unknown, ())
"""
Expand Down Expand Up @@ -194,12 +195,12 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
val batchesSharded = batches.map { subBatch => {
/* `expanded` is io.getquill.context.ExpandWithInjectables(ast, subBatch, injectableLiftList) */
val expanded = $expanded
((expanded.string, $returningColumn), expanded.prepare)
((expanded.string, $returningColumn), expanded.prepare, expanded.liftings)
}
}
batchesSharded.groupByOrdered(_._1).map {
case ((string, column), items) =>
${c.prefix}.BatchGroupReturning(string, column, items.map(_._2).toList)
${c.prefix}.BatchGroupReturning(string, column, items.map(_._2).toList, items.map(_._3).toList)
}.toList
}, ${returningExtractor[T]})(io.getquill.context.ExecutionInfo.unknown, ())
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.getquill.context

import io.getquill.ast.ScalarLift
import io.getquill.{Action, BatchAction, NamingStrategy, Query, Quoted}
import io.getquill.idiom.Idiom

Expand All @@ -16,34 +17,40 @@ trait ContextVerbTranslate extends ContextTranslateMacro {
override def seq[A](list: List[A]): List[A] = list
}

case class TranslateOptions(
prettyPrint: Boolean = false,
plugLifts: Boolean = true,
demarcatePluggedLifts: Boolean = true
)

trait ContextTranslateMacro extends ContextTranslateProto {
this: Context[_ <: Idiom, _ <: NamingStrategy] =>

def translate[T](quoted: Quoted[T]): TranslateResult[String] = macro QueryMacro.translateQuery[T]
def translate[T](quoted: Quoted[Query[T]]): TranslateResult[String] = macro QueryMacro.translateQuery[T]
def translate(quoted: Quoted[Action[_]]): TranslateResult[String] = macro ActionMacro.translateQuery
def translate(quoted: Quoted[BatchAction[Action[_]]]): TranslateResult[List[String]] =
def translate[T](quoted: Quoted[T]): String = macro QueryMacro.translateQuery[T]
def translate[T](quoted: Quoted[Query[T]]): String = macro QueryMacro.translateQuery[T]
def translate(quoted: Quoted[Action[_]]): String = macro ActionMacro.translateQuery
def translate(quoted: Quoted[BatchAction[Action[_]]]): List[String] =
macro ActionMacro.translateBatchQuery

def translate[T](quoted: Quoted[T], prettyPrint: Boolean): TranslateResult[String] =
def translate[T](quoted: Quoted[T], options: TranslateOptions): TranslateResult[String] =
macro QueryMacro.translateQueryPrettyPrint[T]
def translate[T](quoted: Quoted[Query[T]], prettyPrint: Boolean): TranslateResult[String] =
def translate[T](quoted: Quoted[Query[T]], options: TranslateOptions): TranslateResult[String] =
macro QueryMacro.translateQueryPrettyPrint[T]
def translate(quoted: Quoted[Action[_]], prettyPrint: Boolean): TranslateResult[String] =
def translate(quoted: Quoted[Action[_]], options: TranslateOptions): TranslateResult[String] =
macro ActionMacro.translateQueryPrettyPrint
def translate(quoted: Quoted[BatchAction[Action[_]]], prettyPrint: Boolean): TranslateResult[List[String]] =
def translate(quoted: Quoted[BatchAction[Action[_]]], options: TranslateOptions): TranslateResult[List[String]] =
macro ActionMacro.translateBatchQueryPrettyPrint

def translateQuery[T](
statement: String,
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor,
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[String]
def translateBatchQuery(groups: List[BatchGroup], prettyPrint: Boolean = false)(
lifts: List[ScalarLift] = List(),
options: TranslateOptions
)(executionInfo: ExecutionInfo, dc: Runner): String

def translateBatchQuery(groups: List[BatchGroup], options: TranslateOptions = TranslateOptions())(
executionInfo: ExecutionInfo,
dc: Runner
): TranslateResult[List[String]]
): List[String]
}

trait ContextTranslateProto {
Expand All @@ -58,40 +65,51 @@ trait ContextTranslateProto {

def translateQuery[T](
statement: String,
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor,
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[String] =
try {
push(prepareParams(statement, prepare)) { params =>
val query =
if (params.nonEmpty) {
params.foldLeft(statement) { case (expanded, param) =>
expanded.replaceFirst("\\?", param)
}
} else {
statement
}

if (prettyPrint)
idiom.format(query)
else
query
liftings: List[ScalarLift] = List(),
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): String = {
def quoteIfNeeded(value: Any): String =
value match {
case _: String => s"'${value}'"
case _: Char => s"'${value}'"
case _ => s"${value}"
}
} catch {
case e: Exception =>
wrap("<!-- Cannot display parameters due to preparation error: " + e.getMessage + " -->\n" + statement)
}

val outputQuery =
if (liftings.isEmpty)
statement
else
options.plugLifts match {
case true =>
liftings.foldLeft(statement) { case (expanded, lift) =>
expanded.replaceFirst("\\?", if (options.demarcatePluggedLifts) s"lift(${quoteIfNeeded(lift.value)})" else quoteIfNeeded(lift.value))
}
case false =>
var varNum: Int = 0
val dol = '$'
val numberedQuery =
liftings.foldLeft(statement) { case (expanded, lift) =>
val res = expanded.replaceFirst("\\?", s"${dol}${varNum + 1}")
varNum += 1
res
}
numberedQuery + "\n" + liftings.zipWithIndex.map { case (lift, i) => s"${dol}${i + 1} = ${lift.value}" }.mkString("\n")
}

if (options.prettyPrint)
idiom.format(outputQuery)
else
outputQuery
}

def translateBatchQuery(
// TODO these groups need to have liftings lists
groups: List[BatchGroup],
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[List[String]] =
seq {
groups.flatMap { group =>
group.prepare.map { prepare =>
translateQuery(group.string, prepare, prettyPrint = prettyPrint)(executionInfo, dc)
}
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): List[String] =
groups.flatMap { group =>
(group.prepare zip group.liftings).map { case (_, liftings) =>
translateQuery(group.string, options = options, liftings = liftings)(executionInfo, dc)
}
}

Expand Down
Loading

0 comments on commit 87bcdb2

Please sign in to comment.