Skip to content

Commit

Permalink
Improve translate function. No longer needs preparers.
Browse files Browse the repository at this point in the history
  • Loading branch information
deusaquilus committed Dec 4, 2024
1 parent cb465cc commit bd73907
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CassandraPekkoContext[+N <: NamingStrategy](
groups: List[BatchGroup]
)(info: ExecutionInfo, dc: Runner)(implicit executionContext: ExecutionContext): Result[RunBatchActionResult] =
Future.sequence {
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare.map(executeAction(cql, _)(info, dc))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class CassandraZioContext[+N <: NamingStrategy](val naming: N)
env <- ZIO.service[CassandraZioSession]
_ <- {
val batchGroups =
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare
.map(prep => executeAction(cql, prep)(info, dc).provideEnvironment(ZEnvironment(env)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CassandraAsyncContext[+N <: NamingStrategy](
groups: List[BatchGroup]
)(info: ExecutionInfo, dc: Runner)(implicit executionContext: ExecutionContext): Result[RunBatchActionResult] =
Future.sequence {
groups.flatMap { case BatchGroup(cql, prepare) =>
groups.flatMap { case BatchGroup(cql, prepare, _) =>
prepare.map(executeAction(cql, _)(info, dc))
}
}.map(_ => ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CassandraSyncContext[+N <: NamingStrategy](
}

def executeBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner): Unit =
groups.foreach { case BatchGroup(cql, prepare) =>
groups.foreach { case BatchGroup(cql, prepare, _) =>
prepare.foreach(executeAction(cql, _)(info, dc))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class AsyncMirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](
)(executionInfo: ExecutionInfo, dc: Runner)(implicit ec: ExecutionContext) =
Future {
BatchActionMirror(
groups.map { case BatchGroup(string, prepare) =>
groups.map { case BatchGroup(string, prepare, _) =>
(string, prepare.map(_(Row(), session)._2))
},
executionInfo
Expand Down
4 changes: 2 additions & 2 deletions quill-core/src/main/scala/io/getquill/MirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class MirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](

def executeBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner) =
BatchActionMirror(
groups.map { case BatchGroup(string, prepare) =>
groups.map { case BatchGroup(string, prepare, _) =>
(string, prepare.map(_(Row(), session)._2))
},
info
Expand All @@ -151,7 +151,7 @@ class MirrorContext[+Idiom <: BaseIdiom, +Naming <: NamingStrategy](

def prepareBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner) =
(session: Session) =>
groups.flatMap { case BatchGroup(string, prepare) =>
groups.flatMap { case BatchGroup(string, prepare, _) =>
prepare.map(_(Row(), session)._2)
}

Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/context/ActionMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
def translateQuery(quoted: Tree): Tree =
translateQueryPrettyPrint(quoted, q"false")

def translateQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree = {
def translateQueryPrettyPrint(quoted: Tree, options: Tree): Tree = {
val expanded = expand(extractAst(quoted), inferQuat(quoted.tpe))
c.untypecheck {
q"""
..${EnableReflectiveCalls(c)}
val (idiomContext, expanded) = $expanded
${c.prefix}.translateQuery(
expanded.string,
expanded.prepare,
prettyPrint = ${prettyPrint}
options = ${options}
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
Expand All @@ -36,7 +35,8 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
def translateBatchQuery(quoted: Tree): Tree =
translateBatchQueryPrettyPrint(quoted, q"false")

def translateBatchQueryPrettyPrint(quoted: Tree, prettyPrint: Tree): Tree =
// TODO need to change this to include liftings
def translateBatchQueryPrettyPrint(quoted: Tree, options: Tree): Tree =
expandBatchActionNew(quoted, isReturning = false) {
case (batch, param, expanded, injectableLiftList, idiomNamingOriginalAstVars, idiomContext, canDoBatch) =>
q"""
Expand All @@ -50,12 +50,12 @@ class ActionMacro(val c: MacroContext) extends ContextMacro with ReifyLiftings {
${c.prefix}.translateBatchQuery(
batches.map { subBatch =>
val expanded = $expanded
(expanded.string, expanded.prepare)
(expanded.string, expanded.prepare, expanded.liftings)
}.groupBy(_._1).map {
case (string, items) =>
${c.prefix}.BatchGroup(string, items.map(_._2).toList)
}.toList,
$prettyPrint
$options
)(io.getquill.context.ExecutionInfo.unknown, ())
"""
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.getquill.context

import io.getquill.ast.ScalarLift
import io.getquill.{Action, BatchAction, NamingStrategy, Query, Quoted}
import io.getquill.idiom.Idiom

Expand All @@ -16,34 +17,40 @@ trait ContextVerbTranslate extends ContextTranslateMacro {
override def seq[A](list: List[A]): List[A] = list
}

case class TranslateOptions(
prettyPrint: Boolean = false,
plugLifts: Boolean = true,
demarcateLifts: Boolean = true
)

trait ContextTranslateMacro extends ContextTranslateProto {
this: Context[_ <: Idiom, _ <: NamingStrategy] =>

def translate[T](quoted: Quoted[T]): TranslateResult[String] = macro QueryMacro.translateQuery[T]
def translate[T](quoted: Quoted[Query[T]]): TranslateResult[String] = macro QueryMacro.translateQuery[T]
def translate(quoted: Quoted[Action[_]]): TranslateResult[String] = macro ActionMacro.translateQuery
def translate(quoted: Quoted[BatchAction[Action[_]]]): TranslateResult[List[String]] =
def translate[T](quoted: Quoted[T]): String = macro QueryMacro.translateQuery[T]
def translate[T](quoted: Quoted[Query[T]]): String = macro QueryMacro.translateQuery[T]
def translate(quoted: Quoted[Action[_]]): String = macro ActionMacro.translateQuery
def translate(quoted: Quoted[BatchAction[Action[_]]]): List[String] =
macro ActionMacro.translateBatchQuery

def translate[T](quoted: Quoted[T], prettyPrint: Boolean): TranslateResult[String] =
def translate[T](quoted: Quoted[T], options: TranslateOptions): TranslateResult[String] =
macro QueryMacro.translateQueryPrettyPrint[T]
def translate[T](quoted: Quoted[Query[T]], prettyPrint: Boolean): TranslateResult[String] =
def translate[T](quoted: Quoted[Query[T]], options: TranslateOptions): TranslateResult[String] =
macro QueryMacro.translateQueryPrettyPrint[T]
def translate(quoted: Quoted[Action[_]], prettyPrint: Boolean): TranslateResult[String] =
def translate(quoted: Quoted[Action[_]], options: TranslateOptions): TranslateResult[String] =
macro ActionMacro.translateQueryPrettyPrint
def translate(quoted: Quoted[BatchAction[Action[_]]], prettyPrint: Boolean): TranslateResult[List[String]] =
def translate(quoted: Quoted[BatchAction[Action[_]]], options: TranslateOptions): TranslateResult[List[String]] =
macro ActionMacro.translateBatchQueryPrettyPrint

def translateQuery[T](
statement: String,
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor,
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[String]
def translateBatchQuery(groups: List[BatchGroup], prettyPrint: Boolean = false)(
lifts: List[ScalarLift] = List(),
options: TranslateOptions
)(executionInfo: ExecutionInfo, dc: Runner): String

def translateBatchQuery(groups: List[BatchGroup], options: TranslateOptions = TranslateOptions())(
executionInfo: ExecutionInfo,
dc: Runner
): TranslateResult[List[String]]
): List[String]
}

trait ContextTranslateProto {
Expand All @@ -58,40 +65,36 @@ trait ContextTranslateProto {

def translateQuery[T](
statement: String,
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor,
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[String] =
try {
push(prepareParams(statement, prepare)) { params =>
val query =
if (params.nonEmpty) {
params.foldLeft(statement) { case (expanded, param) =>
expanded.replaceFirst("\\?", param)
}
} else {
statement
liftings: List[ScalarLift] = List(),
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): String =
(liftings.nonEmpty, options.plugLifts) match {
case (true, true) =>
liftings.foldLeft(statement) { case (expanded, lift) =>
expanded.replaceFirst("\\?", if (options.demarcateLifts) s"prep(${lift.value})" else s"${lift.value}")
}
case (true, false) =>
var varNum: Int = 0
val dol = '$'
val numberedQuery =
liftings.foldLeft(statement) { case (expanded, lift) =>
val res = expanded.replaceFirst("\\?", s"${dol}${varNum}")
varNum += 1
res
}

if (prettyPrint)
idiom.format(query)
else
query
}
} catch {
case e: Exception =>
wrap("<!-- Cannot display parameters due to preparation error: " + e.getMessage + " -->\n" + statement)
numberedQuery + "\n" + liftings.map(lift => s"${dol} = ${lift.value}").mkString("\n")
case _ =>
statement
}

def translateBatchQuery(
// TODO these groups need to have liftings lists
groups: List[BatchGroup],
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[List[String]] =
seq {
groups.flatMap { group =>
group.prepare.map { prepare =>
translateQuery(group.string, prepare, prettyPrint = prettyPrint)(executionInfo, dc)
}
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): List[String] =
groups.flatMap { group =>
group.prepare.map { _ =>
translateQuery(group.string, options = options)(executionInfo, dc)
}
}

Expand Down
43 changes: 21 additions & 22 deletions quill-core/src/main/scala/io/getquill/context/QueryMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ class QueryMacro(val c: MacroContext) extends ContextMacro {
case object UsesDefaultFetch extends FetchSizeArg
case object DoesNotUseFetch extends FetchSizeArg

sealed trait PrettyPrintingArg
case class ExplicitPrettyPrint(tree: Tree) extends PrettyPrintingArg
case object DefaultPrint extends PrettyPrintingArg
sealed trait PrettyPrintingOptions
case class ExplicitOptions(tree: Tree) extends PrettyPrintingOptions
case object DefaultPrint extends PrettyPrintingOptions

sealed trait ContextMethod { def name: String }
case class StreamQuery(fetchSizeBehavior: FetchSizeArg) extends ContextMethod { val name = "streamQuery" }
case object ExecuteQuery extends ContextMethod { val name = "executeQuery" }
case object ExecuteQuerySingle extends ContextMethod { val name = "executeQuerySingle" }
case class TranslateQuery(prettyPrintingArg: PrettyPrintingArg) extends ContextMethod { val name = "translateQuery" }
case object PrepareQuery extends ContextMethod { val name = "prepareQuery" }
case class StreamQuery(fetchSizeBehavior: FetchSizeArg) extends ContextMethod { val name = "streamQuery" }
case object ExecuteQuery extends ContextMethod { val name = "executeQuery" }
case object ExecuteQuerySingle extends ContextMethod { val name = "executeQuerySingle" }
case class TranslateQuery(prettyPrintingOpts: PrettyPrintingOptions) extends ContextMethod { val name = "translateQuery" }
case object PrepareQuery extends ContextMethod { val name = "prepareQuery" }

def streamQuery[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandQuery[T](quoted, StreamQuery(UsesDefaultFetch))
Expand All @@ -40,8 +40,8 @@ class QueryMacro(val c: MacroContext) extends ContextMacro {
def translateQuery[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandQuery[T](quoted, TranslateQuery(DefaultPrint))

def translateQueryPrettyPrint[T](quoted: Tree, prettyPrint: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandQuery[T](quoted, TranslateQuery(ExplicitPrettyPrint(prettyPrint)))
def translateQueryPrettyPrint[T](quoted: Tree, options: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandQuery[T](quoted, TranslateQuery(ExplicitOptions(options)))

def prepareQuery[T](quoted: Tree)(implicit t: WeakTypeTag[T]): Tree =
expandQuery[T](quoted, PrepareQuery)
Expand Down Expand Up @@ -85,22 +85,23 @@ class QueryMacro(val c: MacroContext) extends ContextMacro {
(row, session) => $decoder(0, row, session)
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case TranslateQuery(ExplicitPrettyPrint(argValue)) =>
case TranslateQuery(ExplicitOptions(argValue)) =>
// use 'liftings' instead of 'prepare' I.e. the List[ScalarLifts] extracted from the query during Expand
q"""
${c.prefix}.${TermName(method.name)}(
expanded.string,
expanded.prepare,
expanded.liftings,
(row, session) => $decoder(0, row, session),
prettyPrint = ${argValue}
options = ${argValue}
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case TranslateQuery(DefaultPrint) =>
q"""
${c.prefix}.${TermName(method.name)}(
expanded.string,
expanded.prepare,
expanded.liftings,
(row, session) => $decoder(0, row, session),
prettyPrint = false
options = io.getquill.context.TranslateOptions()
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case PrepareQuery =>
Expand Down Expand Up @@ -167,22 +168,20 @@ class QueryMacro(val c: MacroContext) extends ContextMacro {
$meta.extract
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case TranslateQuery(ExplicitPrettyPrint(argValue)) =>
case TranslateQuery(ExplicitOptions(argValue)) =>
q"""
${c.prefix}.${TermName(method.name)}(
expanded.string,
expanded.prepare,
$meta.extract,
prettyPrint = ${argValue}
expanded.liftings,
options = ${argValue}
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case TranslateQuery(DefaultPrint) =>
q"""
${c.prefix}.${TermName(method.name)}(
expanded.string,
expanded.prepare,
$meta.extract,
prettyPrint = false
expanded.liftings,
options = io.getquill.context.TranslateOptions()
)(io.getquill.context.ExecutionInfo(expanded.executionType, expanded.ast, staticTopLevelQuat), ())
"""
case PrepareQuery =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ trait DoobieContextBase[+Dialect <: SqlIdiom, +Naming <: NamingStrategy]
)(
info: ExecutionInfo,
dc: Runner
): ConnectionIO[List[Long]] = groups.flatTraverse { case BatchGroup(sql, preps) =>
): ConnectionIO[List[Long]] = groups.flatTraverse { case BatchGroup(sql, preps, _) =>
HC.prepareStatement(sql) {
useConnection { implicit connection =>
for {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.getquill.context

import io.getquill.ReturnAction
import io.getquill.ast.ScalarLift

trait RowContext {
type PrepareRow
Expand All @@ -10,7 +11,7 @@ trait RowContext {
private val _identityExtractor: Extractor[Any] = (rr: ResultRow, _: Session) => rr
protected def identityExtractor[T]: Extractor[T] = _identityExtractor.asInstanceOf[Extractor[T]]

case class BatchGroup(string: String, prepare: List[Prepare])
case class BatchGroup(string: String, prepare: List[Prepare], liftings: List[ScalarLift])
case class BatchGroupReturning(string: String, returningBehavior: ReturnAction, prepare: List[Prepare])

type Prepare = (PrepareRow, Session) => (List[Any], PrepareRow)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.getquill.context.qzio

import io.getquill.ast.ScalarLift
import io.getquill.context.ZioJdbc._
import io.getquill.context._
import io.getquill.context.jdbc.JdbcContextTypes
Expand Down Expand Up @@ -104,22 +105,19 @@ abstract class ZioJdbcContext[+Dialect <: SqlIdiom, +Naming <: NamingStrategy]

override def translateQuery[T](
statement: String,
prepare: Prepare = identityPrepare,
extractor: Extractor[T] = identityExtractor,
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[String] =
onConnection(connDelegate.translateQuery[T](statement, prepare, extractor, prettyPrint)(executionInfo, dc))
liftings: List[ScalarLift] = List(),
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): String =
connDelegate.translateQuery[T](statement, liftings, options)(executionInfo, dc)

override def translateBatchQuery(
groups: List[BatchGroup],
prettyPrint: Boolean = false
)(executionInfo: ExecutionInfo, dc: Runner): TranslateResult[List[String]] =
onConnection(
connDelegate.translateBatchQuery(
groups.asInstanceOf[List[ZioJdbcContext.this.connDelegate.BatchGroup]],
prettyPrint
)(executionInfo, dc)
)
options: TranslateOptions = TranslateOptions()
)(executionInfo: ExecutionInfo, dc: Runner): List[String] =
connDelegate.translateBatchQuery(
groups.asInstanceOf[List[ZioJdbcContext.this.connDelegate.BatchGroup]],
options
)(executionInfo, dc)

def streamQuery[T](
fetchSize: Option[Int],
Expand Down
Loading

0 comments on commit bd73907

Please sign in to comment.