From c085e1ca8b71df74de82d069f6eefc5405536e79 Mon Sep 17 00:00:00 2001 From: deusaquilus Date: Tue, 27 Aug 2019 13:24:11 -0400 Subject: [PATCH] Infix should be in query sub-clauses even if contents not selected. --- .../test/resources/application-codegen.conf | 2 +- .../main/scala/io/getquill/AstPrinter.scala | 30 ++- .../scala/io/getquill/norm/ApplyMap.scala | 11 +- .../scala/io/getquill/norm/Normalize.scala | 32 ++- .../scala/io/getquill/util/Interpolator.scala | 140 +++++++++++++ .../scala/io/getquill/util/Messages.scala | 46 ++-- .../io/getquill/util/InterpolatorSpec.scala | 196 ++++++++++++++++++ .../io/getquill/context/sql/SqlQuery.scala | 22 +- .../context/sql/idiom/VerifySqlQuery.scala | 31 ++- .../sql/norm/ExpandNestedQueries.scala | 70 +------ .../context/sql/norm/nested/Elements.scala | 28 +++ .../sql/norm/nested/ExpandSelect.scala | 184 ++++++++++++++++ .../norm/nested/FindUnexpressedInfixes.scala | 82 ++++++++ .../getquill/context/sql/EmbeddedSpec.scala | 21 ++ .../io/getquill/context/sql/InfixSpec.scala | 62 +++++- .../getquill/context/sql/SqlQuerySpec.scala | 2 +- .../sql/norm/ExpandNestedQueriesSpec.scala | 18 ++ 17 files changed, 871 insertions(+), 106 deletions(-) create mode 100644 quill-core/src/main/scala/io/getquill/util/Interpolator.scala create mode 100644 quill-core/src/test/scala/io/getquill/util/InterpolatorSpec.scala create mode 100644 quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/Elements.scala create mode 100644 quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/ExpandSelect.scala create mode 100644 quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/FindUnexpressedInfixes.scala create mode 100644 quill-sql/src/test/scala/io/getquill/context/sql/EmbeddedSpec.scala diff --git a/quill-codegen-jdbc/src/test/resources/application-codegen.conf b/quill-codegen-jdbc/src/test/resources/application-codegen.conf index 9d15f12a0d..40b9944b23 100644 --- a/quill-codegen-jdbc/src/test/resources/application-codegen.conf +++ b/quill-codegen-jdbc/src/test/resources/application-codegen.conf @@ -9,7 +9,7 @@ testPostgresDB.dataSource.databaseName=codegen_test # Otherwise can get PSQLException: FATAL: sorry, too many clients already testPostgresDB.maximumPoolSize=1 -testH2DB.dataSource.url="jdbc:h2:file:./codegen_test.h2" +testH2DB.dataSource.url="jdbc:h2:file:./codegen_test.h2;DB_CLOSE_ON_EXIT=TRUE" testSqliteDB.jdbcUrl="jdbc:sqlite:codegen_test.db" diff --git a/quill-core/src/main/scala/io/getquill/AstPrinter.scala b/quill-core/src/main/scala/io/getquill/AstPrinter.scala index c279bc4dc5..739cddca86 100644 --- a/quill-core/src/main/scala/io/getquill/AstPrinter.scala +++ b/quill-core/src/main/scala/io/getquill/AstPrinter.scala @@ -1,10 +1,21 @@ package io.getquill +import fansi.Str import io.getquill.ast.Renameable.{ ByStrategy, Fixed } -import io.getquill.ast.{ Entity, Property, Renameable } +import io.getquill.ast.{ Ast, Entity, Property, Renameable } import pprint.{ Renderer, Tree, Truncated } -class AstPrinter(traceOpinions: Boolean) extends pprint.Walker { +object AstPrinter { + object Implicits { + implicit class FansiStrExt(str: Str) { + def string(color: Boolean): String = + if (color) str.render + else str.plainText + } + } +} + +class AstPrinter(traceOpinions: Boolean, traceAstSimple: Boolean) extends pprint.Walker { val defaultWidth: Int = 150 val defaultHeight: Int = Integer.MAX_VALUE val defaultIndent: Int = 2 @@ -18,6 +29,12 @@ class AstPrinter(traceOpinions: Boolean) extends pprint.Walker { } override def additionalHandlers: PartialFunction[Any, Tree] = { + case ast: Ast if (traceAstSimple) => + Tree.Literal(ast + "") // Do not blow up if it is null + + case past: PseudoAst if (traceAstSimple) => + Tree.Literal(past + "") // Do not blow up if it is null + case p: Property if (traceOpinions) => Tree.Apply("Property", List[Tree](treeify(p.ast), treeify(p.name), printRenameable(p.renameable)).iterator) @@ -36,4 +53,11 @@ class AstPrinter(traceOpinions: Boolean) extends pprint.Walker { val truncated = new Truncated(rendered, defaultWidth, defaultHeight) truncated } -} \ No newline at end of file +} + +/** + * A trait to be used by elements that are not proper AST elements but should still be treated as though + * they were in the case where `traceAstSimple` is enabled (i.e. their toString method should be + * used instead of the standard qprint AST printing) + */ +trait PseudoAst \ No newline at end of file diff --git a/quill-core/src/main/scala/io/getquill/norm/ApplyMap.scala b/quill-core/src/main/scala/io/getquill/norm/ApplyMap.scala index 74ff79a56f..92ee0ff00d 100644 --- a/quill-core/src/main/scala/io/getquill/norm/ApplyMap.scala +++ b/quill-core/src/main/scala/io/getquill/norm/ApplyMap.scala @@ -26,6 +26,15 @@ object ApplyMap { } } + object MapWithoutInfixes { + def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = + ast match { + case Map(a, b, InfixedTailOperation(c)) => None + case Map(a, b, c) => Some((a, b, c)) + case _ => None + } + } + object DetachableMap { def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = ast match { @@ -50,7 +59,7 @@ object ApplyMap { // a.map(b => c).map(d => e) => // a.map(b => e[d := c]) - case Map(Map(a, b, c), d, e) => + case before @ Map(MapWithoutInfixes(a, b, c), d, e) => val er = BetaReduction(e, d -> c) Some(Map(a, b, er)) diff --git a/quill-core/src/main/scala/io/getquill/norm/Normalize.scala b/quill-core/src/main/scala/io/getquill/norm/Normalize.scala index 73a027c48e..3e007759ab 100644 --- a/quill-core/src/main/scala/io/getquill/norm/Normalize.scala +++ b/quill-core/src/main/scala/io/getquill/norm/Normalize.scala @@ -5,6 +5,8 @@ import io.getquill.ast.Query import io.getquill.ast.StatelessTransformer import io.getquill.norm.capture.AvoidCapture import io.getquill.ast.Action +import io.getquill.util.Messages.trace +import io.getquill.util.Messages.TraceType.Normalizations import scala.annotation.tailrec @@ -19,15 +21,31 @@ object Normalize extends StatelessTransformer { override def apply(q: Query): Query = norm(AvoidCapture(q)) + private def traceNorm[T](label: String) = + trace[T](s"${label} (Normalize)", 1, Normalizations) + @tailrec private def norm(q: Query): Query = q match { - case NormalizeNestedStructures(query) => norm(query) - case ApplyMap(query) => norm(query) - case SymbolicReduction(query) => norm(query) - case AdHocReduction(query) => norm(query) - case OrderTerms(query) => norm(query) - case NormalizeAggregationIdent(query) => norm(query) - case other => other + case NormalizeNestedStructures(query) => + traceNorm("NormalizeNestedStructures")(query) + norm(query) + case ApplyMap(query) => + traceNorm("ApplyMap")(query) + norm(query) + case SymbolicReduction(query) => + traceNorm("SymbolicReduction")(query) + norm(query) + case AdHocReduction(query) => + traceNorm("AdHocReduction")(query) + norm(query) + case OrderTerms(query) => + traceNorm("OrderTerms")(query) + norm(query) + case NormalizeAggregationIdent(query) => + traceNorm("NormalizeAggregationIdent")(query) + norm(query) + case other => + other } } diff --git a/quill-core/src/main/scala/io/getquill/util/Interpolator.scala b/quill-core/src/main/scala/io/getquill/util/Interpolator.scala new file mode 100644 index 0000000000..afc84043fc --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/util/Interpolator.scala @@ -0,0 +1,140 @@ +package io.getquill.util + +import java.io.PrintStream + +import io.getquill.AstPrinter +import io.getquill.AstPrinter.Implicits._ +import io.getquill.util.Messages.TraceType + +import scala.collection.mutable +import scala.util.matching.Regex + +class Interpolator( + traceType: TraceType, + defaultIndent: Int = 0, + color: Boolean = Messages.traceColors, + qprint: AstPrinter = Messages.qprint, + out: PrintStream = System.out, + tracesEnabled: (TraceType) => Boolean = Messages.tracesEnabled(_) +) { + implicit class InterpolatorExt(sc: StringContext) { + def trace(elements: Any*) = new Traceable(sc, elements) + } + implicit class StringOps(str: String) { + def fitsOnOneLine: Boolean = !str.contains("\n") + def multiline(indent: Int, prefix: String): String = + str.split("\n").map(elem => indent.prefix + prefix + elem).mkString("\n") + } + implicit class IndentOps(i: Int) { + def prefix = indentOf(i) + } + private def indentOf(num: Int) = + (0 to num).map(_ => "").mkString(" ") + + class Traceable(sc: StringContext, elementsSeq: Seq[Any]) { + + private val elementPrefix = "| " + + private sealed trait PrintElement + private case class Str(str: String, first: Boolean) extends PrintElement + private case class Elem(value: String) extends PrintElement + private case object Separator extends PrintElement + + private def generateStringForCommand(value: Any, indent: Int) = { + val objectString = qprint(value).string(color) + val oneLine = objectString.fitsOnOneLine + oneLine match { + case true => s"${indent.prefix}> ${objectString}" + case false => s"${indent.prefix}>\n${objectString.multiline(indent, elementPrefix)}" + } + } + + private def readFirst(first: String) = + new Regex("%([0-9]+)(.*)").findFirstMatchIn(first) match { + case Some(matches) => (matches.group(2).trim, matches.group(1).toInt) + case None => (first, defaultIndent) + } + + private def readBuffers() = { + val parts = sc.parts.iterator.toList + val elements = elementsSeq.toList.map(qprint(_).string(color)) + + val (firstStr, indent) = readFirst(parts.head) + + val partsIter = parts.iterator + partsIter.next() // already took care of the 1st element + val elementsIter = elements.iterator + + val sb = new mutable.ArrayBuffer[PrintElement]() + sb.append(Str(firstStr.trim, true)) + + while (elementsIter.hasNext) { + sb.append(Separator) + sb.append(Elem(elementsIter.next())) + val nextPart = partsIter.next().trim + sb.append(Separator) + sb.append(Str(nextPart, false)) + } + + (sb.toList, indent) + } + + def generateString() = { + val (elementsRaw, indent) = readBuffers() + + val elements = elementsRaw.filter { + case Str(value, _) => value.trim != "" + case Elem(value) => value.trim != "" + case _ => true + } + + val oneLine = elements.forall { + case Elem(value) => value.fitsOnOneLine + case Str(value, _) => value.fitsOnOneLine + case _ => true + } + val output = + elements.map { + case Str(value, true) if (oneLine) => indent.prefix + value + case Str(value, false) if (oneLine) => value + case Elem(value) if (oneLine) => value + case Separator if (oneLine) => " " + case Str(value, true) => value.multiline(indent, "") + case Str(value, false) => value.multiline(indent, "|") + case Elem(value) => value.multiline(indent, "| ") + case Separator => "\n" + } + + (output.mkString, indent) + } + + private def logIfEnabled[T]() = + if (tracesEnabled(traceType)) + Some(generateString()) + else + None + + def andLog(): Unit = + logIfEnabled.foreach(value => out.println(value._1)) + + def andContinue[T](command: => T) = { + logIfEnabled.foreach(value => out.println(value._1)) + command + } + + def andReturn[T](command: => T) = { + logIfEnabled() match { + case Some((output, indent)) => + // do the initial log + out.println(output) + // evaluate the command, this will activate any traces that were inside of it + val result = command + out.println(generateStringForCommand(result, indent)) + + result + case None => + command + } + } + } +} \ No newline at end of file diff --git a/quill-core/src/main/scala/io/getquill/util/Messages.scala b/quill-core/src/main/scala/io/getquill/util/Messages.scala index ef2b77f2c1..1642a32e21 100644 --- a/quill-core/src/main/scala/io/getquill/util/Messages.scala +++ b/quill-core/src/main/scala/io/getquill/util/Messages.scala @@ -1,30 +1,52 @@ package io.getquill.util import io.getquill.AstPrinter - import scala.reflect.macros.blackbox.{ Context => MacroContext } object Messages { - private val debugEnabled = { - !sys.env.get("quill.macro.log").filterNot(_.isEmpty).map(_.toLowerCase).contains("false") && - !Option(System.getProperty("quill.macro.log")).filterNot(_.isEmpty).map(_.toLowerCase).contains("false") - } + private def variable(propName: String, envName: String, default: String) = + Option(System.getProperty(propName)).orElse(sys.env.get(envName)).getOrElse(default) + + private[util] val debugEnabled = variable("quill.macro.log", "quill_macro_log", "true").toBoolean + private[util] val traceEnabled = variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean + private[util] val traceColors = variable("quill.trace.color", "quill_trace_color,", "false").toBoolean + private[util] val traceOpinions = variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean + private[util] val traceAstSimple = variable("quill.trace.ast.simple", "quill_trace_ast_simple", "false").toBoolean + private[util] val traces: List[TraceType] = + variable("quill.trace.types", "quill_trace_types", "standard") + .split(",") + .toList + .map(_.trim) + .flatMap(trace => TraceType.values.filter(traceType => trace == traceType.value)) - private val traceEnabled = false - private val traceColors = false - private val traceOpinions = false + def tracesEnabled(tt: TraceType) = + traceEnabled && traces.contains(tt) + + sealed trait TraceType { def value: String } + object TraceType { + case object Normalizations extends TraceType { val value = "norm" } + case object Standard extends TraceType { val value = "standard" } + case object NestedQueryExpansion extends TraceType { val value = "nest" } + + def values: List[TraceType] = List(Standard, Normalizations, NestedQueryExpansion) + } - val qprint = new AstPrinter(traceOpinions) + val qprint = new AstPrinter(traceOpinions, traceAstSimple) def fail(msg: String) = throw new IllegalStateException(msg) - def trace[T](label: String) = + def trace[T](label: String, numIndent: Int = 0, traceType: TraceType = TraceType.Standard) = (v: T) => { - if (traceEnabled) - println(s"$label\n${{ if (traceColors) qprint.apply(v).render else qprint.apply(v).plainText }.split("\n").map(" " + _).mkString("\n")}") + val indent = (0 to numIndent).map(_ => "").mkString(" ") + if (tracesEnabled(traceType)) + println(s"$indent$label\n${ + { + if (traceColors) qprint.apply(v).render else qprint.apply(v).plainText + }.split("\n").map(s"$indent " + _).mkString("\n") + }") v } diff --git a/quill-core/src/test/scala/io/getquill/util/InterpolatorSpec.scala b/quill-core/src/test/scala/io/getquill/util/InterpolatorSpec.scala new file mode 100644 index 0000000000..7f2fb80fc2 --- /dev/null +++ b/quill-core/src/test/scala/io/getquill/util/InterpolatorSpec.scala @@ -0,0 +1,196 @@ +package io.getquill.util + +import java.io.{ ByteArrayOutputStream, PrintStream } + +import io.getquill.Spec +import io.getquill.util.Messages.TraceType.Standard + +class InterpolatorSpec extends Spec { + + val interp = new Interpolator(Standard, defaultIndent = 0, color = false, tracesEnabled = _ => true) + import interp._ + + case class Small(id: Int) + val small = Small(123) + + "traces small objects on single line - single" in { + trace"small object: $small".generateString() mustEqual (("small object: Small(123) ", 0)) + } + + "traces multiple small objects on single line" in { + trace"small object: $small and $small".generateString() mustEqual (("small object: Small(123) and Small(123) ", 0)) + } + + "traces multiple small objects multline text" in { + trace"""small object: $small and foo +and bar $small""".generateString() mustEqual ( + ( + """small object: + || Small(123) + ||and foo + ||and bar + || Small(123) + |""".stripMargin, + 0 + ) + ) + } + + case class Large(id: Int, one: String, two: String, three: String, four: String, five: String, six: String, seven: String, eight: String, nine: String, ten: String) + val vars = (0 until 10).map(i => (0 until i).map(_ => "Test").mkString("")).toList + val large = Large(123, vars(0), vars(1), vars(2), vars(3), vars(4), vars(5), vars(6), vars(7), vars(8), vars(9)) + + "traces large objects on multiple line - single" in { + trace"large object: $large".generateString() mustEqual (( + """large object: + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + |""".stripMargin, 0 + )) + } + + "traces large objects on multiple line - single - custom indent" in { + trace"%2 large object: $large".generateString() mustEqual (( + """ large object: + | | Large( + | | 123, + | | "", + | | "Test", + | | "TestTest", + | | "TestTestTest", + | | "TestTestTestTest", + | | "TestTestTestTestTest", + | | "TestTestTestTestTestTest", + | | "TestTestTestTestTestTestTest", + | | "TestTestTestTestTestTestTestTest", + | | "TestTestTestTestTestTestTestTestTest" + | | ) + |""".stripMargin, 2 + )) + } + + "traces large objects on multiple line - multi" in { + trace"large object: $large and $large".generateString() mustEqual (( + """large object: + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + ||and + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + |""".stripMargin, 0 + )) + } + + "should log to print stream" - { + "do not log if traces disabled" in { + val buff = new ByteArrayOutputStream() + val ps = new PrintStream(buff) + val interp = new Interpolator(Standard, defaultIndent = 0, color = false, tracesEnabled = _ => false, out = ps) + import interp._ + + trace"small object: $small".andLog() + ps.flush() + buff.toString mustEqual "" + } + + "log if traces disabled" in { + val buff = new ByteArrayOutputStream() + val ps = new PrintStream(buff) + val interp = new Interpolator(Standard, defaultIndent = 0, color = false, tracesEnabled = _ => true, out = ps) + import interp._ + + trace"small object: $small".andLog() + ps.flush() + buff.toString mustEqual "small object: Small(123) \n" + } + + "traces large objects on multiple line - multi - with return" in { + val buff = new ByteArrayOutputStream() + val ps = new PrintStream(buff) + val interp = new Interpolator(Standard, defaultIndent = 0, color = false, tracesEnabled = _ => true, out = ps) + import interp._ + + trace"large object: $large and $large".andReturn(large) mustEqual large + + buff.toString mustEqual ( + """large object: + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + ||and + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + | + |> + || Large( + || 123, + || "", + || "Test", + || "TestTest", + || "TestTestTest", + || "TestTestTestTest", + || "TestTestTestTestTest", + || "TestTestTestTestTestTest", + || "TestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTest", + || "TestTestTestTestTestTestTestTestTest" + || ) + |""".stripMargin + ) + } + } +} diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala index 20e2809e4b..f8bfdc330e 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/SqlQuery.scala @@ -4,7 +4,7 @@ import io.getquill.ast._ import io.getquill.context.sql.norm.FlattenGroupByAggregation import io.getquill.norm.BetaReduction import io.getquill.util.Messages.fail -import io.getquill.Literal +import io.getquill.{ Literal, PseudoAst } case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) @@ -40,7 +40,9 @@ case class UnaryOperationSqlQuery( q: SqlQuery ) extends SqlQuery -case class SelectValue(ast: Ast, alias: Option[String] = None, concat: Boolean = false) +case class SelectValue(ast: Ast, alias: Option[String] = None, concat: Boolean = false) extends PseudoAst { + override def toString: String = s"${ast.toString}${alias.map("->" + _).getOrElse("")}" +} case class FlattenSqlQuery( from: List[FromContext] = List(), @@ -95,6 +97,20 @@ object SqlQuery { (List.empty, other) } + object NestedNest { + def unapply(q: Ast): Option[Ast] = + q match { + case _: Nested => recurse(q) + case _ => None + } + + private def recurse(q: Ast): Option[Ast] = + q match { + case Nested(qn) => recurse(qn) + case other => Some(other) + } + } + private def flatten(sources: List[FromContext], finalFlatMapBody: Ast, alias: String): FlattenSqlQuery = { def select(alias: String) = SelectValue(Ident(alias), None) :: Nil @@ -103,7 +119,7 @@ object SqlQuery { def nest(ctx: FromContext) = FlattenSqlQuery(from = sources :+ ctx, select = select(alias)) q match { case Map(_: GroupBy, _, _) => nest(source(q, alias)) - case Nested(q) => nest(QueryContext(apply(q), alias)) + case NestedNest(q) => nest(QueryContext(apply(q), alias)) case q: ConcatMap => nest(QueryContext(apply(q), alias)) case Join(tpe, a, b, iA, iB, on) => val ctx = source(q, alias) diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/VerifySqlQuery.scala b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/VerifySqlQuery.scala index 3d36f7e452..9d1a705ea7 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/VerifySqlQuery.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/VerifySqlQuery.scala @@ -1,21 +1,8 @@ package io.getquill.context.sql.idiom -import io.getquill.ast.Ast -import io.getquill.ast.Ident -import io.getquill.context.sql.FlatJoinContext -import io.getquill.context.sql.FlattenSqlQuery -import io.getquill.context.sql.FromContext -import io.getquill.context.sql.InfixContext -import io.getquill.context.sql.JoinContext -import io.getquill.context.sql.QueryContext -import io.getquill.context.sql.SetOperationSqlQuery -import io.getquill.context.sql.SqlQuery -import io.getquill.context.sql.TableContext -import io.getquill.context.sql.UnaryOperationSqlQuery +import io.getquill.ast._ +import io.getquill.context.sql._ import io.getquill.quotation.FreeVariables -import io.getquill.ast.CollectAst -import io.getquill.ast.Property -import io.getquill.ast.Aggregation case class Error(free: List[Ident], ast: Ast) case class InvalidSqlQuery(errors: List[Error]) { @@ -80,11 +67,23 @@ object VerifySqlQuery { } } + // Recursively expand children until values are fully flattened. Identities in all these should + // be skipped during verification. + def expandSelect(sv: SelectValue): List[SelectValue] = + sv.ast match { + case Tuple(values) => values.map(v => SelectValue(v)).flatMap(expandSelect(_)) + case CaseClass(values) => values.map(v => SelectValue(v._2)).flatMap(expandSelect(_)) + case _ => List(sv) + } + val freeVariableErrors: List[Error] = query.where.flatMap(verifyAst).toList ++ query.orderBy.map(_.ast).flatMap(verifyAst) ++ query.limit.flatMap(verifyAst) ++ - query.select.map(_.ast).filterNot(_.isInstanceOf[Ident]).flatMap(verifyAst) ++ + query.select + .flatMap(expandSelect(_)) // Expand tuple select clauses so their top-level identities are skipped + .map(_.ast) + .filterNot(_.isInstanceOf[Ident]).flatMap(verifyAst) ++ query.from.flatMap { case j: JoinContext => verifyAst(j.on) case j: FlatJoinContext => verifyAst(j.on) diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala index 97d948e7fc..33a4eae166 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/ExpandNestedQueries.scala @@ -10,7 +10,6 @@ import io.getquill.context.sql.FromContext import io.getquill.context.sql.InfixContext import io.getquill.context.sql.JoinContext import io.getquill.context.sql.QueryContext -import io.getquill.context.sql.SelectValue import io.getquill.context.sql.SetOperationSqlQuery import io.getquill.context.sql.SqlQuery import io.getquill.context.sql.TableContext @@ -18,9 +17,15 @@ import io.getquill.context.sql.UnaryOperationSqlQuery import io.getquill.context.sql.FlatJoinContext import scala.collection.mutable.LinkedHashSet +import io.getquill.util.Interpolator +import io.getquill.util.Messages.TraceType.NestedQueryExpansion +import io.getquill.context.sql.norm.nested.ExpandSelect class ExpandNestedQueries(strategy: NamingStrategy) { + val interp = new Interpolator(NestedQueryExpansion, 3) + import interp._ + def apply(q: SqlQuery, references: List[Property]): SqlQuery = apply(q, LinkedHashSet.empty ++ references) @@ -29,7 +34,9 @@ class ExpandNestedQueries(strategy: NamingStrategy) { private def apply(q: SqlQuery, references: LinkedHashSet[Property]): SqlQuery = q match { case q: FlattenSqlQuery => - expandNested(q.copy(select = expandSelect(q.select, references))) + val expand = expandNested(q.copy(select = ExpandSelect(q.select, references, strategy))) + trace"Expanded Nested Query $q into $expand" andLog () + expand case SetOperationSqlQuery(a, op, b) => SetOperationSqlQuery(apply(a, references), op, apply(b, references)) case UnaryOperationSqlQuery(op, q) => @@ -55,65 +62,6 @@ class ExpandNestedQueries(strategy: NamingStrategy) { case _: TableContext | _: InfixContext => s } - private def expandSelect(select: List[SelectValue], references: LinkedHashSet[Property]) = { - - object TupleIndex { - def unapply(s: String): Option[Int] = - if (s.matches("_[0-9]*")) - Some(s.drop(1).toInt - 1) - else - None - } - - def expandColumn(name: String, renameable: Renameable): String = - renameable.fixedOr(name)(strategy.column(name)) - - def expandReference(ref: Property): SelectValue = { - - def concat(alias: Option[String], idx: Int) = - Some(s"${alias.getOrElse("")}_${idx + 1}") - - ref match { - case Property(ast: Property, TupleIndex(idx)) => - expandReference(ast) match { - case SelectValue(Tuple(elems), alias, c) => - SelectValue(elems(idx), concat(alias, idx), c) - case SelectValue(ast, alias, c) => - SelectValue(ast, concat(alias, idx), c) - } - case Property.Opinionated(ast: Property, name, renameable) => - expandReference(ast) match { - case SelectValue(ast, nested, c) => - // Alias is the name of the column after the naming strategy - // The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the - // alias field when it's value is Some(T). - // Technically the aliases of a column should not be using naming strategies - // but this is an issue to fix at a later date. - SelectValue(Property.Opinionated(ast, name, renameable), Some(s"${nested.getOrElse("")}${expandColumn(name, renameable)}"), c) - } - case Property(_, TupleIndex(idx)) => - select(idx) match { - case SelectValue(ast, alias, c) => - SelectValue(ast, concat(alias, idx), c) - } - case Property.Opinionated(_, name, renameable) => - select match { - case List(SelectValue(cc: CaseClass, alias, c)) => - SelectValue(cc.values.toMap.apply(name), Some(expandColumn(name, renameable)), c) - case List(SelectValue(i: Ident, _, c)) => - SelectValue(Property.Opinionated(i, name, renameable), None, c) - case other => - SelectValue(Ident(name), Some(expandColumn(name, renameable)), false) - } - } - } - - references.toList match { - case Nil => select - case refs => refs.map(expandReference) - } - } - private def references(alias: String, asts: List[Ast]) = LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)(_.apply)._2.state.references) } diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/Elements.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/Elements.scala new file mode 100644 index 0000000000..5785aa7f5d --- /dev/null +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/Elements.scala @@ -0,0 +1,28 @@ +package io.getquill.context.sql.norm.nested + +import io.getquill.PseudoAst +import io.getquill.context.sql.SelectValue + +object Elements { + /** + * In order to be able to reconstruct the original ordering of elements inside of a select clause, + * we need to keep track of their order, not only within the top-level select but also it's order + * within any possible tuples/case-classes that in which it is embedded. + * For example, in the query: + *

