Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] inserted record can be returned from query #1383

Closed
wants to merge 7 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ trait CqlIdiom extends Idiom {
case a: Operation => a.token
case a: Action => a.token
case a: Ident => a.token
case a: Type => a.token
case a: Property => a.token
case a: Value => a.token
case a: Function => a.body.token
Expand Down Expand Up @@ -190,4 +191,8 @@ trait CqlIdiom extends Idiom {
case SetContains(ast, body) => stmt"${ast.token} CONTAINS ${body.token}"
case ListContains(ast, body) => stmt"${ast.token} CONTAINS ${body.token}"
}

implicit val typeTokenizer: Tokenizer[Type] = Tokenizer[Type] {
case _ => stmt""
}
}
6 changes: 6 additions & 0 deletions quill-core/src/main/scala/io/getquill/MirrorIdiom.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MirrorIdiom extends Idiom {
case ast: Operation => ast.token
case ast: Action => ast.token
case ast: Ident => ast.token
case ast: Type => ast.token
case ast: Property => ast.token
case ast: Infix => ast.token
case ast: OptionOperation => ast.token
Expand Down Expand Up @@ -192,6 +193,10 @@ class MirrorIdiom extends Idiom {
case e => stmt"${e.name.token}"
}

implicit val typeTokenizer: Tokenizer[Type] = Tokenizer[Type] {
case _ => stmt""
}

implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] = Tokenizer[OnConflict.Excluded] {
case OnConflict.Excluded(ident) => stmt"${ident.token}"
}
Expand All @@ -204,6 +209,7 @@ class MirrorIdiom extends Idiom {
case Update(query, assignments) => stmt"${query.token}.update(${assignments.token})"
case Insert(query, assignments) => stmt"${query.token}.insert(${assignments.token})"
case Delete(query) => stmt"${query.token}.delete"
case ReturningRecord(query, tpe) => stmt"${query.token}.returning[${tpe.token}]"
case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})"
case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
case c: OnConflict => stmt"${c.token}"
Expand Down
3 changes: 3 additions & 0 deletions quill-core/src/main/scala/io/getquill/ast/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ case class Ident(name: String) extends Ast

case class Property(ast: Ast, name: String) extends Ast

case class Type(name: String) extends Ast

sealed trait OptionOperation extends Ast
case class OptionFlatten(ast: Ast) extends OptionOperation
case class OptionGetOrElse(ast: Ast, body: Ast) extends OptionOperation
Expand Down Expand Up @@ -135,6 +137,7 @@ case class Insert(query: Ast, assignments: List[Assignment]) extends Action
case class Delete(query: Ast) extends Action

case class Returning(action: Ast, alias: Ident, property: Ast) extends Action
case class ReturningRecord(action: Ast, tpe: Type) extends Action