+   *   query[Person].map(p => (p.id, (p.name, p.age))).nested
+   *   // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person x) AS p
+   * 
+ * Since the `p.name` and `p.age` elements are selected inside of a sub-tuple, their "order" is + * `List(2,1)` and `List(2,2)` respectively as opposed to `p.id` whose "order" is just `List(1)`. + * + * This class keeps track of the values needed in order to perform do this. + */ + case class OrderedSelect(order: List[Int], selectValue: SelectValue) extends PseudoAst { + override def toString: String = s"[${order.mkString(",")}]${selectValue}" + } + object OrderedSelect { + def apply(order: Int, selectValue: SelectValue) = + new OrderedSelect(List(order), selectValue) + } +} diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/ExpandSelect.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/ExpandSelect.scala new file mode 100644 index 0000000000..dbe1baf13b --- /dev/null +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/ExpandSelect.scala @@ -0,0 +1,184 @@ +package io.getquill.context.sql.norm.nested + +import io.getquill.NamingStrategy +import io.getquill.ast.Property +import io.getquill.context.sql.SelectValue +import io.getquill.util.Interpolator +import io.getquill.util.Messages.TraceType.NestedQueryExpansion + +import scala.collection.mutable.LinkedHashSet +import io.getquill.context.sql.norm.nested.Elements._ +import io.getquill.ast._ + +/** + * Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query constrct exists) and flattens + * them from a nested-hiearchical structure (i.e. tuples inside case classes inside tuples etc..) into + * into a single series of top-level select elements where needed. In cases where a user wants to select an element + * that contains an entire tuple (i.e. a sub-tuple of the outer select clause) we pull out the entire tuple + * that is being selected and leave it to the tokenizer to flatten later. + * + * The part about this operation that is tricky is if there are situations where there are infix clauses + * in a sub-query representing an element that has not been selected by the query-query but in order to ensure + * the SQL operation has the same meaning, we need to keep track for it. For example: + *

+ *   val q = quote {
+ *     query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int], p.name, p.id)).map(t => (t._2, t._3))
+ *   }
+ *   run(q)
+ *   // SELECT p._2, p._3 FROM (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
+ * 
+ * Since `DISTINCT ON` significantly changes the behavior of the outer query, we need to keep track of it + * inside of the inner query. In order to do this, we need to keep track of the location of the infix in the inner query + * so that we can reconstruct it. This is why the `OrderedSelect` and `DoubleOrderedSelect` objects are used. + * See the notes on these classes for more detail. + * + * See issue #1597 for more details and another example. + */ +private class ExpandSelect(selectValues: List[SelectValue], references: LinkedHashSet[Property], strategy: NamingStrategy) { + val interp = new Interpolator(NestedQueryExpansion, 3) + import interp._ + + object TupleIndex { + def unapply(s: String): Option[Int] = + if (s.matches("_[0-9]*")) + Some(s.drop(1).toInt - 1) + else + None + } + + object MultiTupleIndex { + def unapply(s: String): Boolean = + if (s.matches("(_[0-9]+)+")) + true + else + false + } + + val select = + selectValues.zipWithIndex.map { + case (value, index) => OrderedSelect(index, value) + } + + def expandColumn(name: String, renameable: Renameable): String = + renameable.fixedOr(name)(strategy.column(name)) + + def apply: List[SelectValue] = { + trace"Expanding Select values: $selectValues into references: $references" andLog () + + def expandReference(ref: Property): OrderedSelect = { + trace"Expanding: $ref from $select" andLog () + + def expressIfTupleIndex(str: String) = + str match { + case MultiTupleIndex() => Some(str) + case _ => None + } + + def concat(alias: Option[String], idx: Int) = + Some(s"${alias.getOrElse("")}_${idx + 1}") + + val orderedSelect = ref match { + case pp @ Property(ast: Property, TupleIndex(idx)) => + trace"Reference is a sub-property of a tuple index: $idx. Walking inside." andContinue + expandReference(ast) match { + case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) => + trace"Expressing Element $idx of $elems " andReturn + OrderedSelect(o :+ idx, SelectValue(elems(idx), concat(alias, idx), c)) + case OrderedSelect(o, SelectValue(ast, alias, c)) => + trace"Appending $idx to $alias " andReturn + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated(ast: Property, name, renameable) => + trace"Reference is a sub-property. Walking inside." andContinue + expandReference(ast) match { + case OrderedSelect(o, SelectValue(ast, nested, c)) => + // Alias is the name of the column after the naming strategy + // The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the + // alias field when it's value is Some(T). + // Technically the aliases of a column should not be using naming strategies + // but this is an issue to fix at a later date. + + // In the current implementation, aliases we add nested tuple names to queries e.g. + // SELECT foo from + // SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors) + // Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when + // this kind of query is the result of an applicative join that looks like this: + // query[baz].join(query[colors]).nested + // this may need to change based on how distinct appends table names instead of just tuple indexes + // into the property path. + + OrderedSelect(o, SelectValue( + Property.Opinionated(ast, name, renameable), + Some(s"${nested.flatMap(expressIfTupleIndex(_)).getOrElse("")}${expandColumn(name, renameable)}"), c + )) + } + case pp @ Property(_, TupleIndex(idx)) => + trace"Reference is a tuple index: $idx from $select." andContinue + select(idx) match { + case OrderedSelect(o, SelectValue(ast, alias, c)) => + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated(_, name, renameable) => + select match { + case List(OrderedSelect(o, SelectValue(cc: CaseClass, alias, c))) => + // Currently case class element name is not being appended. Need to change that in order to ensure + // path name uniqueness in future. + val ((_, ast), index) = cc.values.zipWithIndex.find(_._1._1 == name) match { + case Some(v) => v + case None => throw new IllegalArgumentException(s"Cannot find element $name in $cc") + } + trace"Reference is a case class member: " andReturn + OrderedSelect(o :+ index, SelectValue(ast, Some(expandColumn(name, renameable)), c)) + case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) => + trace"Reference is an identifier: " andReturn + OrderedSelect(o, SelectValue(Property.Opinionated(i, name, renameable), None, c)) + case other => + trace"Reference is unidentified: " andReturn + OrderedSelect(Integer.MAX_VALUE, SelectValue(Ident(name), Some(expandColumn(name, renameable)), false)) + } + } + + trace"Expanded $ref into $orderedSelect" + orderedSelect + } + + references.toList match { + case Nil => select.map(_.selectValue) + case refs => { + // elements first need to be sorted by their order in the select clause. Since some may map to multiple + // properties when expanded, we want to maintain this order of properties as a secondary value. + val mappedRefs = refs.map(expandReference) + trace"Mapped Refs: $mappedRefs" andLog () + + // are there any selects that have infix values which we have not already selected? We need to include + // them because they could be doing essential things e.g. RANK ... ORDER BY + val remainingSelectsWithInfixes = + trace"Searching Selects with Infix:" andReturn + new FindUnexpressedInfixes(select)(mappedRefs) + + implicit val ordering: scala.math.Ordering[List[Int]] = new scala.math.Ordering[List[Int]] { + override def compare(x: List[Int], y: List[Int]): Int = + (x, y) match { + case (head1 :: tail1, head2 :: tail2) => + val diff = head1 - head2 + if (diff != 0) diff + else compare(tail1, tail2) + case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3) + case (head1, Nil) => -1 // List(1,2,3) < List(1,2) + case (Nil, head2) => 1 // List(1,2) > List(1,2,3) + } + } + + val sortedRefs = + (mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref => ref.order) //(ref.order, ref.secondaryOrder) + + sortedRefs.map(_.selectValue) + } + } + } +} + +object ExpandSelect { + def apply(selectValues: List[SelectValue], references: LinkedHashSet[Property], strategy: NamingStrategy): List[SelectValue] = + new ExpandSelect(selectValues, references, strategy).apply +} diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/FindUnexpressedInfixes.scala b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/FindUnexpressedInfixes.scala new file mode 100644 index 0000000000..3802780faa --- /dev/null +++ b/quill-sql/src/main/scala/io/getquill/context/sql/norm/nested/FindUnexpressedInfixes.scala @@ -0,0 +1,82 @@ +package io.getquill.context.sql.norm.nested + +import io.getquill.context.sql.norm.nested.Elements._ +import io.getquill.util.Interpolator +import io.getquill.util.Messages.TraceType.NestedQueryExpansion +import io.getquill.ast._ +import io.getquill.context.sql.SelectValue + +/** + * The challenge with appeneding infixes (that have not been used but are still needed) + * back into the query, is that they could be inside of tuples/case-classes that have already + * been selected, or inside of sibling elements which have been selected. + * Take for instance a query that looks like this: + *

+ *   query[Person].map(p => (p.name, (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1))
+ * 
+ * In this situation, `p.id` which is the sibling of the non-selected infix has been selected + * via `p._2._1` (whose select-order is List(1,0) to represent 1st element in 2nd tuple. + * We need to add it's sibling infix. + * + * Or take the following situation: + *

+ *   query[Person].map(p => (p.name, (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
+ * 
+ * In this case, we have selected the entire 2nd element including the infix. We need to know that + * `P._2._2` does not need to be selected since `p._2` was. + * + * In order to do these things, we use the `order` property from `OrderedSelect` in order to see + * which sub-sub-...-element has been selected. If `p._2` (that has order `List(1)`) + * has been selected, we know that any infixes inside of it e.g. `p._2._1` (ordering `List(1,0)`) + * does not need to be. + */ +class FindUnexpressedInfixes(select: List[OrderedSelect]) { + val interp = new Interpolator(NestedQueryExpansion, 3) + import interp._ + + def apply(refs: List[OrderedSelect]) = { + + def pathExists(path: List[Int]) = + refs.map(_.order).contains(path) + + def containsInfix(ast: Ast) = + CollectAst.byType[Infix](ast).length > 0 + + // build paths to every infix and see these paths were not selected already + def findMissingInfixes(ast: Ast, parentOrder: List[Int]): List[(Ast, List[Int])] = { + trace"Searching for infix: $ast in the sub-path $parentOrder" andLog () + if (pathExists(parentOrder)) + trace"No infixes found" andContinue + List() + else + ast match { + case Tuple(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1)) + .flatMap { + case (ast, index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case CaseClass(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1._2)) + .flatMap { + case ((_, ast), index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case other if (containsInfix(other)) => + trace"Found unexpressed infix inside $other in $parentOrder" andLog () + List((other, parentOrder)) + case _ => + List() + } + } + + select + .flatMap { + case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o) + }.map { + case (ast, order) => OrderedSelect(order, SelectValue(ast)) + } + } +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/EmbeddedSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/EmbeddedSpec.scala new file mode 100644 index 0000000000..6b7f1f4905 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/EmbeddedSpec.scala @@ -0,0 +1,21 @@ +package io.getquill.context.sql + +import io.getquill._ + +class EmbeddedSpec extends Spec { + + val ctx = new SqlMirrorContext(MirrorSqlDialect, Literal) with TestEntities + import ctx._ + + "queries with embedded entities should" - { + "function property inside of nested distinct queries" in { + case class Parent(id: Int, emb1: Emb) + case class Emb(a: Int, b: Int) extends Embedded + val q = quote { + query[Emb].map(e => Parent(1, e)).distinct + } + ctx.run(q).string mustEqual "SELECT e.id, e.a, e.b FROM (SELECT DISTINCT 1 AS id, e.a AS a, e.b AS b FROM Emb e) AS e" + } + } + +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/InfixSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/InfixSpec.scala index edfa350872..251f12dc04 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/InfixSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/InfixSpec.scala @@ -87,7 +87,67 @@ class InfixSpec extends Spec { } yield ThreeData(r.id, r.value, infix"bar".as[Int]) } - ctx.run(q2).string mustEqual "SELECT d._1, d._2, d._3 FROM (SELECT d.id AS _1, d.value AS _2, bar AS _3 FROM (SELECT d.id AS id, foo AS value FROM Data d) AS d WHERE d.value = 1) AS d" + ctx.run(q2).string mustEqual "SELECT d.id, d.value, d.secondValue FROM (SELECT d.id AS id, d.value AS value, bar AS secondValue FROM (SELECT d.id AS id, foo AS value FROM Data d) AS d WHERE d.value = 1) AS d" + } + + "excluded infix values" - { + case class Person(id: Int, name: String, other: String, other2: String) + + "should not be dropped" in { + val q = quote { + query[Person].map(p => (p.name, p.id, infix"foo(${p.other})".as[Int])).map(p => (p._1, p._2)) + } + + ctx.run(q).string mustEqual "SELECT p._1, p._2 FROM (SELECT p.name AS _1, p.id AS _2, foo(p.other) FROM Person p) AS p" + } + + "should not be dropped if pure" in { + val q = quote { + query[Person].map(p => (p.name, p.id, infix"foo(${p.other})".pure.as[Int])).map(p => (p._1, p._2)) + } + + ctx.run(q).string mustEqual "SELECT p.name, p.id FROM Person p" + } + + "should not be dropped in nested tuples" in { + val q = quote { + query[Person].map(p => (p.name, (p.id, infix"foo(${p.other})".as[Int]))).map(p => (p._1, p._2._1)) + } + + ctx.run(q).string mustEqual "SELECT p._1, p._2_1 FROM (SELECT p.name AS _1, p.id AS _2_1, foo(p.other) FROM Person p) AS p" + } + + "should not be selected twice if in sub-sub tuple" in { + val q = quote { + query[Person].map(p => (p.name, (p.id, infix"foo(${p.other})".as[Int]))).map(p => (p._1, p._2)) + } + + ctx.run(q).string mustEqual "SELECT p._1, p._2_1, p._2_2 FROM (SELECT p.name AS _1, p.id AS _2_1, foo(p.other) AS _2_2 FROM Person p) AS p" + } + + "should not be selected in sub-sub tuple if pure" in { + val q = quote { + query[Person].map(p => (p.name, (p.id, infix"foo(${p.other})".pure.as[Int]))).map(p => (p._1, p._2)) + } + + ctx.run(q).string mustEqual "SELECT p.name, p.id, foo(p.other) FROM Person p" + } + + "should not be selected twice in one field matched, one missing" in { + val q = quote { + query[Person].map(p => (p.name, (p.id, infix"foo(${p.other}, ${p.other2})".as[Int], p.other))).map(p => (p._1, p._2._1, p._2._3)) + } + + ctx.run(q).string mustEqual "SELECT p._1, p._2_1, p._2_3 FROM (SELECT p.name AS _1, p.id AS _2_1, foo(p.other, p.other2), p.other AS _2_3 FROM Person p) AS p" + } + + "distinct-on infix example" in { + val q = quote { + query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int], p.name, p.id)).map(t => (t._2, t._3)) + } + + ctx.run(q).string mustEqual "SELECT p._2, p._3 FROM (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p" + } } } } \ No newline at end of file diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala index 773e63acb7..5972991af0 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/SqlQuerySpec.scala @@ -380,7 +380,7 @@ class SqlQuerySpec extends Spec { } } testContext.run(q).string mustEqual - "SELECT t._2i, SUM(t._1i) FROM (SELECT b.i AS _2i, a.i AS _1i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) AS t GROUP BY t._2i" + "SELECT t._2i, SUM(t._1i) FROM (SELECT a.i AS _1i, b.i AS _2i FROM TestEntity a INNER JOIN TestEntity2 b ON a.s = b.s) AS t GROUP BY t._2i" } } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandNestedQueriesSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandNestedQueriesSpec.scala index a9c1c8e07c..fafe678303 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandNestedQueriesSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/norm/ExpandNestedQueriesSpec.scala @@ -90,4 +90,22 @@ class ExpandNestedQueriesSpec extends Spec { testContext.run(q.dynamic).string mustEqual "SELECT x03._1i, x03._2i FROM (SELECT a.i AS _1i, b.i AS _2i FROM TestEntity a INNER JOIN TestEntity2 b ON a.i = b.i) AS x03" } + + "expands nested mapped entity correctly" in { + import testContext._ + + case class TestEntity(s: String, i: Int, l: Long, o: Option[Int]) extends Embedded + case class Dual(ta: TestEntity, tb: TestEntity) + + val qr1 = quote { + query[TestEntity] + } + + val q = quote { + qr1.join(qr1).on((a, b) => a.i == b.i).nested.map(both => both match { case (a, b) => Dual(a, b) }).nested + } + + testContext.run(q).string mustEqual + "SELECT both._1s, both._1i, both._1l, both._1o, both._2s, both._2i, both._2l, both._2o FROM (SELECT a.s AS _1s, a.i AS _1i, a.l AS _1l, a.o AS _1o, b.s AS _2s, b.i AS _2i, b.l AS _2l, b.o AS _2o FROM TestEntity a INNER JOIN TestEntity b ON a.i = b.i) AS both" + } }