case class Foreach(query: Ast, alias: Ident, body: Ast) extends Action

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ trait StatefulTransformer[T] {
case e: Value => apply(e)
case e: Assignment => apply(e)
case e: Ident => (e, this)
case e: Type => (e, this)
case e: OptionOperation => apply(e)
case e: TraversableOperation => apply(e)
case e: Property => apply(e)
Expand Down Expand Up @@ -259,6 +260,9 @@ trait StatefulTransformer[T] {
val (at, att) = apply(a)
val (ct, ctt) = att.apply(c)
(Returning(at, b, ct), ctt)
case ReturningRecord(a, tpe) =>
val (at, att) = apply(a)
(ReturningRecord(at, tpe), att)
case Foreach(a, b, c) =>
val (at, att) = apply(a)
val (ct, ctt) = att.apply(c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ trait StatelessTransformer {
case e: Assignment => apply(e)
case Function(params, body) => Function(params, apply(body))
case e: Ident => e
case e: Type => e
case e: Property => apply(e)
case Infix(a, b) => Infix(a, b.map(apply))
case e: OptionOperation => apply(e)
Expand Down Expand Up @@ -112,6 +113,7 @@ trait StatelessTransformer {
case Insert(query, assignments) => Insert(apply(query), assignments.map(apply))
case Delete(query) => Delete(apply(query))
case Returning(query, alias, property) => Returning(apply(query), alias, apply(property))
case ReturningRecord(query, tpe) => ReturningRecord(apply(query), tpe)
case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body))
case OnConflict(query, target, action) => OnConflict(apply(query), apply(target), apply(action))
}
Expand Down
16 changes: 13 additions & 3 deletions quill-core/src/main/scala/io/getquill/context/ActionMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.getquill.ast._
import io.getquill.quotation.ReifyLiftings
import io.getquill.util.Messages._
import io.getquill.norm.BetaReduction
import io.getquill.util.EnableReflectiveCalls
import io.getquill.util.{ EnableReflectiveCalls, OptionalTypecheck }

class ActionMacro(val c: MacroContext)
extends ContextMacro
Expand Down Expand Up @@ -130,11 +130,21 @@ class ActionMacro(val c: MacroContext)
expanded.ast match {
case io.getquill.ast.Returning(_, _, io.getquill.ast.Property(_, property)) =>
expanded.naming.column(property)
case io.getquill.ast.ReturningRecord(_, _) =>
"*"
case ast =>
io.getquill.util.Messages.fail(s"Can't find returning column. Ast: '$$ast'")
}
"""

private def returningExtractor[T](implicit t: WeakTypeTag[T]) =
q"(row: ${c.prefix}.ResultRow) => implicitly[Decoder[$t]].apply(0, row)"
private def returningExtractor[T](implicit t: WeakTypeTag[T]) = {
OptionalTypecheck(c)(q"implicitly[${c.prefix}.Decoder[$t]]") match {
case Some(decoder) =>
q"(row: ${c.prefix}.ResultRow) => $decoder.apply(0, row)"
case None =>
val metaTpe = c.typecheck(tq"${c.prefix}.QueryMeta[$t]", c.TYPEmode).tpe
val meta = c.inferImplicitValue(metaTpe).orElse(q"${c.prefix}.materializeQueryMeta[$t]")
q"$meta.extract"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.getquill.context

sealed trait InsertReturnCapability
trait ReturnSingleField extends InsertReturnCapability
trait ReturnMultipleField extends InsertReturnCapability

trait Capabilities {
type ReturnAfterInsert <: InsertReturnCapability
}

trait CanReturnRecordAfterInsert extends Capabilities {
override type ReturnAfterInsert = ReturnMultipleField
}
3 changes: 3 additions & 0 deletions quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ private[dsl] trait QueryDsl {
sealed trait Action[E]

sealed trait Insert[E] extends Action[E] {
@compileTimeOnly(NonQuotedException.message)
def returning[R]: ActionReturning[E, R] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def returning[R](f: E => R): ActionReturning[E, R] = NonQuotedException()

Expand Down
17 changes: 11 additions & 6 deletions quill-core/src/main/scala/io/getquill/quotation/Liftables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,18 @@ trait Liftables {
case FullJoin => q"$pack.FullJoin"
}

implicit val typeLiftable: Liftable[io.getquill.ast.Type] = Liftable[io.getquill.ast.Type] {
case io.getquill.ast.Type(a) => q"$pack.Type($a)"
}

implicit val actionLiftable: Liftable[Action] = Liftable[Action] {
case Update(a, b) => q"$pack.Update($a, $b)"
case Insert(a, b) => q"$pack.Insert($a, $b)"
case Delete(a) => q"$pack.Delete($a)"
case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)"
case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)"
case OnConflict(a, b, c) => q"$pack.OnConflict($a, $b, $c)"
case Update(a, b) => q"$pack.Update($a, $b)"
case Insert(a, b) => q"$pack.Insert($a, $b)"
case Delete(a) => q"$pack.Delete($a)"
case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)"
case ReturningRecord(a, b) => q"$pack.ReturningRecord($a, $b)"
case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)"
case OnConflict(a, b, c) => q"$pack.OnConflict($a, $b, $c)"
}

implicit val conflictTargetLiftable: Liftable[OnConflict.Target] = Liftable[OnConflict.Target] {
Expand Down
31 changes: 31 additions & 0 deletions quill-core/src/main/scala/io/getquill/quotation/Parsing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package io.getquill.quotation
import scala.reflect.ClassTag
import io.getquill.ast._
import io.getquill.Embedded
import io.getquill.context.ReturnMultipleField
import io.getquill.norm.BetaReduction
import io.getquill.util.Messages.RichContext
import io.getquill.util.Interleave
import io.getquill.dsl.CoreDsl
import io.getquill.norm.capture.AvoidAliasConflict
import io.getquill.idiom.Idiom

import scala.annotation.tailrec
import scala.collection.immutable.StringOps
Expand Down Expand Up @@ -701,13 +703,42 @@ trait Parsing {
tpe.paramLists(0).map(_.name.toString)
}

val typeParser: Parser[io.getquill.ast.Type] = Parser[io.getquill.ast.Type] {
case t: Tree => io.getquill.ast.Type(t.tpe.dealias.typeSymbol.fullName)
}

val actionParser: Parser[Ast] = Parser[Ast] {
case q"$query.$method(..$assignments)" if (method.decodedName.toString == "update") =>
Update(astParser(query), assignments.map(assignmentParser(_)))
case q"$query.insert(..$assignments)" =>
Insert(astParser(query), assignments.map(assignmentParser(_)))
case q"$query.delete" =>
Delete(astParser(query))
case q"$action.returning[$r]" =>
val maybeIdiomTpe =
c.prefix.tree.tpe
.baseClasses
.flatMap { baseClass =>
val baseClassTypeArgs = c.prefix.tree.tpe.baseType(baseClass).typeArgs
baseClassTypeArgs.find { typeArg =>
typeArg <:< typeOf[Idiom]
}
}
.headOption

val canReturn = maybeIdiomTpe
.toSeq
.flatMap(_.members)
.exists {
case ts: TypeSymbol if ts.asType.typeSignature =:= typeOf[ReturnMultipleField] => true
case _ => false
}

(maybeIdiomTpe, canReturn) match {
case (Some(_), true) => ReturningRecord(astParser(action), typeParser(r))
case (Some(idiomTpe), false) => c.fail(s"""Idiom "${idiomTpe.typeSymbol.fullName}" doesn't support query returning multiple fields""")
case (None, _) => c.fail("provided context doesn't define sql idiom")
}
case q"$action.returning[$r](($alias) => $body)" =>
Returning(astParser(action), identParser(alias), astParser(body))
case q"$query.foreach[$t1, $t2](($alias) => $body)($f)" if (is[CoreDsl#Query[Any]](query)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ trait Unliftables {
case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b)
case q"$pack.Delete.apply(${ a: Ast })" => Delete(a)
case q"$pack.Returning.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Returning(a, b, c)
case q"$pack.ReturningRecord.apply(${ a: Ast }, ${ b: io.getquill.ast.Type })" => ReturningRecord(a, b)
case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c)
case q"$pack.OnConflict.apply(${ a: Ast }, ${ b: OnConflict.Target }, ${ c: OnConflict.Action })" => OnConflict(a, b, c)
}
Expand Down Expand Up @@ -192,4 +193,8 @@ trait Unliftables {
case q"$pack.ScalarQueryLift.apply(${ a: String }, $b, $c)" => ScalarQueryLift(a, b, c)
case q"$pack.CaseClassQueryLift.apply(${ a: String }, $b)" => CaseClassQueryLift(a, b)
}

implicit val typeUnliftable: Unliftable[io.getquill.ast.Type] = Unliftable[io.getquill.ast.Type] {
case q"$pack.Type.apply(${ a: String })" => io.getquill.ast.Type(a)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ trait OrientDBIdiom extends Idiom {
a.token
case a: Ident =>
a.token
case a: Type =>
a.token
case a: Property =>
a.token
case a: Value =>
Expand Down Expand Up @@ -302,4 +304,8 @@ trait OrientDBIdiom extends Idiom {
case _: Tuple => stmt"(${ast.token})"
case _ => ast.token
}

implicit val typeTokenizer: Tokenizer[Type] = Tokenizer[Type] {
case _ => stmt""
}
}
4 changes: 3 additions & 1 deletion quill-sql/src/main/scala/io/getquill/PostgresDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package io.getquill
import java.util.concurrent.atomic.AtomicInteger

import io.getquill.ast._
import io.getquill.context.CanReturnRecordAfterInsert
import io.getquill.context.sql.idiom._
import io.getquill.idiom.StatementInterpolator._

trait PostgresDialect
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with OnConflictSupport {
with OnConflictSupport
with CanReturnRecordAfterInsert {

override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] =
Tokenizer[Ast] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import io.getquill.ast.BooleanOperator._
import io.getquill.ast.Lift
import io.getquill.context.sql._
import io.getquill.context.sql.norm._
import io.getquill.idiom.{ Idiom, SetContainsToken, Statement }
import io.getquill.idiom._
import io.getquill.idiom.StatementInterpolator._
import io.getquill.NamingStrategy
import io.getquill.util.Interleave
import io.getquill.util.Messages.{ fail, trace }
import io.getquill.idiom.Token
import io.getquill.norm.ConcatBehavior
import io.getquill.norm.ConcatBehavior.AnsiConcat

Expand Down Expand Up @@ -56,6 +55,7 @@ trait SqlIdiom extends Idiom {
case a: Infix => a.token
case a: Action => a.token
case a: Ident => a.token
case a: Type => a.token
case a: Property => a.token
case a: Value => a.token
case a: If => a.token
Expand Down Expand Up @@ -383,6 +383,12 @@ trait SqlIdiom extends Idiom {
case Returning(action, alias, prop) =>
action.token

case ReturningRecord(Insert(table: Entity, Nil), _) =>
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(null)}"

case ReturningRecord(action, _) =>
action.token

case other =>
fail(s"Action ast can't be translated to sql: '$other'")
}
Expand All @@ -398,4 +404,8 @@ trait SqlIdiom extends Idiom {
case _: Tuple => stmt"(${ast.token})"
case _ => ast.token
}

implicit val typeTokenizer: Tokenizer[Type] = Tokenizer[Type] {
case _ => stmt""
}
}