diff --git a/.gitignore b/.gitignore index 51e62d7a7e..be23d59da5 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,7 @@ quill-jdbc/src/main/resources/logback.xml log.txt* tmp nohup.out +.bloop/ +.metals/ +project/.bloop/ +/io/ diff --git a/quill-async-mysql/src/main/scala/io/getquill/MysqlAsyncContext.scala b/quill-async-mysql/src/main/scala/io/getquill/MysqlAsyncContext.scala index f6b63f47f0..0a1279c6e0 100644 --- a/quill-async-mysql/src/main/scala/io/getquill/MysqlAsyncContext.scala +++ b/quill-async-mysql/src/main/scala/io/getquill/MysqlAsyncContext.scala @@ -17,7 +17,7 @@ class MysqlAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConnect def this(naming: N, config: Config) = this(naming, MysqlAsyncContextConfig(config)) def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix)) - override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = { + override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = { result match { case r: MySQLQueryResult => returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId))) diff --git a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala index 57330c35e0..d845bf338b 100644 --- a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala +++ b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala @@ -1,9 +1,10 @@ package io.getquill.context.async.mysql import com.github.mauricio.async.db.QueryResult +import io.getquill.ReturnAction.ReturnColumns import scala.concurrent.ExecutionContext.Implicits.global -import io.getquill.{ Literal, MysqlAsyncContext, Spec } +import io.getquill.{ Literal, MysqlAsyncContext, ReturnAction, Spec } class MysqlAsyncContextSpec extends Spec { @@ -18,7 +19,7 @@ class MysqlAsyncContextSpec extends Spec { "Insert with returning with single column table" in { val inserted: Long = await(testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) }) await(testContext.run(qr4.filter(_.i == lift(inserted)))) .head.i mustBe inserted @@ -35,13 +36,13 @@ class MysqlAsyncContextSpec extends Spec { "cannot extract" in { object ctx extends MysqlAsyncContext(Literal, "testMysqlDB") { override def extractActionResult[O]( - returningColumn: String, + returningAction: ReturnAction, returningExtractor: ctx.Extractor[O] )(result: QueryResult) = - super.extractActionResult(returningColumn, returningExtractor)(result) + super.extractActionResult(returningAction, returningExtractor)(result) } intercept[IllegalStateException] { - ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e")) + ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e")) } ctx.close } diff --git a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/ProductMysqlAsyncSpec.scala b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/ProductMysqlAsyncSpec.scala index 70bd122f07..b061ec3f3f 100644 --- a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/ProductMysqlAsyncSpec.scala +++ b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/ProductMysqlAsyncSpec.scala @@ -35,7 +35,7 @@ class ProductMysqlAsyncSpec extends ProductSpec { val prd = Product(0L, "test1", 1L) val inserted = await { testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } } val returnedProduct = await(testContext.run(productById(lift(inserted)))).head @@ -47,7 +47,7 @@ class ProductMysqlAsyncSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = await(testContext.run(q1)) val returnedProduct = await(testContext.run(productById(lift(inserted)))).head @@ -60,7 +60,7 @@ class ProductMysqlAsyncSpec extends ProductSpec { case class Product(id: Id, description: String, sku: Long) val prd = Product(Id(0L), "test2", 2L) val q1 = quote { - query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } await(testContext.run(q1)) mustBe a[Id] } diff --git a/quill-async-postgres/src/main/scala/io/getquill/PostgresAsyncContext.scala b/quill-async-postgres/src/main/scala/io/getquill/PostgresAsyncContext.scala index 482dcf0942..de539050f7 100644 --- a/quill-async-postgres/src/main/scala/io/getquill/PostgresAsyncContext.scala +++ b/quill-async-postgres/src/main/scala/io/getquill/PostgresAsyncContext.scala @@ -1,9 +1,10 @@ package io.getquill -import com.github.mauricio.async.db.{ QueryResult => DBQueryResult } import com.github.mauricio.async.db.pool.PartitionedConnectionPool import com.github.mauricio.async.db.postgresql.PostgreSQLConnection +import com.github.mauricio.async.db.{ QueryResult => DBQueryResult } import com.typesafe.config.Config +import io.getquill.ReturnAction.{ ReturnColumns, ReturnNothing, ReturnRecord } import io.getquill.context.async.{ ArrayDecoders, ArrayEncoders, AsyncContext, UUIDObjectEncoding } import io.getquill.util.LoadConfig import io.getquill.util.Messages.fail @@ -18,7 +19,7 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn def this(naming: N, config: Config) = this(naming, PostgresAsyncContextConfig(config)) def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix)) - override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = { + override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = { result.rows match { case Some(r) if r.nonEmpty => returningExtractor(r.head) @@ -27,6 +28,14 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn } } - override protected def expandAction(sql: String, returningColumn: String): String = - s"$sql RETURNING $returningColumn" + override protected def expandAction(sql: String, returningAction: ReturnAction): String = + returningAction match { + // The Postgres dialect will create SQL that has a 'RETURNING' clause so we don't have to add one. + case ReturnRecord => s"$sql" + // The Postgres dialect will not actually use these below variants but in case we decide to plug + // in some other dialect into this context... + case ReturnColumns(columns) => s"$sql RETURNING ${columns.mkString(", ")}" + case ReturnNothing => s"$sql" + } + } diff --git a/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala index 52c92f1175..9e2ef115e4 100644 --- a/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala +++ b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala @@ -1,9 +1,10 @@ package io.getquill.context.async.postgres import com.github.mauricio.async.db.QueryResult +import io.getquill.ReturnAction.ReturnColumns import scala.concurrent.ExecutionContext.Implicits.global -import io.getquill.{ Literal, PostgresAsyncContext, Spec } +import io.getquill.{ Literal, PostgresAsyncContext, ReturnAction, Spec } class PostgresAsyncContextSpec extends Spec { @@ -18,11 +19,18 @@ class PostgresAsyncContextSpec extends Spec { "Insert with returning with single column table" in { val inserted: Long = await(testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) }) await(testContext.run(qr4.filter(_.i == lift(inserted)))) .head.i mustBe inserted } + "Insert with returning with multiple columns" in { + await(testContext.run(qr1.delete)) + val inserted = await(testContext.run { + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o)) + }) + (1, "foo", Some(123)) mustBe inserted + } "performIO" in { await(performIO(runIO(qr4).transactional)) @@ -35,13 +43,13 @@ class PostgresAsyncContextSpec extends Spec { "cannot extract" in { object ctx extends PostgresAsyncContext(Literal, "testPostgresDB") { override def extractActionResult[O]( - returningColumn: String, + returningAction: ReturnAction, returningExtractor: ctx.Extractor[O] )(result: QueryResult) = - super.extractActionResult(returningColumn, returningExtractor)(result) + super.extractActionResult(returningAction, returningExtractor)(result) } intercept[IllegalStateException] { - ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e")) + ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e")) } ctx.close } diff --git a/quill-async/src/main/scala/io/getquill/context/async/AsyncContext.scala b/quill-async/src/main/scala/io/getquill/context/async/AsyncContext.scala index 5cd103c094..af83621215 100644 --- a/quill-async/src/main/scala/io/getquill/context/async/AsyncContext.scala +++ b/quill-async/src/main/scala/io/getquill/context/async/AsyncContext.scala @@ -4,6 +4,7 @@ import com.github.mauricio.async.db.Connection import com.github.mauricio.async.db.{ QueryResult => DBQueryResult } import com.github.mauricio.async.db.RowData import com.github.mauricio.async.db.pool.PartitionedConnectionPool + import scala.concurrent.Await import scala.concurrent.ExecutionContext import scala.concurrent.Future @@ -11,7 +12,7 @@ import scala.concurrent.duration.Duration import scala.util.Try import io.getquill.context.sql.SqlContext import io.getquill.context.sql.idiom.SqlIdiom -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, ReturnAction } import io.getquill.util.ContextLogger import io.getquill.monad.ScalaFutureIOMonad import io.getquill.context.{ Context, TranslateContext } @@ -48,9 +49,9 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection] case other => f(pool) } - protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O + protected def extractActionResult[O](returningAction: ReturnAction, extractor: Extractor[O])(result: DBQueryResult): O - protected def expandAction(sql: String, returningColumn: String) = sql + protected def expandAction(sql: String, returningAction: ReturnAction) = sql def probe(sql: String) = Try { @@ -88,12 +89,12 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection] withConnection(_.sendPreparedStatement(sql, values)).map(_.rowsAffected) } - def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Future[T] = { - val expanded = expandAction(sql, returningColumn) + def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction)(implicit ec: ExecutionContext): Future[T] = { + val expanded = expandAction(sql, returningAction) val (params, values) = prepare(Nil) logger.logQuery(sql, params) withConnection(_.sendPreparedStatement(expanded, values)) - .map(extractActionResult(returningColumn, extractor)) + .map(extractActionResult(returningAction, extractor)) } def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Future[List[Long]] = diff --git a/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala b/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala index bf4f3fdaf7..7dfff5f38f 100644 --- a/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala +++ b/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala @@ -2,6 +2,7 @@ package io.getquill.context.cassandra import io.getquill.ast.{ TraversableOperation, _ } import io.getquill.NamingStrategy +import io.getquill.context.CannotReturn import io.getquill.util.Messages.fail import io.getquill.idiom.Idiom import io.getquill.idiom.StatementInterpolator._ @@ -10,7 +11,7 @@ import io.getquill.idiom.SetContainsToken import io.getquill.idiom.Token import io.getquill.util.Interleave -object CqlIdiom extends CqlIdiom +object CqlIdiom extends CqlIdiom with CannotReturn trait CqlIdiom extends Idiom { @@ -33,6 +34,7 @@ trait CqlIdiom extends Idiom { case a: Operation => a.token case a: Action => a.token case a: Ident => a.token + case a: ExternalIdent => a.token case a: Property => a.token case a: Value => a.token case a: Function => a.body.token @@ -135,6 +137,10 @@ trait CqlIdiom extends Idiom { case e => strategy.default(e.name).token } + implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent] { + case e => strategy.default(e.name).token + } + implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] { case Assignment(alias, prop, value) => stmt"${prop.token} = ${value.token}" @@ -172,7 +178,7 @@ trait CqlIdiom extends Idiom { case Delete(table) => stmt"TRUNCATE ${table.token}" - case _: Returning => + case _: Returning | _: ReturningGenerated => fail(s"Cql doesn't support returning generated during insertion") case other => diff --git a/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala b/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala index 2cb7d80d19..e2816d1732 100644 --- a/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala +++ b/quill-cassandra/src/test/scala/io/getquill/context/cassandra/CqlIdiomSpec.scala @@ -3,6 +3,7 @@ package io.getquill.context.cassandra import io.getquill._ import io.getquill.idiom.StatementInterpolator._ import io.getquill.ast.{ Action => AstAction, _ } +import io.getquill.idiom.StringToken class CqlIdiomSpec extends Spec { @@ -38,10 +39,7 @@ class CqlIdiomSpec extends Spec { "SELECT s FROM TestEntity WHERE i = 1 ORDER BY s ASC LIMIT 1" } "returning" in { - val q = quote { - query[TestEntity].insert(_.l -> 1L).returning(_.i) - } - "mirrorContext.run(q).string" mustNot compile + "mirrorContext.run(query[TestEntity].insert(_.l -> 1L).returning(_.i)).string" mustNot compile } } @@ -392,5 +390,10 @@ class CqlIdiomSpec extends Spec { intercept[IllegalStateException](t.token(null: AstAction)) intercept[IllegalStateException](t.token(Insert(Nested(Ident("a")), Nil))) } + // not actually used anywhere but doing a sanity check here + "external ident sanity check" in { + val t = implicitly[Tokenizer[ExternalIdent]] + t.token(ExternalIdent("TestIdent")) mustBe StringToken("TestIdent") + } } } diff --git a/quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala b/quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala index 1a811bce9c..652d4b63f6 100644 --- a/quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala +++ b/quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala @@ -56,11 +56,11 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom case class ActionMirror(string: String, prepareRow: PrepareRow)(implicit val ec: ExecutionContext) - case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)(implicit val ec: ExecutionContext) + case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)(implicit val ec: ExecutionContext) case class BatchActionMirror(groups: List[(String, List[Row])])(implicit val ec: ExecutionContext) - case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext) + case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext) case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])(implicit val ec: ExecutionContext) @@ -74,8 +74,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom Future(ActionMirror(string, prepare(Row())._2)) def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], - returningColumn: String)(implicit ec: ExecutionContext) = - Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn)) + returningBehavior: ReturnAction)(implicit ec: ExecutionContext) = + Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior)) def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext) = Future { @@ -91,8 +91,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom Future { BatchActionReturningMirror[T]( groups.map { - case BatchGroupReturning(string, column, prepare) => - (string, column, prepare.map(_(Row())._2)) + case BatchGroupReturning(string, returningBehavior, prepare) => + (string, returningBehavior, prepare.map(_(Row())._2)) }, extractor ) } diff --git a/quill-core/src/main/scala/io/getquill/MirrorContext.scala b/quill-core/src/main/scala/io/getquill/MirrorContext.scala index e569f0a84a..c809868516 100644 --- a/quill-core/src/main/scala/io/getquill/MirrorContext.scala +++ b/quill-core/src/main/scala/io/getquill/MirrorContext.scala @@ -41,11 +41,11 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi case class ActionMirror(string: String, prepareRow: PrepareRow) - case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String) + case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction) case class BatchActionMirror(groups: List[(String, List[Row])]) - case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T]) + case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T]) case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T]) @@ -59,8 +59,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi ActionMirror(string, prepare(Row())._2) def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], - returningColumn: String) = - ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn) + returningBehavior: ReturnAction) = + ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior) def executeBatchAction(groups: List[BatchGroup]) = BatchActionMirror { @@ -73,8 +73,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]) = BatchActionReturningMirror[T]( groups.map { - case BatchGroupReturning(string, column, prepare) => - (string, column, prepare.map(_(Row())._2)) + case BatchGroupReturning(string, returningBehavior, prepare) => + (string, returningBehavior, prepare.map(_(Row())._2)) }, extractor ) diff --git a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala index e4b5d3875b..f9522b2639 100644 --- a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala +++ b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala @@ -1,16 +1,16 @@ package io.getquill import io.getquill.ast._ -import io.getquill.idiom.Idiom -import io.getquill.idiom.SetContainsToken -import io.getquill.idiom.Statement +import io.getquill.context.CanReturnClause +import io.getquill.idiom.{ Idiom, SetContainsToken, Statement } import io.getquill.idiom.StatementInterpolator._ import io.getquill.norm.Normalize import io.getquill.util.Interleave object MirrorIdiom extends MirrorIdiom +class MirrorIdiom extends MirrorIdiomBase with CanReturnClause -class MirrorIdiom extends Idiom { +trait MirrorIdiomBase extends Idiom { override def prepareForProbing(string: String) = string @@ -28,6 +28,7 @@ class MirrorIdiom extends Idiom { case ast: Operation => ast.token case ast: Action => ast.token case ast: Ident => ast.token + case ast: ExternalIdent => ast.token case ast: Property => ast.token case ast: Infix => ast.token case ast: OptionOperation => ast.token @@ -176,7 +177,8 @@ class MirrorIdiom extends Idiom { } implicit def propertyTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Property] = Tokenizer[Property] { - case Property(ref, name) => stmt"${scopedTokenizer(ref)}.${name.token}" + case Property(ExternalIdent(_), name) => stmt"${name.token}" + case Property(ref, name) => stmt"${scopedTokenizer(ref)}.${name.token}" } implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] { @@ -192,6 +194,10 @@ class MirrorIdiom extends Idiom { case e => stmt"${e.name.token}" } + implicit val typeTokenizer: Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent] { + case e => stmt"${e.name.token}" + } + implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] = Tokenizer[OnConflict.Excluded] { case OnConflict.Excluded(ident) => stmt"${ident.token}" } @@ -201,12 +207,13 @@ class MirrorIdiom extends Idiom { } implicit def actionTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Action] = Tokenizer[Action] { - case Update(query, assignments) => stmt"${query.token}.update(${assignments.token})" - case Insert(query, assignments) => stmt"${query.token}.insert(${assignments.token})" - case Delete(query) => stmt"${query.token}.delete" - case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})" - case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})" - case c: OnConflict => stmt"${c.token}" + case Update(query, assignments) => stmt"${query.token}.update(${assignments.token})" + case Insert(query, assignments) => stmt"${query.token}.insert(${assignments.token})" + case Delete(query) => stmt"${query.token}.delete" + case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})" + case ReturningGenerated(query, alias, body) => stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})" + case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})" + case c: OnConflict => stmt"${c.token}" } implicit def conflictTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[OnConflict] = { diff --git a/quill-core/src/main/scala/io/getquill/ReturnAction.scala b/quill-core/src/main/scala/io/getquill/ReturnAction.scala new file mode 100644 index 0000000000..ce2f622258 --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/ReturnAction.scala @@ -0,0 +1,8 @@ +package io.getquill + +sealed trait ReturnAction +object ReturnAction { + case object ReturnNothing extends ReturnAction + case class ReturnColumns(columns: List[String]) extends ReturnAction + case object ReturnRecord extends ReturnAction +} diff --git a/quill-core/src/main/scala/io/getquill/ast/Ast.scala b/quill-core/src/main/scala/io/getquill/ast/Ast.scala index be8b3a9275..9d344d1773 100644 --- a/quill-core/src/main/scala/io/getquill/ast/Ast.scala +++ b/quill-core/src/main/scala/io/getquill/ast/Ast.scala @@ -69,6 +69,10 @@ case class Function(params: List[Ident], body: Ast) extends Ast case class Ident(name: String) extends Ast +// Like identity but is but defined in a clause external to the query. Currently this is used +// for 'returning' clauses to define properties being retruned. +case class ExternalIdent(name: String) extends Ast + case class Property(ast: Ast, name: String) extends Ast sealed trait OptionOperation extends Ast @@ -134,7 +138,18 @@ case class Update(query: Ast, assignments: List[Assignment]) extends Action case class Insert(query: Ast, assignments: List[Assignment]) extends Action case class Delete(query: Ast) extends Action -case class Returning(action: Ast, alias: Ident, property: Ast) extends Action +sealed trait ReturningAction extends Action +object ReturningAction { + def unapply(returningClause: ReturningAction): Option[(Ast, Ident, Ast)] = + returningClause match { + case Returning(action, alias, property) => Some((action, alias, property)) + case ReturningGenerated(action, alias, property) => Some((action, alias, property)) + case _ => None + } + +} +case class Returning(action: Ast, alias: Ident, property: Ast) extends ReturningAction +case class ReturningGenerated(action: Ast, alias: Ident, property: Ast) extends ReturningAction case class Foreach(query: Ast, alias: Ident, body: Ast) extends Action diff --git a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala index dd6924a117..f2b362e9b0 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala @@ -12,6 +12,7 @@ trait StatefulTransformer[T] { case e: Value => apply(e) case e: Assignment => apply(e) case e: Ident => (e, this) + case e: ExternalIdent => (e, this) case e: OptionOperation => apply(e) case e: TraversableOperation => apply(e) case e: Property => apply(e) @@ -259,6 +260,10 @@ trait StatefulTransformer[T] { val (at, att) = apply(a) val (ct, ctt) = att.apply(c) (Returning(at, b, ct), ctt) + case ReturningGenerated(a, b, c) => + val (at, att) = apply(a) + val (ct, ctt) = att.apply(c) + (ReturningGenerated(at, b, ct), ctt) case Foreach(a, b, c) => val (at, att) = apply(a) val (ct, ctt) = att.apply(c) diff --git a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala index d594b4a6ed..8ea21c53cb 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala @@ -11,6 +11,7 @@ trait StatelessTransformer { case e: Assignment => apply(e) case Function(params, body) => Function(params, apply(body)) case e: Ident => e + case e: ExternalIdent => e case e: Property => apply(e) case Infix(a, b) => Infix(a, b.map(apply)) case e: OptionOperation => apply(e) @@ -108,12 +109,13 @@ trait StatelessTransformer { def apply(e: Action): Action = e match { - case Update(query, assignments) => Update(apply(query), assignments.map(apply)) - case Insert(query, assignments) => Insert(apply(query), assignments.map(apply)) - case Delete(query) => Delete(apply(query)) - case Returning(query, alias, property) => Returning(apply(query), alias, apply(property)) - case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body)) - case OnConflict(query, target, action) => OnConflict(apply(query), apply(target), apply(action)) + case Update(query, assignments) => Update(apply(query), assignments.map(apply)) + case Insert(query, assignments) => Insert(apply(query), assignments.map(apply)) + case Delete(query) => Delete(apply(query)) + case Returning(query, alias, property) => Returning(apply(query), alias, apply(property)) + case ReturningGenerated(query, alias, property) => ReturningGenerated(apply(query), alias, apply(property)) + case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body)) + case OnConflict(query, target, action) => OnConflict(apply(query), apply(target), apply(action)) } def apply(e: OnConflict.Target): OnConflict.Target = diff --git a/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala b/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala index 1a171a1315..b71e1f39ae 100644 --- a/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala +++ b/quill-core/src/main/scala/io/getquill/context/ActionMacro.scala @@ -1,12 +1,12 @@ package io.getquill.context -import io.getquill.ast._ +import io.getquill.ast._ // Only .returning(r => r.prop) or .returning(r => OneElementCaseClass(r.prop)) is allowed. import io.getquill.norm.BetaReduction import io.getquill.quotation.ReifyLiftings -import io.getquill.util.EnableReflectiveCalls import io.getquill.util.Messages._ import scala.reflect.macros.whitebox.{ Context => MacroContext } +import io.getquill.util.{ EnableReflectiveCalls, OptionalTypecheck } class ActionMacro(val c: MacroContext) extends ContextMacro @@ -135,17 +135,16 @@ class ActionMacro(val c: MacroContext) private def returningColumn = q""" - expanded.ast match { - case io.getquill.ast.Returning(_, _, io.getquill.ast.Property(_, property)) => - expanded.naming.column(property) + (expanded.ast match { + case ret: io.getquill.ast.ReturningAction => + io.getquill.norm.ExpandReturning.applyMap(ret)( + (ast, statement) => io.getquill.context.Expand(${c.prefix}, ast, statement, idiom, naming).string + )(idiom, naming) case ast => io.getquill.util.Messages.fail(s"Can't find returning column. Ast: '$$ast'") - } + }) """ - private def returningExtractor[T](implicit t: WeakTypeTag[T]) = - q"(row: ${c.prefix}.ResultRow) => implicitly[Decoder[$t]].apply(0, row)" - def bindAction(quoted: Tree): Tree = c.untypecheck { q""" @@ -158,4 +157,14 @@ class ActionMacro(val c: MacroContext) """ } + private def returningExtractor[T](implicit t: WeakTypeTag[T]) = { + OptionalTypecheck(c)(q"implicitly[${c.prefix}.Decoder[$t]]") match { + case Some(decoder) => + q"(row: ${c.prefix}.ResultRow) => $decoder.apply(0, row)" + case None => + val metaTpe = c.typecheck(tq"${c.prefix}.QueryMeta[$t]", c.TYPEmode).tpe + val meta = c.inferImplicitValue(metaTpe).orElse(q"${c.prefix}.materializeQueryMeta[$t]") + q"$meta.extract" + } + } } diff --git a/quill-core/src/main/scala/io/getquill/context/Context.scala b/quill-core/src/main/scala/io/getquill/context/Context.scala index fe91f264fc..57f7f3b5c9 100644 --- a/quill-core/src/main/scala/io/getquill/context/Context.scala +++ b/quill-core/src/main/scala/io/getquill/context/Context.scala @@ -5,8 +5,9 @@ import scala.language.experimental.macros import io.getquill.dsl.CoreDsl import io.getquill.util.Messages.fail import java.io.Closeable + import scala.util.Try -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, ReturnAction } trait Context[Idiom <: io.getquill.idiom.Idiom, Naming <: NamingStrategy] extends Closeable @@ -25,7 +26,7 @@ trait Context[Idiom <: io.getquill.idiom.Idiom, Naming <: NamingStrategy] type Extractor[T] = ResultRow => T case class BatchGroup(string: String, prepare: List[Prepare]) - case class BatchGroupReturning(string: String, column: String, prepare: List[Prepare]) + case class BatchGroupReturning(string: String, returningBehavior: ReturnAction, prepare: List[Prepare]) def probe(statement: String): Try[_] diff --git a/quill-core/src/main/scala/io/getquill/context/ReturnFieldCapability.scala b/quill-core/src/main/scala/io/getquill/context/ReturnFieldCapability.scala new file mode 100644 index 0000000000..25f6da9838 --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/context/ReturnFieldCapability.scala @@ -0,0 +1,57 @@ +package io.getquill.context + +sealed trait ReturningCapability + +/** + * Data cannot be returned Insert/Update/etc... clauses in the target database. + */ +sealed trait ReturningNotSupported extends ReturningCapability + +/** + * Returning a single field from Insert/Update/etc... clauses is supported. This is the most common + * databases e.g. MySQL, Sqlite, and H2 (although as of h2database/h2database#1972 this may change. See #1496 + * regarding this. Typically this needs to be setup in the JDBC `connection.prepareStatement(sql, Array("returnColumn"))`. + */ +sealed trait ReturningSingleFieldSupported extends ReturningCapability + +/** + * Returning multiple columns from Insert/Update/etc... clauses is supported. This generally means that + * columns besides auto-incrementing ones can be returned. This is supported by Oracle. + * In JDBC, the following is done: + * `connection.prepareStatement(sql, Array("column1, column2, ..."))`. + */ +sealed trait ReturningMultipleFieldSupported extends ReturningCapability + +/** + * An actual `RETURNING` clause is supported in the SQL dialect of the specified database e.g. Postgres. + * this typically means that columns returned from Insert/Update/etc... clauses can have other database + * operations done on them such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. + * In JDBC, the following is done: + * `connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS))`. + */ +sealed trait ReturningClauseSupported extends ReturningCapability + +object ReturningNotSupported extends ReturningNotSupported +object ReturningSingleFieldSupported extends ReturningSingleFieldSupported +object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported +object ReturningClauseSupported extends ReturningClauseSupported + +trait Capabilities { + def idiomReturningCapability: ReturningCapability +} + +trait CanReturnClause extends Capabilities { + override def idiomReturningCapability: ReturningClauseSupported = ReturningClauseSupported +} + +trait CanReturnField extends Capabilities { + override def idiomReturningCapability: ReturningSingleFieldSupported = ReturningSingleFieldSupported +} + +trait CanReturnMultiField extends Capabilities { + override def idiomReturningCapability: ReturningMultipleFieldSupported = ReturningMultipleFieldSupported +} + +trait CannotReturn extends Capabilities { + override def idiomReturningCapability: ReturningNotSupported = ReturningNotSupported +} diff --git a/quill-core/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala b/quill-core/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala index ec4303f3a2..60d91cd8c5 100644 --- a/quill-core/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala +++ b/quill-core/src/main/scala/io/getquill/dsl/DynamicQueryDSL.scala @@ -357,6 +357,11 @@ trait DynamicQueryDsl { DynamicActionReturning(splice(Returning(q.ast, v, f(splice(v)).ast))) } + def returningGenerated[R](f: Quoted[E] => Quoted[R]): DynamicActionReturning[E, R] = + withFreshIdent { v => + DynamicActionReturning(splice(ReturningGenerated(q.ast, v, f(splice(v)).ast))) + } + def onConflictIgnore: DynamicInsert[E] = dyn(OnConflict(DynamicInsert.this.q.ast, OnConflict.NoTarget, OnConflict.Ignore)) diff --git a/quill-core/src/main/scala/io/getquill/dsl/MetaDslMacro.scala b/quill-core/src/main/scala/io/getquill/dsl/MetaDslMacro.scala index 97867bc7ee..306d8739d0 100644 --- a/quill-core/src/main/scala/io/getquill/dsl/MetaDslMacro.scala +++ b/quill-core/src/main/scala/io/getquill/dsl/MetaDslMacro.scala @@ -1,11 +1,9 @@ package io.getquill.dsl -import io.getquill.Embedded import io.getquill.util.Messages._ import scala.reflect.macros.whitebox.{ Context => MacroContext } -import io.getquill.util.OptionalTypecheck -class MetaDslMacro(val c: MacroContext) { +class MetaDslMacro(val c: MacroContext) extends ValueComputation { import c.universe._ def schemaMeta[T](entity: Tree, columns: Tree*)(implicit t: WeakTypeTag[T]): Tree = @@ -137,132 +135,4 @@ class MetaDslMacro(val c: MacroContext) { """ } } - - def flatten(base: Tree, value: Value): List[Tree] = { - def nest(tree: Tree, term: Option[TermName]) = - term match { - case None => tree - case Some(term) => q"$tree.$term" - } - - def apply(base: Tree, params: List[List[Value]]): List[Tree] = - params.flatten.flatMap(flatten(base, _)) - - value match { - case Scalar(term, _, _, _) => - List(nest(base, term)) - case Nested(term, _, params, optional) => - if (optional) - apply(q"v", params) - .map(body => q"${nest(base, term)}.map(v => $body)") - else - apply(nest(base, term), params) - } - } - - sealed trait Value { - val term: Option[TermName] - def nestedAndOptional: Boolean - } - case class Nested(term: Option[TermName], tpe: Type, params: List[List[Value]], optional: Boolean) extends Value { - def nestedAndOptional: Boolean = optional - } - case class Scalar(term: Option[TermName], tpe: Type, decoder: Tree, optional: Boolean) extends Value { - def nestedAndOptional: Boolean = false - } - - private def is[T](tpe: Type)(implicit t: TypeTag[T]) = - tpe <:< t.tpe - - private def value(encoding: String, tpe: Type, exclude: Tree*): Value = { - - def nest(tpe: Type, term: Option[TermName]): Nested = - caseClassConstructor(tpe) match { - case None => - c.fail(s"Found the embedded '$tpe', but it is not a case class") - case Some(constructor) => - val params = - constructor.paramLists.map { - _.map { param => - apply( - param.typeSignature.asSeenFrom(tpe, tpe.typeSymbol), - Some(param.name.toTermName), - nested = !isTuple(tpe) - ) - } - } - Nested(term, tpe, params, optional = false) - } - - def apply(tpe: Type, term: Option[TermName], nested: Boolean): Value = { - OptionalTypecheck(c)(q"implicitly[${c.prefix}.${TypeName(encoding)}[$tpe]]") match { - case Some(encoding) => - Scalar(term, tpe, encoding, optional = is[Option[Any]](tpe)) - case None => - - def value(tpe: Type) = - tpe match { - case tpe if !is[Embedded](tpe) && nested => - c.fail( - s"""Can't find implicit `$encoding[$tpe]`. Please, do one of the following things: - |1. ensure that implicit `$encoding[$tpe]` is provided and there are no other conflicting implicits; - |2. make `$tpe` `Embedded` case class or `AnyVal`. - """.stripMargin - ) - - case tpe => - nest(tpe, term) - } - - if (isNone(tpe)) { - c.fail("Cannot handle untyped `None` objects. Use a cast e.g. `None:Option[String]` or `Option.empty`.") - } else if (is[Option[Any]](tpe)) { - value(tpe.typeArgs.head).copy(optional = true) - } else { - value(tpe) - } - } - } - - def filterExcludes(value: Value) = { - val paths = - exclude.map { - case f: Function => - def path(tree: Tree): List[TermName] = - tree match { - case q"$a.$b" => path(a) :+ b - case q"$a.map[$t]($b => $c)" => path(a) ++ path(c) - case _ => Nil - } - path(f.body) - } - - def filter(value: Value, path: List[TermName] = Nil): Option[Value] = - value match { - case value if paths.contains(path ++ value.term) => - None - case Nested(term, tpe, params, optional) => - Some(Nested(term, tpe, params.map(_.flatMap(filter(_, path ++ term))), optional)) - case value => - Some(value) - } - - filter(value).getOrElse { - c.fail("Can't exclude all entity properties") - } - } - - filterExcludes(apply(tpe, term = None, nested = false)) - } - - private def isNone(tpe: Type) = - tpe.typeSymbol.name.toString == "None" - - private def isTuple(tpe: Type) = - tpe.typeSymbol.name.toString.startsWith("Tuple") - - private def caseClassConstructor(t: Type) = - t.members.collect { - case m: MethodSymbol if m.isPrimaryConstructor => m - }.headOption } \ No newline at end of file diff --git a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala index e2965777c9..fe3bfecc23 100644 --- a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala +++ b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala @@ -5,7 +5,7 @@ import io.getquill.quotation.NonQuotedException import scala.annotation.compileTimeOnly -private[dsl] trait QueryDsl { +private[getquill] trait QueryDsl { dsl: CoreDsl => def query[T]: EntityQuery[T] = macro QueryDslMacro.expandEntity[T] @@ -73,6 +73,7 @@ private[dsl] trait QueryDsl { def groupBy[R](f: T => R): Query[(R, Query[T])] + def value[U >: T]: Option[T] def min[U >: T]: Option[T] def max[U >: T]: Option[T] def avg[U >: T](implicit n: Numeric[U]): Option[BigDecimal] @@ -130,6 +131,9 @@ private[dsl] trait QueryDsl { @compileTimeOnly(NonQuotedException.message) def returning[R](f: E => R): ActionReturning[E, R] = NonQuotedException() + @compileTimeOnly(NonQuotedException.message) + def returningGenerated[R](f: E => R): ActionReturning[E, R] = NonQuotedException() + @compileTimeOnly(NonQuotedException.message) def onConflictIgnore: Insert[E] = NonQuotedException() diff --git a/quill-core/src/main/scala/io/getquill/dsl/ValueComputation.scala b/quill-core/src/main/scala/io/getquill/dsl/ValueComputation.scala new file mode 100644 index 0000000000..989ebcf232 --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/dsl/ValueComputation.scala @@ -0,0 +1,139 @@ +package io.getquill.dsl + +import scala.reflect.macros.whitebox.{ Context => MacroContext } +import io.getquill.Embedded +import io.getquill.util.OptionalTypecheck +import io.getquill.util.Messages._ + +trait ValueComputation { + val c: MacroContext + import c.universe._ + + sealed trait Value { + val term: Option[TermName] + def nestedAndOptional: Boolean + } + case class Nested(term: Option[TermName], tpe: Type, params: List[List[Value]], optional: Boolean) extends Value { + def nestedAndOptional: Boolean = optional + } + case class Scalar(term: Option[TermName], tpe: Type, decoder: Tree, optional: Boolean) extends Value { + def nestedAndOptional: Boolean = false + } + + private def is[T](tpe: Type)(implicit t: TypeTag[T]) = + tpe <:< t.tpe + + private[getquill] def value(encoding: String, tpe: Type, exclude: Tree*): Value = { + + def nest(tpe: Type, term: Option[TermName]): Nested = + caseClassConstructor(tpe) match { + case None => + c.fail(s"Found the embedded '$tpe', but it is not a case class") + case Some(constructor) => + val params = + constructor.paramLists.map { + _.map { param => + apply( + param.typeSignature.asSeenFrom(tpe, tpe.typeSymbol), + Some(param.name.toTermName), + nested = !isTuple(tpe) + ) + } + } + Nested(term, tpe, params, optional = false) + } + + def apply(tpe: Type, term: Option[TermName], nested: Boolean): Value = { + OptionalTypecheck(c)(q"implicitly[${c.prefix}.${TypeName(encoding)}[$tpe]]") match { + case Some(encoding) => + Scalar(term, tpe, encoding, optional = is[Option[Any]](tpe)) + case None => + + def value(tpe: Type) = + tpe match { + case tpe if !is[Embedded](tpe) && nested => + c.fail( + s"""Can't find implicit `$encoding[$tpe]`. Please, do one of the following things: + |1. ensure that implicit `$encoding[$tpe]` is provided and there are no other conflicting implicits; + |2. make `$tpe` `Embedded` case class or `AnyVal`. + """.stripMargin + ) + + case tpe => + nest(tpe, term) + } + + if (isNone(tpe)) { + c.fail("Cannot handle untyped `None` objects. Use a cast e.g. `None:Option[String]` or `Option.empty`.") + } else if (is[Option[Any]](tpe)) { + value(tpe.typeArgs.head).copy(optional = true) + } else { + value(tpe) + } + } + } + + def filterExcludes(value: Value) = { + val paths = + exclude.map { + case f: Function => + def path(tree: Tree): List[TermName] = + tree match { + case q"$a.$b" => path(a) :+ b + case q"$a.map[$t]($b => $c)" => path(a) ++ path(c) + case _ => Nil + } + path(f.body) + } + + def filter(value: Value, path: List[TermName] = Nil): Option[Value] = + value match { + case value if paths.contains(path ++ value.term) => + None + case Nested(term, tpe, params, optional) => + Some(Nested(term, tpe, params.map(_.flatMap(filter(_, path ++ term))), optional)) + case value => + Some(value) + } + + filter(value).getOrElse { + c.fail("Can't exclude all entity properties") + } + } + + filterExcludes(apply(tpe, term = None, nested = false)) + } + + private def isNone(tpe: Type) = + tpe.typeSymbol.name.toString == "None" + + private def isTuple(tpe: Type) = + tpe.typeSymbol.name.toString.startsWith("Tuple") + + private def caseClassConstructor(t: Type) = + t.members.collect { + case m: MethodSymbol if m.isPrimaryConstructor => m + }.headOption + + def flatten(base: Tree, value: Value): List[Tree] = { + def nest(tree: Tree, term: Option[TermName]) = + term match { + case None => tree + case Some(term) => q"$tree.$term" + } + + def apply(base: Tree, params: List[List[Value]]): List[Tree] = + params.flatten.flatMap(flatten(base, _)) + + value match { + case Scalar(term, _, _, _) => + List(nest(base, term)) + case Nested(term, _, params, optional) => + if (optional) + apply(q"v", params) + .map(body => q"${nest(base, term)}.map(v => $body)") + else + apply(nest(base, term), params) + } + } +} diff --git a/quill-core/src/main/scala/io/getquill/idiom/Idiom.scala b/quill-core/src/main/scala/io/getquill/idiom/Idiom.scala index f5166bbd01..91df617a7f 100644 --- a/quill-core/src/main/scala/io/getquill/idiom/Idiom.scala +++ b/quill-core/src/main/scala/io/getquill/idiom/Idiom.scala @@ -2,8 +2,9 @@ package io.getquill.idiom import io.getquill.ast._ import io.getquill.NamingStrategy +import io.getquill.context.Capabilities -trait Idiom { +trait Idiom extends Capabilities { def emptySetContainsToken(field: Token): Token = StringToken("FALSE") diff --git a/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala b/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala index 9ceaf21a70..7f448a22a7 100644 --- a/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala +++ b/quill-core/src/main/scala/io/getquill/norm/BetaReduction.scala @@ -58,6 +58,10 @@ case class BetaReduction(map: collection.Map[Ast, Ast]) val t = BetaReduction(map - alias) Returning(apply(action), alias, t(prop)) + case ReturningGenerated(action, alias, prop) => + val t = BetaReduction(map - alias) + ReturningGenerated(apply(action), alias, t(prop)) + case other => super.apply(other) } diff --git a/quill-core/src/main/scala/io/getquill/norm/ExpandReturning.scala b/quill-core/src/main/scala/io/getquill/norm/ExpandReturning.scala new file mode 100644 index 0000000000..6d4b1d69eb --- /dev/null +++ b/quill-core/src/main/scala/io/getquill/norm/ExpandReturning.scala @@ -0,0 +1,61 @@ +package io.getquill.norm + +import io.getquill.ReturnAction.ReturnColumns +import io.getquill.{ NamingStrategy, ReturnAction } +import io.getquill.ast._ +import io.getquill.context.{ + ReturningClauseSupported, + ReturningMultipleFieldSupported, + ReturningNotSupported, + ReturningSingleFieldSupported +} +import io.getquill.idiom.{ Idiom, Statement } + +/** + * Take the `.returning` part in a query that contains it and return the array of columns + * representing of the returning seccovtion with any other operations etc... that they might contain. + */ +object ExpandReturning { + + def applyMap(returning: ReturningAction)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = { + val initialExpand = ExpandReturning.apply(returning)(idiom, naming) + + idiom.idiomReturningCapability match { + case ReturningClauseSupported => + ReturnAction.ReturnRecord + case ReturningMultipleFieldSupported => + ReturnColumns(initialExpand.map { case (ast, statement) => f(ast, statement) }) + case ReturningSingleFieldSupported => + if (initialExpand.length == 1) + ReturnColumns(initialExpand.map { case (ast, statement) => f(ast, statement) }) + else + throw new IllegalArgumentException(s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified.") + case ReturningNotSupported => + throw new IllegalArgumentException(s"RETURNING columns are not allowed in the ${idiom} dialect.") + } + } + + def apply(returning: ReturningAction)(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = { + val ReturningAction(_, alias, properties) = returning + + // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1)))) + // => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + val dePropertized = + Transform(properties) { + case `alias` => ExternalIdent(alias.name) + } + + val aliasName = alias.name + + // Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + // => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))) + val deTuplified = dePropertized match { + case Tuple(values) => values + case CaseClass(values) => values.map(_._2) + case other => List(other) + } + + implicit val namingStrategy: NamingStrategy = naming + deTuplified.map(v => idiom.translate(v)) + } +} diff --git a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala index 9ddf1d8ccf..fc37445feb 100644 --- a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala +++ b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala @@ -1,37 +1,119 @@ package io.getquill.norm import io.getquill.ast._ +import io.getquill.norm.capture.AvoidAliasConflict +/** + * When actions are used with a `.returning` clause, remove the columns used in the returning clause from the action. + * E.g. for `insert(Person(id, name)).returning(_.id)` remove the `id` column from the original insert. + */ object NormalizeReturning { def apply(e: Action): Action = { e match { - case Returning(a: Action, alias, body) => Returning(apply(a, body), alias, body) - case _ => e + case ReturningGenerated(a: Action, alias, body) => + // De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action. + // This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres. + // For example: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally + // be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case. + // In order to fix this, we need to change `t` into a different alias. + val newBody = dealiasBody(body, alias) + ReturningGenerated(apply(a, newBody, alias), alias, newBody) + + // For a regular return clause, do not need to exclude assignments from insertion however, we still + // need to de-alias the Action body in case conflicts result. For example the following query: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // would incorrectly be interpreted as: + // INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t + // whereas it should be: + // INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1) + case Returning(a: Action, alias, body) => + val newBody = dealiasBody(body, alias) + Returning(a, alias, newBody) + + case _ => e } } - private def apply(e: Action, body: Ast): Action = e match { - case Insert(query, assignments) => Insert(query, filterReturnedColumn(assignments, body)) - case Update(query, assignments) => Update(query, filterReturnedColumn(assignments, body)) - case OnConflict(a: Action, target, act) => OnConflict(apply(a, body), target, act) + /** + * In some situations, a query can exist inside of a `returning` clause. In this case, we need to rename + * if the aliases used in that query override the alias used in the `returning` clause otherwise + * they will be treated as returning-clause aliases ExpandReturning (i.e. they will become ExternalAlias instances) + * and later be tokenized incorrectly. + */ + private def dealiasBody(body: Ast, alias: Ident): Ast = + Transform(body) { + case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias)) + } + + private def apply(e: Action, body: Ast, returningIdent: Ident): Action = e match { + case Insert(query, assignments) => Insert(query, filterReturnedColumn(assignments, body, returningIdent)) + case Update(query, assignments) => Update(query, filterReturnedColumn(assignments, body, returningIdent)) + case OnConflict(a: Action, target, act) => OnConflict(apply(a, body, returningIdent), target, act) case _ => e } - private def filterReturnedColumn(assignments: List[Assignment], column: Ast): List[Assignment] = - assignments.flatMap(filterReturnedColumn(_, column)) + private def filterReturnedColumn(assignments: List[Assignment], column: Ast, returningIdent: Ident): List[Assignment] = + assignments.flatMap(filterReturnedColumn(_, column, returningIdent)) + + /** + * In situations like Property(Property(ident, foo), bar) pull out the inner-most ident + */ + object NestedProperty { + def unapply(ast: Property): Option[Ast] = { + ast match { + case p @ Property(subAst, _) => Some(innerMost(subAst)) + case _ => None + } + } + + private def innerMost(ast: Ast): Ast = ast match { + case Property(inner, _) => innerMost(inner) + case other => other + } + } + + /** + * Remove the specified column from the assignment. For example, in a query like `insert(Person(id, name)).returning(r => r.id)` + * we need to remove the `id` column from the insertion. The value of the `column:Ast` in this case will be `Property(Ident(r), id)` + * and the values fo the assignment `p1` property will typically be `v.id` and `v.name` (the `v` variable is a default + * used for `insert` queries). + */ + private def filterReturnedColumn(assignment: Assignment, column: Ast, returningIdent: Ident): Option[Assignment] = + assignment match { + case Assignment(_, p1: Property, _) => { + // Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but + // if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2)) + // or it can even be a query since queries are allowed to be in return sections e.g: + // query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max)) + // In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment + // in order to know what to exclude. + val matchedProps = + CollectAst(column) { + case prop @ NestedProperty(`returningIdent`) => prop + } - private def filterReturnedColumn(assignment: Assignment, column: Ast): Option[Assignment] = - (assignment, column) match { - case (Assignment(_, p1: Property, _), p2: Property) if p1.name == p2.name && isSameProperties(p1, p2) => - None - case (assignment, column) => - Some(assignment) + if (matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))) + None + else + Some(assignment) + } + case assignment => Some(assignment) } + /** + * Is it the same property (but possibly of a different identity). E.g. `p.foo.bar` and `v.foo.bar` + */ private def isSameProperties(p1: Property, p2: Property): Boolean = (p1.ast, p2.ast) match { - case (_: Ident, _: Ident) => p1.name == p2.name - case (pp1: Property, pp2: Property) => isSameProperties(pp1, pp2) - case _ => false + case (_: Ident, _: Ident) => + p1.name == p2.name + // If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the + // outer properties are the same before moving on to the inner ones. + case (pp1: Property, pp2: Property) if (p1.name == p2.name) => + isSameProperties(pp1, pp2) + case _ => + false } } diff --git a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala index 1809b9323b..020a67796c 100644 --- a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala +++ b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala @@ -34,6 +34,13 @@ object RenameProperties extends StatelessTransformer { val bodyr = BetaReduction(body, replace: _*) (Returning(action, alias, bodyr), schema) } + case ReturningGenerated(action: Action, alias, body) => + applySchema(action) match { + case (action, schema) => + val replace = replacements(alias, schema) + val bodyr = BetaReduction(body, replace: _*) + (ReturningGenerated(action, alias, bodyr), schema) + } case OnConflict(a: Action, target, act) => applySchema(a) match { case (action, schema) => (OnConflict(action, target, act), schema) diff --git a/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala b/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala index b44497f8ab..e660bba623 100644 --- a/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala +++ b/quill-core/src/main/scala/io/getquill/norm/capture/AvoidAliasConflict.scala @@ -1,7 +1,7 @@ package io.getquill.norm.capture import io.getquill.ast.{ Entity, Filter, FlatJoin, FlatMap, GroupBy, Ident, Join, Map, Query, SortBy, StatefulTransformer, _ } -import io.getquill.norm.BetaReduction +import io.getquill.norm.{ BetaReduction, Normalize } private[getquill] case class AvoidAliasConflict(state: collection.Set[Ident]) extends StatefulTransformer[collection.Set[Ident]] { @@ -143,4 +143,11 @@ private[getquill] object AvoidAliasConflict { def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = { AvoidAliasConflict(dangerousVariables).applyForeach(f) } + + def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = { + AvoidAliasConflict(dangerousVariables).apply(q) match { + // Propagate aliasing changes to the rest of the query + case (q, _) => Normalize(q) + } + } } diff --git a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala index da04ddfc8c..0060e486e5 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala @@ -55,6 +55,8 @@ case class FreeVariables(state: State) action match { case q @ Returning(a, b, c) => (q, free(a, b, c)) + case q @ ReturningGenerated(a, b, c) => + (q, free(a, b, c)) case other => super.apply(other) } diff --git a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala index 04aaf282bb..1a84604422 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala @@ -149,12 +149,13 @@ trait Liftables { } implicit val actionLiftable: Liftable[Action] = Liftable[Action] { - case Update(a, b) => q"$pack.Update($a, $b)" - case Insert(a, b) => q"$pack.Insert($a, $b)" - case Delete(a) => q"$pack.Delete($a)" - case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)" - case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)" - case OnConflict(a, b, c) => q"$pack.OnConflict($a, $b, $c)" + case Update(a, b) => q"$pack.Update($a, $b)" + case Insert(a, b) => q"$pack.Insert($a, $b)" + case Delete(a) => q"$pack.Delete($a)" + case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)" + case ReturningGenerated(a, b, c) => q"$pack.ReturningGenerated($a, $b, $c)" + case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)" + case OnConflict(a, b, c) => q"$pack.OnConflict($a, $b, $c)" } implicit val conflictTargetLiftable: Liftable[OnConflict.Target] = Liftable[OnConflict.Target] { @@ -180,6 +181,9 @@ trait Liftables { implicit val identLiftable: Liftable[Ident] = Liftable[Ident] { case Ident(a) => q"$pack.Ident($a)" } + implicit val externalIdentLiftable: Liftable[ExternalIdent] = Liftable[ExternalIdent] { + case ExternalIdent(a) => q"$pack.ExternalIdent($a)" + } implicit val liftLiftable: Liftable[Lift] = Liftable[Lift] { case ScalarValueLift(a, b: Tree, c: Tree) => q"$pack.ScalarValueLift($a, $b, $c)" diff --git a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala index 6b98289e38..963891b681 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala @@ -3,18 +3,20 @@ package io.getquill.quotation import scala.reflect.ClassTag import io.getquill.ast._ import io.getquill.Embedded +import io.getquill.context.{ ReturningMultipleFieldSupported, _ } import io.getquill.norm.BetaReduction import io.getquill.util.Messages.RichContext -import io.getquill.util.Interleave -import io.getquill.dsl.CoreDsl +import io.getquill.dsl.{ CoreDsl, QueryDsl, ValueComputation } import io.getquill.norm.capture.AvoidAliasConflict +import io.getquill.idiom.Idiom import scala.annotation.tailrec import scala.collection.immutable.StringOps import scala.reflect.macros.TypecheckException import io.getquill.ast.Implicits._ +import io.getquill.util.Interleave -trait Parsing { +trait Parsing extends ValueComputation { this: Quotation => import c.universe.{ Ident => _, Constant => _, Function => _, If => _, Block => _, _ } @@ -197,6 +199,7 @@ trait Parsing { case q"$source.groupBy[$t](($alias) => $body)" if (is[CoreDsl#Query[Any]](source)) => GroupBy(astParser(source), identParser(alias), astParser(body)) + case q"$a.value[$t]" if (is[CoreDsl#Query[Any]](a)) => astParser(a) case q"$a.min[$t]" if (is[CoreDsl#Query[Any]](a)) => Aggregation(AggregationOperator.`min`, astParser(a)) case q"$a.max[$t]" if (is[CoreDsl#Query[Any]](a)) => Aggregation(AggregationOperator.`max`, astParser(a)) case q"$a.avg[$t]($n)" if (is[CoreDsl#Query[Any]](a)) => Aggregation(AggregationOperator.`avg`, astParser(a)) @@ -231,7 +234,7 @@ trait Parsing { Distinct(astParser(source)) case q"$source.nested" if (is[CoreDsl#Query[Any]](source)) => - Nested(astParser(source)) + io.getquill.ast.Nested(astParser(source)) } @@ -651,6 +654,13 @@ trait Parsing { era =:= typeOf[Option[Any]] || era =:= typeOf[Some[Any]] || era =:= typeOf[None.type] } + object ClassTypeRefMatch { + def unapply(tpe: Type) = tpe match { + case TypeRef(_, cls, args) if (cls.isClass) => Some((cls.asClass, args)) + case _ => None + } + } + /** * Recursively traverse an `Option[T]` or `Option[Option[T]]`, or `Option[Option[Option[T]]]` etc... * until we find the `T`. Stop at a specified depth. @@ -760,6 +770,73 @@ trait Parsing { tpe.paramLists(0).map(_.name.toString) } + private[getquill] def currentIdiom: Option[Type] = { + c.prefix.tree.tpe + .baseClasses + .flatMap { baseClass => + val baseClassTypeArgs = c.prefix.tree.tpe.baseType(baseClass).typeArgs + baseClassTypeArgs.find { typeArg => + typeArg <:< typeOf[Idiom] + } + } + .headOption + } + + private[getquill] def idiomReturnCapability: ReturningCapability = { + val returnAfterInsertType = + currentIdiom + .toSeq + .flatMap(_.members) + .collect { + case ms: MethodSymbol if (ms.name.toString == "idiomReturningCapability") => Some(ms.returnType) + } + .headOption + .flatten + + returnAfterInsertType match { + case Some(returnType) if (returnType =:= typeOf[ReturningClauseSupported]) => ReturningClauseSupported + case Some(returnType) if (returnType =:= typeOf[ReturningSingleFieldSupported]) => ReturningSingleFieldSupported + case Some(returnType) if (returnType =:= typeOf[ReturningMultipleFieldSupported]) => ReturningMultipleFieldSupported + case Some(returnType) if (returnType =:= typeOf[ReturningNotSupported]) => ReturningNotSupported + // Since most SQL Dialects support returing a single field (that is auto-incrementing) allow a return + // of a single field in the case that a dialect is not actually specified. E.g. when SqlContext[_, _] + // is used to define `returning` clauses. + case other => ReturningSingleFieldSupported + } + } + + implicit class InsertReturnCapabilityExtension(capability: ReturningCapability) { + def verifyAst(returnBody: Ast) = capability match { + case ReturningClauseSupported => + // Only .returning(r => r.prop) or .returning(r => OneElementCaseClass(r.prop1..., propN)) or .returning(r => (r.prop1..., propN)) (well actually it's prop22) is allowed. + case ReturningMultipleFieldSupported => + returnBody match { + case CaseClass(list) if (list.forall { + case (_, Property(_, _)) => true + case _ => false + }) => + case Tuple(list) if (list.forall { + case Property(_, _) => true + case _ => false + }) => + case Property(_, _) => + case other => + c.fail(s"${currentIdiom.map(n => s"The dialect ${n} only allows").getOrElse("Unspecified dialects only allow")} single a single property or multiple properties in case classes / tuples in 'returning' clauses ${other}.") + } + // Only .returning(r => r.prop) or .returning(r => OneElementCaseClass(r.prop)) is allowed. + case ReturningSingleFieldSupported => + returnBody match { + case Property(_, _) => + case other => + c.fail(s"${currentIdiom.map(n => s"The dialect ${n} only allows").getOrElse("Unspecified dialects only allow")} single, auto-incrementing columns in 'returning' clauses.") + } + // This is not actually the case for unspecified dialects (e.g. when doing `returning` from `SqlContext[_, _]` but error message + // says what it would say if either case happened. Otherwise doing currentIdiom.get would be allowed which is bad practice. + case ReturningNotSupported => + c.fail(s"${currentIdiom.map(n => s"The dialect ${n} does").getOrElse("Unspecified dialects do")} not allow 'returning' clauses.") + } + } + val actionParser: Parser[Ast] = Parser[Ast] { case q"$query.$method(..$assignments)" if (method.decodedName.toString == "update") => Update(astParser(query), assignments.map(assignmentParser(_))) @@ -767,14 +844,80 @@ trait Parsing { Insert(astParser(query), assignments.map(assignmentParser(_))) case q"$query.delete" => Delete(astParser(query)) + case q"$action.returning[$r]" => + c.fail(s"A 'returning' clause must have arguments.") case q"$action.returning[$r](($alias) => $body)" => - Returning(astParser(action), identParser(alias), astParser(body)) + val ident = identParser(alias) + val bodyAst = reprocessReturnClause(ident, astParser(body), action) + // Verify that the idiom supports this type of returning clause + idiomReturnCapability match { + case ReturningMultipleFieldSupported | ReturningClauseSupported => + case ReturningSingleFieldSupported => + c.fail(s"The 'returning' clause is not supported by the ${currentIdiom.getOrElse("specified")} idiom. Use 'returningGenerated' instead.") + case ReturningNotSupported => + c.fail(s"The 'returning' or 'returningGenerated' clauses are not supported by the ${currentIdiom.getOrElse("specified")} idiom.") + } + // Verify that the AST in the returning-body is valid + idiomReturnCapability.verifyAst(bodyAst) + Returning(astParser(action), ident, bodyAst) + + case q"$action.returningGenerated[$r](($alias) => $body)" => + val ident = identParser(alias) + val bodyAst = reprocessReturnClause(ident, astParser(body), action) + // Verify that the idiom supports this type of returning clause + idiomReturnCapability match { + case ReturningNotSupported => + c.fail(s"The 'returning' or 'returningGenerated' clauses are not supported by the ${currentIdiom.getOrElse("specified")} idiom.") + case _ => + } + // Verify that the AST in the returning-body is valid + idiomReturnCapability.verifyAst(bodyAst) + ReturningGenerated(astParser(action), ident, bodyAst) + case q"$query.foreach[$t1, $t2](($alias) => $body)($f)" if (is[CoreDsl#Query[Any]](query)) => // If there are actions inside the subtree, we need to do some additional sanitizations // of the variables so that their content will not collide with code that we have generated. AvoidAliasConflict.sanitizeVariables(Foreach(astParser(query), identParser(alias), astParser(body)), dangerousVariables) } + /** + * In situations where the a `.returning` clause returns the initial record i.e. `.returning(r => r)`, + * we need to expand out the record into it's fields i.e. `.returning(r => (r.foo, r.bar))` + * otherwise the tokenizer would be force to pass `RETURNING *` to the SQL which is a problem + * because the fields inside of the record could arrive out of order in the result set + * (e.g. arrive as `r.bar, r.foo`). Use use the value/flatten methods in order to expand + * the case-class out into fields. + */ + private def reprocessReturnClause(ident: Ident, originalBody: Ast, action: Tree) = { + val actionType = typecheckUnquoted(action) + + (ident == originalBody, actionType.tpe) match { + // Note, tuples are also case classes so this also matches for tuples + case (true, ClassTypeRefMatch(cls, List(arg))) if (cls == asClass[QueryDsl#Insert[_]] && isTypeCaseClass(arg)) => + + val elements = flatten(q"${TermName(ident.name)}", value("Decoder", arg)) + if (elements.size == 0) c.fail("Case class in the 'returning' clause has no values") + + // Create an intermediate scala API that can then be parsed into clauses. This needs to be + // typechecked first in order to function properly. + val tpe = c.typecheck( + q"((${TermName(ident.name)}:$arg) => io.getquill.dsl.UnlimitedTuple.apply(..$elements))" + ) + val newBody = + tpe match { + case q"(($newAlias) => $newBody)" => newBody + case _ => c.fail("Could not process whole-record 'returning' clause. Consider trying to return individual columns.") + } + astParser(newBody) + + case (true, _) => + c.fail("Could not process whole-record 'returning' clause. Consider trying to return individual columns.") + + case _ => + originalBody + } + } + private val assignmentParser: Parser[Assignment] = Parser[Assignment] { case q"((${ identParser(i1) }) => $pack.Predef.ArrowAssoc[$t]($prop).$arrow[$v]($value))" => checkTypes(prop, value) diff --git a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala index dd784b798e..dea27fd1a1 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala @@ -158,6 +158,7 @@ trait Unliftables { case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b) case q"$pack.Delete.apply(${ a: Ast })" => Delete(a) case q"$pack.Returning.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Returning(a, b, c) + case q"$pack.ReturningGenerated.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => ReturningGenerated(a, b, c) case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c) case q"$pack.OnConflict.apply(${ a: Ast }, ${ b: OnConflict.Target }, ${ c: OnConflict.Action })" => OnConflict(a, b, c) } @@ -185,6 +186,9 @@ trait Unliftables { implicit val identUnliftable: Unliftable[Ident] = Unliftable[Ident] { case q"$pack.Ident.apply(${ a: String })" => Ident(a) } + implicit val externalIdentUnliftable: Unliftable[ExternalIdent] = Unliftable[ExternalIdent] { + case q"$pack.ExternalIdent.apply(${ a: String })" => ExternalIdent(a) + } implicit val liftUnliftable: Unliftable[Lift] = Unliftable[Lift] { case q"$pack.ScalarValueLift.apply(${ a: String }, $b, $c)" => ScalarValueLift(a, b, c) diff --git a/quill-core/src/test/scala/io/getquill/MirrorIdiomExt.scala b/quill-core/src/test/scala/io/getquill/MirrorIdiomExt.scala new file mode 100644 index 0000000000..9bfaef3e24 --- /dev/null +++ b/quill-core/src/test/scala/io/getquill/MirrorIdiomExt.scala @@ -0,0 +1,25 @@ +package io.getquill + +import io.getquill.context.{ CanReturnField, CanReturnMultiField, CannotReturn } + +class TestMirrorContextTemplate[Dialect <: MirrorIdiomBase, Naming <: NamingStrategy](dialect: Dialect, naming: Naming) + extends MirrorContext[Dialect, Naming](dialect, naming) with TestEntities { + + def withDialect[I <: MirrorIdiomBase](dialect: I)(f: TestMirrorContextTemplate[I, Naming] => Any): Unit = { + val ctx = new TestMirrorContextTemplate[I, Naming](dialect, naming) + f(ctx) + ctx.close + } +} + +// Mirror idiom supporting only single-field returning clauses +trait MirrorIdiomReturningSingle extends MirrorIdiomBase with CanReturnField +object MirrorIdiomReturningSingle extends MirrorIdiomReturningSingle + +// Mirror idiom supporting only multi-field returning clauses +trait MirrorIdiomReturningMulti extends MirrorIdiomBase with CanReturnMultiField +object MirrorIdiomReturningMulti extends MirrorIdiomReturningMulti + +// Mirror idiom not supporting any returns +trait MirrorIdiomReturningUnsupported extends MirrorIdiomBase with CannotReturn +object MirrorIdiomReturningUnsupported extends MirrorIdiomReturningUnsupported \ No newline at end of file diff --git a/quill-core/src/test/scala/io/getquill/TestEntities.scala b/quill-core/src/test/scala/io/getquill/TestEntities.scala index c3314201c8..50a9f0c9b6 100644 --- a/quill-core/src/test/scala/io/getquill/TestEntities.scala +++ b/quill-core/src/test/scala/io/getquill/TestEntities.scala @@ -6,14 +6,22 @@ trait TestEntities { this: Context[_, _] => case class TestEntity(s: String, i: Int, l: Long, o: Option[Int]) + case class Emb(s: String, i: Int) extends Embedded + case class TestEntityEmb(emb: Emb, l: Long, o: Option[Int]) case class TestEntity2(s: String, i: Int, l: Long, o: Option[Int]) case class TestEntity3(s: String, i: Int, l: Long, o: Option[Int]) case class TestEntity4(i: Long) - case class TestEntity5(s: String, i: Long) + case class TestEntity5(i: Long, s: String) + case class EmbSingle(i: Long) extends Embedded + case class TestEntity4Emb(emb: EmbSingle) + case class TestEntityRegular(s: String, i: Long) val qr1 = quote { query[TestEntity] } + val qr1Emb = quote { + querySchema[TestEntityEmb]("TestEntity") + } val qr2 = quote { query[TestEntity2] } @@ -24,8 +32,14 @@ trait TestEntities { query[TestEntity4] } val qr5 = quote { + query[TestEntity5] + } + val qr4Emb = quote { + querySchema[TestEntity4Emb]("TestEntity4") + } + val qrRegular = quote { for { a <- query[TestEntity] - } yield TestEntity5(a.s, a.l) + } yield TestEntityRegular(a.s, a.l) } } diff --git a/quill-core/src/test/scala/io/getquill/ast/AstOpsSpec.scala b/quill-core/src/test/scala/io/getquill/ast/AstOpsSpec.scala index 6126745016..1cc7a0ca41 100644 --- a/quill-core/src/test/scala/io/getquill/ast/AstOpsSpec.scala +++ b/quill-core/src/test/scala/io/getquill/ast/AstOpsSpec.scala @@ -95,4 +95,28 @@ class AstOpsSpec extends Spec { } } } + + "returning matcher" - { + val insert = Insert(Entity("Ent", List()), List(Assignment(Ident("p"), Property(Ident("p"), "prop"), Constant(123)))) + val r = Ident("r") + val prop = Property(r, "value") + + "must match returning" in { + Returning(insert, r, prop) must matchPattern { + case ReturningAction(`insert`, `r`, `prop`) => + } + } + + "must match returning generated" in { + ReturningGenerated(insert, r, prop) must matchPattern { + case ReturningAction(`insert`, `r`, `prop`) => + } + } + + "must not match anything else" in { + insert mustNot matchPattern { + case ReturningAction(_, _, _) => + } + } + } } diff --git a/quill-core/src/test/scala/io/getquill/context/ActionMacroSpec.scala b/quill-core/src/test/scala/io/getquill/context/ActionMacroSpec.scala index 3500378167..5f3974437d 100644 --- a/quill-core/src/test/scala/io/getquill/context/ActionMacroSpec.scala +++ b/quill-core/src/test/scala/io/getquill/context/ActionMacroSpec.scala @@ -1,9 +1,11 @@ package io.getquill.context -import io.getquill.Spec -import io.getquill.testContext +import io.getquill.{ Spec, testContext } +import io.getquill.ReturnAction.{ ReturnColumns, ReturnRecord } import io.getquill.testContext._ import io.getquill.context.mirror.Row +import io.getquill.MirrorIdiomReturningSingle +import io.getquill.MirrorIdiomReturningMulti class ActionMacroSpec extends Spec { @@ -40,32 +42,121 @@ class ActionMacroSpec extends Spec { r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.l -> ?, v => v.o -> ?)""" r.prepareRow mustEqual Row("s", 1, 2L, None) } - "returning value" in { - val q = quote { - qr1.insert(t => t.i -> 1).returning(t => t.l) - } - val r = testContext.run(q) - r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returning((t) => t.l)""" - r.prepareRow mustEqual Row() - r.returningColumn mustEqual "l" - } - "scalar lifting + returning value" in { - val q = quote { - qr1.insert(t => t.i -> lift(1)).returning(t => t.l) - } - val r = testContext.run(q) - r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> ?).returning((t) => t.l)""" - r.prepareRow mustEqual Row(1) - r.returningColumn mustEqual "l" - } - "case class lifting + returning value" in { - val q = quote { - qr1.insert(lift(TestEntity("s", 1, 2L, None))).returning(t => t.l) + + "returning" - { + "returning value" in { + val q = quote { + qr1.insert(t => t.i -> 1).returning(t => t.l) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returning((t) => t.l)""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnRecord + } + "returning value - with single - should not compile" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + import ctx._ + "ctx.run(qr1.insert(t => t.i -> 1).returning(t => t.l))" mustNot compile + } + "returning value - with multi" in testContext.withDialect(MirrorIdiomReturningMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(t => t.i -> 1).returning(t => t.l) + } + val r = ctx.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returning((t) => t.l)""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnColumns(List("l")) + } + "returning generated value" in { + val q = quote { + qr1.insert(t => t.i -> 1).returningGenerated(t => t.l) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returningGenerated((t) => t.l)""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnRecord + } + "returning generated value - with single" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + import ctx._ + val q = quote { + qr1.insert(t => t.i -> 1).returningGenerated(t => t.l) + } + val r = ctx.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returningGenerated((t) => t.l)""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnColumns(List("l")) + } + "returning generated value - with single - multi should not compile" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + "import ctx._; ctx.run(qr1.insert(t => t.i -> 1).returningGenerated(t => (t.l, t.i))" mustNot compile + } + "returning generated value - with multi" in testContext.withDialect(MirrorIdiomReturningMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(t => t.i -> 1).returningGenerated(t => (t.l, t.s)) + } + val r = ctx.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returningGenerated((t) => (t.l, t.s))""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnColumns(List("l", "s")) + } + "returning generated value - with multi - operation in clause should not compile" in testContext.withDialect(MirrorIdiomReturningMulti) { ctx => + "import ctx._; ctx.run(qr1.insert(t => t.i -> 1).returningGenerated(t => (t.l, t.i + 1)))" mustNot compile + } + "returning generated value - with multi - single" in testContext.withDialect(MirrorIdiomReturningMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(t => t.i -> 1).returningGenerated(t => t.l) + } + val r = ctx.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> 1).returningGenerated((t) => t.l)""" + r.prepareRow mustEqual Row() + r.returningBehavior mustEqual ReturnColumns(List("l")) + } + "scalar lifting + returning value" in { + val q = quote { + qr1.insert(t => t.i -> lift(1)).returning(t => t.l) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(t => t.i -> ?).returning((t) => t.l)""" + r.prepareRow mustEqual Row(1) + r.returningBehavior mustEqual ReturnRecord + } + "case class lifting + returning value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returning(t => t.l) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.l -> ?, v => v.o -> ?).returning((t) => t.l)""" + r.prepareRow mustEqual Row("s", 1, 2, None) + r.returningBehavior mustEqual ReturnRecord + } + "case class lifting + returning generated value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returningGenerated(t => t.l) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returningGenerated((t) => t.l)""" + r.prepareRow mustEqual Row("s", 1, None) + r.returningBehavior mustEqual ReturnRecord + } + "case class lifting + returning multi value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returning(t => (t.l, t.i)) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.l -> ?, v => v.o -> ?).returning((t) => (t.l, t.i))""" + r.prepareRow mustEqual Row("s", 1, 2, None) + r.returningBehavior mustEqual ReturnRecord + } + "case class lifting + returning generated multi value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returningGenerated(t => (t.l, t.i)) + } + val r = testContext.run(q) + r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.o -> ?).returningGenerated((t) => (t.l, t.i))""" + r.prepareRow mustEqual Row("s", None) + r.returningBehavior mustEqual ReturnRecord } - val r = testContext.run(q) - r.string mustEqual """querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returning((t) => t.l)""" - r.prepareRow mustEqual Row("s", 1, None) - r.returningColumn mustEqual "l" } } @@ -146,7 +237,7 @@ class ActionMacroSpec extends Spec { } val r = testContext.run(q) r.groups mustEqual List( - ("""querySchema("TestEntity").insert(t => t.i -> ?).returning((t) => t.l)""", "l", List(Row(1), Row(2))) + ("""querySchema("TestEntity").insert(t => t.i -> ?).returning((t) => t.l)""", ReturnRecord, List(Row(1), Row(2))) ) } "case class + returning" in { @@ -155,7 +246,20 @@ class ActionMacroSpec extends Spec { } val r = testContext.run(q) r.groups mustEqual List( - ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returning((t) => t.l)""", "l", + ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.l -> ?, v => v.o -> ?).returning((t) => t.l)""", + ReturnRecord, + List(Row("s1", 2, 3, Some(4)), Row("s5", 6, 7, Some(8))) + ) + ) + } + "case class + returning generated" in { + val q = quote { + liftQuery(entities).foreach(p => qr1.insert(p).returningGenerated(t => t.l)) + } + val r = testContext.run(q) + r.groups mustEqual List( + ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returningGenerated((t) => t.l)""", + ReturnRecord, List(Row("s1", 2, Some(4)), Row("s5", 6, Some(8))) ) ) @@ -166,7 +270,20 @@ class ActionMacroSpec extends Spec { } val r = testContext.run(liftQuery(entities).foreach(p => insert(p))) r.groups mustEqual List( - ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returning((t) => t.l)""", "l", + ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.l -> ?, v => v.o -> ?).returning((t) => t.l)""", + ReturnRecord, + List(Row("s1", 2, 3, Some(4)), Row("s5", 6, 7, Some(8))) + ) + ) + } + "case class + returning generated + nested action" in { + val insert = quote { + (p: TestEntity) => qr1.insert(p).returningGenerated(t => t.l) + } + val r = testContext.run(liftQuery(entities).foreach(p => insert(p))) + r.groups mustEqual List( + ("""querySchema("TestEntity").insert(v => v.s -> ?, v => v.i -> ?, v => v.o -> ?).returningGenerated((t) => t.l)""", + ReturnRecord, List(Row("s1", 2, Some(4)), Row("s5", 6, Some(8))) ) ) @@ -221,7 +338,14 @@ class ActionMacroSpec extends Spec { qr1.insert(lift(TestEntity("s", 1, 2L, None))).returning(t => t.l) } testContext.translate(q) mustEqual - """querySchema("TestEntity").insert(v => v.s -> 's', v => v.i -> 1, v => v.o -> null).returning((t) => t.l)""" + """querySchema("TestEntity").insert(v => v.s -> 's', v => v.i -> 1, v => v.l -> 2, v => v.o -> null).returning((t) => t.l)""" + } + "case class lifting + returning generated value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returningGenerated(t => t.l) + } + testContext.translate(q) mustEqual + """querySchema("TestEntity").insert(v => v.s -> 's', v => v.i -> 1, v => v.o -> null).returningGenerated((t) => t.l)""" } } @@ -306,8 +430,8 @@ class ActionMacroSpec extends Spec { liftQuery(entities).foreach(p => qr1.insert(p).returning(t => t.l)) } testContext.translate(q) mustEqual List( - """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.o -> 4).returning((t) => t.l)""", - """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.o -> 8).returning((t) => t.l)""" + """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.l -> 3, v => v.o -> 4).returning((t) => t.l)""", + """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.l -> 7, v => v.o -> 8).returning((t) => t.l)""" ) } "case class + returning + nested action" in { @@ -315,8 +439,26 @@ class ActionMacroSpec extends Spec { (p: TestEntity) => qr1.insert(p).returning(t => t.l) } testContext.translate(liftQuery(entities).foreach(p => insert(p))) mustEqual List( - """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.o -> 4).returning((t) => t.l)""", - """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.o -> 8).returning((t) => t.l)""" + """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.l -> 3, v => v.o -> 4).returning((t) => t.l)""", + """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.l -> 7, v => v.o -> 8).returning((t) => t.l)""" + ) + } + "case class + returning generated" in { + val q = quote { + liftQuery(entities).foreach(p => qr1.insert(p).returningGenerated(t => t.l)) + } + testContext.translate(q) mustEqual List( + """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.o -> 4).returningGenerated((t) => t.l)""", + """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.o -> 8).returningGenerated((t) => t.l)""" + ) + } + "case class + returning generated + nested action" in { + val insert = quote { + (p: TestEntity) => qr1.insert(p).returningGenerated(t => t.l) + } + testContext.translate(liftQuery(entities).foreach(p => insert(p))) mustEqual List( + """querySchema("TestEntity").insert(v => v.s -> 's1', v => v.i -> 2, v => v.o -> 4).returningGenerated((t) => t.l)""", + """querySchema("TestEntity").insert(v => v.s -> 's5', v => v.i -> 6, v => v.o -> 8).returningGenerated((t) => t.l)""" ) } } diff --git a/quill-core/src/test/scala/io/getquill/context/BindMacroSpec.scala b/quill-core/src/test/scala/io/getquill/context/BindMacroSpec.scala index d900edfa9b..c0a846060f 100644 --- a/quill-core/src/test/scala/io/getquill/context/BindMacroSpec.scala +++ b/quill-core/src/test/scala/io/getquill/context/BindMacroSpec.scala @@ -56,6 +56,20 @@ class BindMacroSpec extends Spec { qr1.insert(lift(TestEntity("s", 1, 2L, None))).returning(t => t.l) } val r = testContext.prepare(q) + r(session) mustEqual Row("s", 1, 2, None) + } + "scalar lifting + returning generated value" in { + val q = quote { + qr1.insert(t => t.i -> lift(1)).returningGenerated(t => t.l) + } + val r = testContext.prepare(q) + r(session) mustEqual Row(1) + } + "case class lifting + returning generated value" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, None))).returningGenerated(t => t.l) + } + val r = testContext.prepare(q) r(session) mustEqual Row("s", 1, None) } } @@ -137,6 +151,13 @@ class BindMacroSpec extends Spec { liftQuery(entities).foreach(p => qr1.insert(p).returning(t => t.l)) } val r = testContext.prepare(q) + r(session) mustEqual List(Row("s1", 2, 3L, Some(4)), Row("s5", 6, 7L, Some(8))) + } + "case class + returning generated" in { + val q = quote { + liftQuery(entities).foreach(p => qr1.insert(p).returningGenerated(t => t.l)) + } + val r = testContext.prepare(q) r(session) mustEqual List(Row("s1", 2, Some(4)), Row("s5", 6, Some(8))) } "case class + returning + nested action" in { @@ -144,6 +165,13 @@ class BindMacroSpec extends Spec { (p: TestEntity) => qr1.insert(p).returning(t => t.l) } val r = testContext.prepare(liftQuery(entities).foreach(p => insert(p))) + r(session) mustEqual List(Row("s1", 2, 3L, Some(4)), Row("s5", 6, 7L, Some(8))) + } + "case class + returning generated + nested action" in { + val insert = quote { + (p: TestEntity) => qr1.insert(p).returningGenerated(t => t.l) + } + val r = testContext.prepare(liftQuery(entities).foreach(p => insert(p))) r(session) mustEqual List(Row("s1", 2, Some(4)), Row("s5", 6, Some(8))) } } diff --git a/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala b/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala new file mode 100644 index 0000000000..7832b9df66 --- /dev/null +++ b/quill-core/src/test/scala/io/getquill/norm/ExpandReturningSpec.scala @@ -0,0 +1,161 @@ +package io.getquill.norm + +import io.getquill.ReturnAction.{ ReturnColumns, ReturnRecord } +import io.getquill._ +import io.getquill.ast._ +import io.getquill.context.Expand + +class ExpandReturningSpec extends Spec { + + case class Person(name: String, age: Int) + case class Foo(bar: String, baz: Int) + + "inner apply" - { + val mi = MirrorIdiom + val ctx = new MirrorContext(mi, Literal) + import ctx._ + + "should replace tuple clauses with ExternalIdent" in { + val q = quote { + query[Person].insert(lift(Person("Joe", 123))).returning(p => (p.name, p.age)) + } + val list = + ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal) + list must matchPattern { + case List((Property(ExternalIdent("p"), "name"), _), (Property(ExternalIdent("p"), "age"), _)) => + } + } + + "should replace case class clauses with ExternalIdent" in { + val q = quote { + query[Person].insert(lift(Person("Joe", 123))).returning(p => Foo(p.name, p.age)) + } + val list = + ExpandReturning.apply(q.ast.asInstanceOf[Returning])(MirrorIdiom, Literal) + list must matchPattern { + case List((Property(ExternalIdent("p"), "name"), _), (Property(ExternalIdent("p"), "age"), _)) => + } + } + } + + "returning clause" - { + val mi = MirrorIdiom + val ctx = new MirrorContext(mi, Literal) + import ctx._ + val q = quote { query[Person].insert(lift(Person("Joe", 123))) } + + "should expand tuples with plain record" in { + val qi = quote { q.returning(p => (p.name, p.age)) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => fail("Should not use this method for the returning clause") + }(mi, Literal) + + ret mustBe ReturnRecord + } + "should expand case classes with plain record" in { + val qi = quote { q.returning(p => Foo(p.name, p.age)) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => fail("Should not use this method for the returning clause") + }(mi, Literal) + + ret mustBe ReturnRecord + } + "should expand whole record with plain record (converted to tuple in parser)" in { + val qi = quote { q.returning(p => p) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => fail("Should not use this method for the returning clause") + }(mi, Literal) + + ret mustBe ReturnRecord + } + } + + "returning multi" - { + val mi = MirrorIdiomReturningMulti + val ctx = new MirrorContext(mi, Literal) + import ctx._ + val q = quote { query[Person].insert(lift(Person("Joe", 123))) } + + "should expand tuples" in { + val qi = quote { q.returning(p => (p.name, p.age)) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + ret mustBe ReturnColumns(List("name", "age")) + } + "should expand case classes" in { + val qi = quote { q.returning(p => Foo(p.name, p.age)) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + ret mustBe ReturnColumns(List("name", "age")) + } + "should expand case classes (converted to tuple in parser)" in { + val qi = quote { q.returning(p => p) } + val ret = + ExpandReturning.applyMap(qi.ast.asInstanceOf[Returning]) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + ret mustBe ReturnColumns(List("name", "age")) + } + } + + "returning single and unsupported" - { + val insert = Insert( + Map( + Entity("Person", List()), + Ident("p"), + Tuple(List(Property(Ident("p"), "name"), Property(Ident("p"), "age"))) + ), + List(Assignment(Ident("pp"), Property(Ident("pp"), "name"), Constant("Joe"))) + ) + val retMulti = + Returning(insert, Ident("r"), Tuple(List(Property(Ident("r"), "name"), Property(Ident("r"), "age")))) + val retSingle = + Returning(insert, Ident("r"), Tuple(List(Property(Ident("r"), "name")))) + + "returning single" - { + val mi = MirrorIdiomReturningSingle + val ctx = new MirrorContext(mi, Literal) + + "should fail if multiple fields encountered" in { + assertThrows[IllegalArgumentException] { + ExpandReturning.applyMap(retMulti) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + } + } + "should succeed if single field encountered" in { + val ret = + ExpandReturning.applyMap(retSingle) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + ret mustBe ReturnColumns(List("name")) + } + } + "returning unsupported" - { + val mi = MirrorIdiomReturningUnsupported + val ctx = new MirrorContext(mi, Literal) + + "should fail if multiple fields encountered" in { + assertThrows[IllegalArgumentException] { + ExpandReturning.applyMap(retMulti) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + } + } + "should fail if single field encountered" in { + assertThrows[IllegalArgumentException] { + ExpandReturning.applyMap(retSingle) { + case (ast, stmt) => Expand(ctx, ast, stmt, mi, Literal).string + }(mi, Literal) + } + } + } + } +} diff --git a/quill-core/src/test/scala/io/getquill/norm/NormalizeReturningSpec.scala b/quill-core/src/test/scala/io/getquill/norm/NormalizeReturningSpec.scala index d693ba92b8..6e7b842bf9 100644 --- a/quill-core/src/test/scala/io/getquill/norm/NormalizeReturningSpec.scala +++ b/quill-core/src/test/scala/io/getquill/norm/NormalizeReturningSpec.scala @@ -1,7 +1,8 @@ package io.getquill.norm +import io.getquill.ReturnAction.{ ReturnColumns, ReturnRecord } import io.getquill.context.mirror.Row -import io.getquill.{ Spec, testContext } +import io.getquill.{ MirrorIdiomReturningSingle, MirrorIdiomReturningMulti, Spec, testContext } import io.getquill.testContext._ class NormalizeReturningSpec extends Spec { @@ -18,16 +19,55 @@ class NormalizeReturningSpec extends Spec { "when returning parent col" in { val r = testContext.run(q.returning(p => p.id)) - r.string mustEqual """querySchema("Entity").insert(v => v.emb.id -> ?).returning((p) => p.id)""" + r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?, v => v.emb.id -> ?).returning((p) => p.id)""" + r.prepareRow mustEqual Row(1, 2) + r.returningBehavior mustEqual ReturnRecord + } + "when returning parent col - single - returning should not compile" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + "ctx.run(query[Entity].insert(lift(e)).returning(p => p.id))" mustNot compile + } + "when returning parent col - single - returning generated" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + import ctx._ + val r = ctx.run(query[Entity].insert(lift(e)).returningGenerated(p => p.id)) + r.string mustEqual """querySchema("Entity").insert(v => v.emb.id -> ?).returningGenerated((p) => p.id)""" + r.prepareRow mustEqual Row(2) + r.returningBehavior mustEqual ReturnColumns(List("id")) + } + "when returning parent col - multi - returning (supported)" in testContext.withDialect(MirrorIdiomReturningMulti) { ctx => + import ctx._ + val r = ctx.run(query[Entity].insert(lift(e)).returning(p => p.id)) + r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?, v => v.emb.id -> ?).returning((p) => p.id)""" + r.prepareRow mustEqual Row(1, 2) + r.returningBehavior mustEqual ReturnColumns(List("id")) + } + "when returningGenerated parent col" in { + val r = testContext.run(q.returningGenerated(p => p.id)) + r.string mustEqual """querySchema("Entity").insert(v => v.emb.id -> ?).returningGenerated((p) => p.id)""" r.prepareRow mustEqual Row(2) - r.returningColumn mustEqual "id" + r.returningBehavior mustEqual ReturnRecord } "when returning embedded col" in { val r = testContext.run(q.returning(p => p.emb.id)) - r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?).returning((p) => p.emb.id)""" + r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?, v => v.emb.id -> ?).returning((p) => p.emb.id)""" + r.prepareRow mustEqual Row(1, 2) + r.returningBehavior mustEqual ReturnRecord + } + "when returningGenerated embedded col" in { + val r = testContext.run(q.returningGenerated(p => p.emb.id)) + r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?).returningGenerated((p) => p.emb.id)""" + r.prepareRow mustEqual Row(1) + r.returningBehavior mustEqual ReturnRecord + } + + "when returning embedded col - single" in testContext.withDialect(MirrorIdiomReturningSingle) { ctx => + import ctx._ + val r = ctx.run(query[Entity].insert(lift(e)).returningGenerated(p => p.emb.id)) + r.string mustEqual """querySchema("Entity").insert(v => v.id -> ?).returningGenerated((p) => p.emb.id)""" r.prepareRow mustEqual Row(1) - r.returningColumn mustEqual "id" + // As of #1489 the Idiom now decides how to tokenize a `returning` clause when for MirrorIdiom is `emb.id` + // since the mirror idiom specifically does not parse out embedded objects. + r.returningBehavior mustEqual ReturnColumns(List("emb.id")) } } diff --git a/quill-core/src/test/scala/io/getquill/package.scala b/quill-core/src/test/scala/io/getquill/package.scala index a68bb56a7c..024d8b353d 100644 --- a/quill-core/src/test/scala/io/getquill/package.scala +++ b/quill-core/src/test/scala/io/getquill/package.scala @@ -5,7 +5,7 @@ import scala.util.{ Failure, Try } package object getquill { - object testContext extends MirrorContext(MirrorIdiom, Literal) with TestEntities + object testContext extends TestMirrorContextTemplate(MirrorIdiom, Literal) with TestEntities object testAsyncContext extends AsyncMirrorContext(MirrorIdiom, Literal) with TestEntities { // hack to avoid Await.result since scala.js doesn't support it diff --git a/quill-core/src/test/scala/io/getquill/quotation/DynamicQuerySpec.scala b/quill-core/src/test/scala/io/getquill/quotation/DynamicQuerySpec.scala index 458d48aa29..91e013c4c4 100644 --- a/quill-core/src/test/scala/io/getquill/quotation/DynamicQuerySpec.scala +++ b/quill-core/src/test/scala/io/getquill/quotation/DynamicQuerySpec.scala @@ -1,10 +1,11 @@ package io.getquill.quotation import io.getquill._ +import io.getquill.dsl.DynamicQueryDsl class DynamicQuerySpec extends Spec { - // object testContext extends MirrorContext(MirrorIdiom, Literal) with TestEntities with DynamicQueryDsl + object testContext extends MirrorContext(MirrorIdiom, Literal) with TestEntities with DynamicQueryDsl import testContext._ "implicit classes" - { @@ -49,7 +50,9 @@ class DynamicQuerySpec extends Spec { d.q mustEqual q } "action returning" in { - val q: Quoted[ActionReturning[TestEntity, Int]] = qr1.insert(_.i -> 1).returning(_.i) + val q: Quoted[ActionReturning[TestEntity, Int]] = quote { + qr1.insert(_.i -> 1).returningGenerated(_.i) + } val d = { val d = q.dynamic (d: DynamicActionReturning[TestEntity, Int]) @@ -557,8 +560,16 @@ class DynamicQuerySpec extends Spec { } "returning" in { test( - dynamicQuery[TestEntity].insert(set(_.i, 1)).returning(v0 => v0.l), - query[TestEntity].insert(v => v.i -> 1).returning(v0 => v0.l) + dynamicQuery[TestEntity].insert(set(_.i, 1)).returningGenerated(v0 => v0.l), + quote { + query[TestEntity].insert(v => v.i -> 1).returningGenerated(v0 => v0.l) + } + ) + } + "returning non quoted" in { + test( + dynamicQuery[TestEntity].insert(set(_.i, 1)).returningGenerated(v0 => v0.l), + query[TestEntity].insert(v => v.i -> 1).returningGenerated((v0: TestEntity) => v0.l) ) } "onConflictIgnore" - { diff --git a/quill-core/src/test/scala/io/getquill/quotation/IsDynamicSpec.scala b/quill-core/src/test/scala/io/getquill/quotation/IsDynamicSpec.scala index 3750c22d14..196ecf9537 100644 --- a/quill-core/src/test/scala/io/getquill/quotation/IsDynamicSpec.scala +++ b/quill-core/src/test/scala/io/getquill/quotation/IsDynamicSpec.scala @@ -4,7 +4,7 @@ import io.getquill.Spec import io.getquill.ast.Dynamic import io.getquill.ast.Property import io.getquill.testContext.qr1 -import io.getquill.testContext.qr5 +import io.getquill.testContext.qrRegular class IsDynamicSpec extends Spec { @@ -21,7 +21,7 @@ class IsDynamicSpec extends Spec { IsDynamic(qr1.ast) mustEqual false } "false when using CaseClass" in { - IsDynamic(qr5.ast) mustEqual false + IsDynamic(qrRegular.ast) mustEqual false } } } diff --git a/quill-finagle-mysql/src/main/scala/io/getquill/FinagleMysqlContext.scala b/quill-finagle-mysql/src/main/scala/io/getquill/FinagleMysqlContext.scala index 3eeb753d2c..d9e2654a9d 100644 --- a/quill-finagle-mysql/src/main/scala/io/getquill/FinagleMysqlContext.scala +++ b/quill-finagle-mysql/src/main/scala/io/getquill/FinagleMysqlContext.scala @@ -138,7 +138,7 @@ class FinagleMysqlContext[N <: NamingStrategy]( .map(r => toOk(r).affectedRows) } - def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String): Future[T] = { + def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction): Future[T] = { val (params, prepared) = prepare(Nil) logger.logQuery(sql, params) withClient(Write)(_.prepare(sql)(prepared: _*)) diff --git a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala index c68c604c3b..59cb46d725 100644 --- a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala +++ b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala @@ -24,7 +24,7 @@ class FinagleMysqlContextSpec extends Spec { "Insert with returning with single column table" in { val inserted: Long = await(testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) }) await(testContext.run(qr4.filter(_.i == lift(inserted)))) .head.i mustBe inserted diff --git a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/ProductFinagleMysqlSpec.scala b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/ProductFinagleMysqlSpec.scala index e468b6e18f..c5722503fb 100644 --- a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/ProductFinagleMysqlSpec.scala +++ b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/ProductFinagleMysqlSpec.scala @@ -35,7 +35,7 @@ class ProductFinagleMysqlSpec extends ProductSpec { val prd = Product(0L, "test1", 1L) val inserted = await { testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } } val returnedProduct = await(testContext.run(productById(lift(inserted)))).head @@ -47,7 +47,7 @@ class ProductFinagleMysqlSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = await(testContext.run(q1)) val returnedProduct = await(testContext.run(productById(lift(inserted)))).head @@ -69,7 +69,7 @@ class ProductFinagleMysqlSpec extends ProductSpec { case class Product(id: Id, description: String, sku: Long) val prd = Product(Id(0L), "test2", 2L) val q1 = quote { - query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } await(testContext.run(q1)) mustBe a[Id] } diff --git a/quill-finagle-postgres/src/main/scala/io/getquill/FinaglePostgresContext.scala b/quill-finagle-postgres/src/main/scala/io/getquill/FinaglePostgresContext.scala index 2198dd9d79..c833ce5619 100644 --- a/quill-finagle-postgres/src/main/scala/io/getquill/FinaglePostgresContext.scala +++ b/quill-finagle-postgres/src/main/scala/io/getquill/FinaglePostgresContext.scala @@ -3,9 +3,11 @@ package io.getquill import com.twitter.util.{ Await, Future, Local } import com.twitter.finagle.postgres._ import com.typesafe.config.Config +import io.getquill.ReturnAction.{ ReturnColumns, ReturnNothing, ReturnRecord } import io.getquill.context.finagle.postgres._ import io.getquill.context.sql.SqlContext import io.getquill.util.{ ContextLogger, LoadConfig } + import scala.util.Try import io.getquill.context.{ Context, TranslateContext } import io.getquill.monad.TwitterFutureIOMonad @@ -43,8 +45,15 @@ class FinaglePostgresContext[N <: NamingStrategy](val naming: N, client: Postgre override def close = Await.result(client.close()) - private def expandAction(sql: String, returningColumn: String): String = - s"$sql RETURNING $returningColumn" + private def expandAction(sql: String, returningAction: ReturnAction): String = + returningAction match { + // The Postgres dialect will create SQL that has a 'RETURNING' clause so we don't have to add one. + case ReturnRecord => s"$sql" + // The Postgres dialect will not actually use these below variants but in case we decide to plug + // in some other dialect into this context... + case ReturnColumns(columns) => s"$sql RETURNING ${columns.mkString(", ")}" + case ReturnNothing => s"$sql" + } def probe(sql: String) = Try(Await.result(client.query(sql))) @@ -91,10 +100,10 @@ class FinaglePostgresContext[N <: NamingStrategy](val naming: N, client: Postgre } }.map(_.flatten.toList) - def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String): Future[T] = { + def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction): Future[T] = { val (params, prepared) = prepare(Nil) logger.logQuery(sql, params) - withClient(_.prepareAndQuery(expandAction(sql, returningColumn), prepared: _*)(extractor)).map(v => handleSingleResult(v.toList)) + withClient(_.prepareAndQuery(expandAction(sql, returningAction), prepared: _*)(extractor)).map(v => handleSingleResult(v.toList)) } def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]): Future[List[T]] = diff --git a/quill-jdbc-monix/src/main/scala/io/getquill/context/monix/MonixJdbcContext.scala b/quill-jdbc-monix/src/main/scala/io/getquill/context/monix/MonixJdbcContext.scala index 97b9ba3e50..ff37ec7dc1 100644 --- a/quill-jdbc-monix/src/main/scala/io/getquill/context/monix/MonixJdbcContext.scala +++ b/quill-jdbc-monix/src/main/scala/io/getquill/context/monix/MonixJdbcContext.scala @@ -3,7 +3,7 @@ package io.getquill.context.monix import java.io.Closeable import java.sql.{ Array => _, _ } -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, ReturnAction } import io.getquill.context.StreamingContext import io.getquill.context.jdbc.JdbcContextBase import io.getquill.context.sql.idiom.SqlIdiom @@ -43,8 +43,8 @@ abstract class MonixJdbcContext[Dialect <: SqlIdiom, Naming <: NamingStrategy]( super.executeQuery(sql, prepare, extractor) override def executeQuerySingle[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor): Task[T] = super.executeQuerySingle(sql, prepare, extractor) - override def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningColumn: String): Task[O] = - super.executeActionReturning(sql, prepare, extractor, returningColumn) + override def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningBehavior: ReturnAction): Task[O] = + super.executeActionReturning(sql, prepare, extractor, returningBehavior) override def executeBatchAction(groups: List[BatchGroup]): Task[List[Long]] = super.executeBatchAction(groups) override def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]): Task[List[T]] = diff --git a/quill-jdbc-monix/src/test/scala/io/getquill/h2/ProductJdbcSpec.scala b/quill-jdbc-monix/src/test/scala/io/getquill/h2/ProductJdbcSpec.scala index 1a9b31770c..60495996bb 100644 --- a/quill-jdbc-monix/src/test/scala/io/getquill/h2/ProductJdbcSpec.scala +++ b/quill-jdbc-monix/src/test/scala/io/getquill/h2/ProductJdbcSpec.scala @@ -42,7 +42,7 @@ class ProductJdbcSpec extends ProductSpec { val (inserted, returnedProduct) = (for { i <- testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } rps <- testContext.run(productById(lift(i))) } yield (i, rps.head)).runSyncUnsafe() @@ -55,7 +55,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val (inserted, returnedProduct) = (for { diff --git a/quill-jdbc-monix/src/test/scala/io/getquill/mysql/ProductJdbcSpec.scala b/quill-jdbc-monix/src/test/scala/io/getquill/mysql/ProductJdbcSpec.scala index dfbb0dc64f..0e65bed944 100644 --- a/quill-jdbc-monix/src/test/scala/io/getquill/mysql/ProductJdbcSpec.scala +++ b/quill-jdbc-monix/src/test/scala/io/getquill/mysql/ProductJdbcSpec.scala @@ -42,7 +42,7 @@ class ProductJdbcSpec extends ProductSpec { val (inserted, returnedProduct) = (for { i <- testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } rps <- testContext.run(productById(lift(i))) } yield (i, rps.head)).runSyncUnsafe() @@ -55,7 +55,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val (inserted, returnedProduct) = (for { diff --git a/quill-jdbc-monix/src/test/scala/io/getquill/sqlite/ProductJdbcSpec.scala b/quill-jdbc-monix/src/test/scala/io/getquill/sqlite/ProductJdbcSpec.scala index 957c784048..9de31fa7f4 100644 --- a/quill-jdbc-monix/src/test/scala/io/getquill/sqlite/ProductJdbcSpec.scala +++ b/quill-jdbc-monix/src/test/scala/io/getquill/sqlite/ProductJdbcSpec.scala @@ -43,7 +43,7 @@ class ProductJdbcSpec extends ProductSpec { val (inserted, returnedProduct) = (for { i <- testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } rps <- testContext.run(productById(lift(i))) } yield (i, rps.head)).runSyncUnsafe() @@ -56,7 +56,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val (inserted, returnedProduct) = (for { diff --git a/quill-jdbc-monix/src/test/scala/io/getquill/sqlserver/ProductJdbcSpec.scala b/quill-jdbc-monix/src/test/scala/io/getquill/sqlserver/ProductJdbcSpec.scala index d98aee5e2f..1db773d542 100644 --- a/quill-jdbc-monix/src/test/scala/io/getquill/sqlserver/ProductJdbcSpec.scala +++ b/quill-jdbc-monix/src/test/scala/io/getquill/sqlserver/ProductJdbcSpec.scala @@ -43,7 +43,7 @@ class ProductJdbcSpec extends ProductSpec { val (inserted, returnedProduct) = (for { i <- testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } rps <- testContext.run(productById(lift(i))) } yield (i, rps.head)).runSyncUnsafe() @@ -56,7 +56,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val (inserted, returnedProduct) = (for { diff --git a/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContext.scala b/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContext.scala index 24b26c493a..45745f2adf 100644 --- a/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContext.scala +++ b/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContext.scala @@ -5,7 +5,7 @@ import java.sql.{ Connection, PreparedStatement } import javax.sql.DataSource import io.getquill.context.sql.idiom.SqlIdiom -import io.getquill.NamingStrategy +import io.getquill.{ NamingStrategy, ReturnAction } import io.getquill.context.{ ContextEffect, TranslateContext } import scala.util.{ DynamicVariable, Try } @@ -40,8 +40,8 @@ abstract class JdbcContext[Dialect <: SqlIdiom, Naming <: NamingStrategy] super.executeQuery(sql, prepare, extractor) override def executeQuerySingle[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor): T = super.executeQuerySingle(sql, prepare, extractor) - override def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningColumn: String): O = - super.executeActionReturning(sql, prepare, extractor, returningColumn) + override def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningBehavior: ReturnAction): O = + super.executeActionReturning(sql, prepare, extractor, returningBehavior) override def executeBatchAction(groups: List[BatchGroup]): List[Long] = super.executeBatchAction(groups) override def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]): List[T] = diff --git a/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContextBase.scala b/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContextBase.scala index a3de051953..f1344185e0 100644 --- a/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContextBase.scala +++ b/quill-jdbc/src/main/scala/io/getquill/context/jdbc/JdbcContextBase.scala @@ -1,8 +1,9 @@ package io.getquill.context.jdbc -import java.sql.{ Connection, JDBCType, PreparedStatement, ResultSet } +import java.sql._ -import io.getquill.NamingStrategy +import io.getquill._ +import io.getquill.ReturnAction._ import io.getquill.context.sql.SqlContext import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.{ Context, ContextEffect } @@ -44,14 +45,21 @@ trait JdbcContextBase[Dialect <: SqlIdiom, Naming <: NamingStrategy] def executeQuerySingle[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor): Result[T] = handleSingleWrappedResult(executeQuery(sql, prepare, extractor)) - def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningColumn: String): Result[O] = + def executeActionReturning[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningBehavior: ReturnAction): Result[O] = withConnectionWrapped { conn => - val (params, ps) = prepare(conn.prepareStatement(sql, Array(returningColumn))) + val (params, ps) = prepare(prepareWithReturning(sql, conn, returningBehavior)) logger.logQuery(sql, params) ps.executeUpdate() handleSingleResult(extractResult(ps.getGeneratedKeys, extractor)) } + protected def prepareWithReturning(sql: String, conn: Connection, returningBehavior: ReturnAction) = + returningBehavior match { + case ReturnRecord => conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS) + case ReturnColumns(columns) => conn.prepareStatement(sql, columns.toArray) + case ReturnNothing => conn.prepareStatement(sql) + } + def executeBatchAction(groups: List[BatchGroup]): Result[List[Long]] = withConnectionWrapped { conn => groups.flatMap { @@ -70,8 +78,8 @@ trait JdbcContextBase[Dialect <: SqlIdiom, Naming <: NamingStrategy] def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]): Result[List[T]] = withConnectionWrapped { conn => groups.flatMap { - case BatchGroupReturning(sql, column, prepare) => - val ps = conn.prepareStatement(sql, Array(column)) + case BatchGroupReturning(sql, returningBehavior, prepare) => + val ps = prepareWithReturning(sql, conn, returningBehavior) logger.underlying.debug("Batch: {}", sql) prepare.foreach { f => val (params, _) = f(ps) diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/JdbcContextSpec.scala index 8289539653..702a468ef7 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/JdbcContextSpec.scala @@ -46,7 +46,7 @@ class JdbcContextSpec extends Spec { "Insert with returning with single column table" in { val inserted = testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) } testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/ProductJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/ProductJdbcSpec.scala index a5294ec29e..e0567df8f1 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/ProductJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/h2/ProductJdbcSpec.scala @@ -34,7 +34,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with inlined free variable" in { val prd = Product(0L, "test1", 1L) val inserted = testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val returnedProduct = testContext.run(productById(lift(inserted))).head returnedProduct.description mustEqual "test1" @@ -45,7 +45,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = testContext.run(q1) val returnedProduct = testContext.run(productById(lift(inserted))).head diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/DistinctJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/DistinctJdbcSpec.scala index d60ea861d9..f8d22f7131 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/DistinctJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/DistinctJdbcSpec.scala @@ -46,6 +46,6 @@ class DistinctJdbcSpec extends DistinctSpec { testContext.run(`Ex 7 Distinct Subquery with Map Multi Field Tuple`) should contain theSameElementsAs `Ex 7 Distinct Subquery with Map Multi Field Tuple Result` } "Ex 8 Distinct With Sort" in { - testContext.run(`Ex 8 Distinct With Sort`) mustEqual `Ex 8 Distinct With Sort Result` + testContext.run(`Ex 8 Distinct With Sort`) should contain theSameElementsAs `Ex 8 Distinct With Sort Result` } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/JdbcContextSpec.scala index 95b9c226a4..66b8ff29e3 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/JdbcContextSpec.scala @@ -54,7 +54,7 @@ class JdbcContextSpec extends Spec { "Insert with returning with single column table" in { val inserted = testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) } testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/ProductJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/ProductJdbcSpec.scala index f0dc2d8686..3e2c422398 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/ProductJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/ProductJdbcSpec.scala @@ -28,7 +28,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with inlined free variable" in { val prd = Product(0L, "test1", 1L) val inserted = testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val returnedProduct = testContext.run(productById(lift(inserted))).head returnedProduct.description mustEqual "test1" @@ -39,7 +39,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = testContext.run(q1) val returnedProduct = testContext.run(productById(lift(inserted))).head diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/JdbcContextSpec.scala index b11e237651..9ea5acb524 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/oracle/JdbcContextSpec.scala @@ -61,4 +61,21 @@ class JdbcContextSpec extends Spec { } testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted } + + "Insert with returning with multiple columns" in { + testContext.run(qr1.delete) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o)) + } + (1, "foo", Some(123)) mustBe inserted + } + + "Insert with returning with multiple columns - case class" in { + case class Return(id: Int, str: String, opt: Option[Int]) + testContext.run(qr1.delete) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => Return(r.i, r.s, r.o)) + } + Return(1, "foo", Some(123)) mustBe inserted + } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcContextSpec.scala index a1c32d5c07..12e5581134 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcContextSpec.scala @@ -52,10 +52,83 @@ class JdbcContextSpec extends Spec { } } - "Insert with returning with single column table" in { + "Insert with returning generated with single column table" in { + testContext.run(qr4.delete) + val insert = quote { + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) + } + + val inserted1 = testContext.run(insert) + testContext.run(qr4.filter(_.i == lift(inserted1))).head.i mustBe inserted1 + + val inserted2 = testContext.run(insert) + testContext.run(qr4.filter(_.i == lift(inserted2))).head.i mustBe inserted2 + + val inserted3 = testContext.run(insert) + testContext.run(qr4.filter(_.i == lift(inserted3))).head.i mustBe inserted3 + } + + "Insert with returning generated with single column table using query" in { + testContext.run(qr5.delete) + val id = testContext.run(qr5.insert(lift(TestEntity5(0, "foo"))).returningGenerated(_.i)) + val id2 = testContext.run { + qr5.insert(_.s -> "bar").returningGenerated(r => query[TestEntity5].filter(_.s == "foo").map(_.i).max) + }.get + id mustBe id2 + } + + "Insert with returning with multiple columns" in { + testContext.run(qr1.delete) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o)) + } + (1, "foo", Some(123)) mustBe inserted + } + + "Insert with returning with multiple columns and operations" in { + testContext.run(qr1.delete) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i + 100, r.s, r.o.map(_ + 100))) + } + (1 + 100, "foo", Some(123 + 100)) mustBe inserted + } + + "Insert with returning with multiple columns and query" in { + testContext.run(qr1.delete) + testContext.run(qr1.insert(lift(TestEntity("one", 1, 18L, Some(1))))) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("two", 2, 18L, Some(123)))).returning(r => + (r.i, r.s + "_s", qr1.filter(rr => rr.o.exists(_ == r.i)).map(_.s).max)) + } + (2, "two_s", Some("one")) mustBe inserted + } + + "Insert returning with multiple columns and query" in { + testContext.run(qr1.delete) + testContext.run(qr1.insert(lift(TestEntity("one", 1, 18L, Some(1))))) + val inserted = testContext.run { + qr1.insert(lift(TestEntity("two", 2, 18L, Some(123)))).returning(r => + (r.i, r.s + "_s", qr1.filter(rr => rr.o.exists(_ == r.i)).map(_.s).max)) + } + (2, "two_s", Some("one")) mustBe inserted + } + + "Insert with returning with multiple columns and query embedded" in { + testContext.run(qr1Emb.delete) + testContext.run(qr1Emb.insert(lift(TestEntityEmb(Emb("one", 1), 18L, Some(123))))) + val inserted = testContext.run { + qr1Emb.insert(lift(TestEntityEmb(Emb("two", 2), 18L, Some(123)))).returning(r => + (r.emb.i, r.o)) + } + (2, Some(123)) mustBe inserted + } + + "Insert with returning with multiple columns - case class" in { + case class Return(id: Int, str: String, opt: Option[Int]) + testContext.run(qr1.delete) val inserted = testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => Return(r.i, r.s, r.o)) } - testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted + Return(1, "foo", Some(123)) mustBe inserted } } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/ProductJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/ProductJdbcSpec.scala index 326ccc7cb8..d8386c0815 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/ProductJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/ProductJdbcSpec.scala @@ -27,15 +27,17 @@ class ProductJdbcSpec extends ProductSpec { product.id mustEqual inserted } + case class Foo(id: Long, description: String, sku: Long) + "Single insert with inlined free variable" in { val prd = Product(0L, "test1", 1L) val inserted = testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(r => r) } - val returnedProduct = testContext.run(productById(lift(inserted))).head + val returnedProduct = testContext.run(productById(lift(inserted.id))).head returnedProduct.description mustEqual "test1" returnedProduct.sku mustEqual 1L - returnedProduct.id mustEqual inserted + returnedProduct mustEqual inserted } "Single insert with free variable and explicit quotation" in { diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/JdbcContextSpec.scala index f0db2ab3e2..f99cbe113d 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/JdbcContextSpec.scala @@ -46,7 +46,7 @@ class JdbcContextSpec extends Spec { "Insert with returning with single column table" in { val inserted = testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) } testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted } diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/ProductJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/ProductJdbcSpec.scala index 0212f59cc1..9db94e56ff 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/ProductJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlite/ProductJdbcSpec.scala @@ -32,7 +32,7 @@ class ProductJdbcSpec extends ProductSpec { val result = testContext.run { liftQuery(list).foreach { prd => - query[Product].insert(prd).returning(_.id) + query[Product].insert(prd).returningGenerated(_.id) } } result.size mustEqual list.size @@ -41,7 +41,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with inlined free variable" in { val prd = Product(0L, "test1", 1L) val inserted = testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val returnedProduct = testContext.run(productById(lift(inserted))).head returnedProduct.description mustEqual "test1" @@ -52,7 +52,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = testContext.run(q1) val returnedProduct = testContext.run(productById(lift(inserted))).head diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/JdbcContextSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/JdbcContextSpec.scala index 2db7015da1..ef63e13477 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/JdbcContextSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/JdbcContextSpec.scala @@ -54,10 +54,17 @@ class JdbcContextSpec extends Spec { "Insert with returning with single column table" in { val inserted = testContext.run { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) } testContext.run(qr4.filter(_.i == lift(inserted))).head.i mustBe inserted } + + "Insert with returning with multiple columns and query embedded" in { + val inserted = testContext.run { + qr4Emb.insert(lift(TestEntity4Emb(EmbSingle(0)))).returningGenerated(_.emb.i) + } + testContext.run(qr4Emb.filter(_.emb.i == lift(inserted))).head.emb.i mustBe inserted + } } class PendingUntilFixed extends Spec { diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/ProductJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/ProductJdbcSpec.scala index 260add7671..d05d19078b 100644 --- a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/ProductJdbcSpec.scala +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/sqlserver/ProductJdbcSpec.scala @@ -31,7 +31,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with inlined free variable" in { val prd = Product(0L, "test1", 1L) val inserted = testContext.run { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val returnedProduct = testContext.run(productById(lift(inserted))).head returnedProduct.description mustEqual "test1" @@ -42,7 +42,7 @@ class ProductJdbcSpec extends ProductSpec { "Single insert with free variable and explicit quotation" in { val prd = Product(0L, "test2", 2L) val q1 = quote { - product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id) + product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id) } val inserted = testContext.run(q1) val returnedProduct = testContext.run(productById(lift(inserted))).head diff --git a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala index fdcdd55b25..4c497a4a81 100644 --- a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala +++ b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala @@ -5,13 +5,14 @@ import io.getquill.context.sql.norm._ import io.getquill.ast.{ AggregationOperator, Lift, _ } import io.getquill.context.sql._ import io.getquill.NamingStrategy +import io.getquill.context.CannotReturn import io.getquill.util.Messages.fail import io.getquill.idiom._ import io.getquill.context.sql.norm.SqlNormalize import io.getquill.util.Interleave import io.getquill.context.sql.idiom.VerifySqlQuery -object OrientDBIdiom extends OrientDBIdiom +object OrientDBIdiom extends OrientDBIdiom with CannotReturn trait OrientDBIdiom extends Idiom { @@ -48,6 +49,8 @@ trait OrientDBIdiom extends Idiom { a.token case a: Ident => a.token + case a: ExternalIdent => + a.token case a: Property => a.token case a: Value => @@ -254,6 +257,9 @@ trait OrientDBIdiom extends Idiom { implicit def identTokenizer(implicit strategy: NamingStrategy): Tokenizer[Ident] = Tokenizer[Ident](e => strategy.default(e.name).token) + implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent](e => strategy.default(e.name).token) + implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] { case Assignment(alias, prop, value) => stmt"${prop.token} = ${scopedTokenizer(value)}" diff --git a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala index a6128f408f..e072bf8126 100644 --- a/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala +++ b/quill-orientdb/src/test/scala/io/getquill/context/orientdb/OrientDBQuerySpec.scala @@ -1,8 +1,9 @@ package io.getquill.context.orientdb -import io.getquill.ast.{ Query => AstQuery, Action => AstAction, _ } +import io.getquill.ast.{ Action => AstAction, Query => AstQuery, _ } import io.getquill.context.sql._ import io.getquill.idiom.StatementInterpolator._ +import io.getquill.idiom.StringToken import io.getquill.{ Literal, Spec } class OrientDBQuerySpec extends Spec { @@ -263,5 +264,10 @@ class OrientDBQuerySpec extends Spec { t.token(ins("nonEmpty")) mustBe stmt"INSERT INTO tb (x IS NOT NULL) VALUES(i)" t.token(Insert(Entity("tb", Nil), List(Assignment(i, Property(i, "i"), i)))) mustBe stmt"INSERT INTO tb (i) VALUES(i)" } + // not actually used anywhere but doing a sanity check here + "external ident sanity check" in { + val t = implicitly[Tokenizer[ExternalIdent]] + t.token(ExternalIdent("TestIdent")) mustBe StringToken("TestIdent") + } } } \ No newline at end of file diff --git a/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala index f21496fdad..90643ff677 100644 --- a/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala +++ b/quill-spark/src/main/scala/io/getquill/context/spark/SparkDialect.scala @@ -19,6 +19,7 @@ import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.Token import io.getquill.util.Messages.trace import io.getquill.ast.Constant +import io.getquill.context.CannotReturn class SparkDialect extends SparkIdiom @@ -72,7 +73,7 @@ object SparkDialectRecursor { } } -trait SparkIdiom extends SqlIdiom { self => +trait SparkIdiom extends SqlIdiom with CannotReturn { self => def parentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = super.sqlQueryTokenizer diff --git a/quill-sql/src/main/scala/io/getquill/H2Dialect.scala b/quill-sql/src/main/scala/io/getquill/H2Dialect.scala index 691dafd369..7a16cd0a89 100644 --- a/quill-sql/src/main/scala/io/getquill/H2Dialect.scala +++ b/quill-sql/src/main/scala/io/getquill/H2Dialect.scala @@ -2,6 +2,8 @@ package io.getquill import io.getquill.idiom.StatementInterpolator._ import java.util.concurrent.atomic.AtomicInteger + +import io.getquill.context.CanReturnField import io.getquill.context.sql.idiom.PositionalBindVariables import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.sql.idiom.ConcatSupport @@ -9,7 +11,8 @@ import io.getquill.context.sql.idiom.ConcatSupport trait H2Dialect extends SqlIdiom with PositionalBindVariables - with ConcatSupport { + with ConcatSupport + with CanReturnField { private[getquill] val preparedStatementId = new AtomicInteger diff --git a/quill-sql/src/main/scala/io/getquill/MirrorSqlDialect.scala b/quill-sql/src/main/scala/io/getquill/MirrorSqlDialect.scala index ba202e9922..4d243c5024 100644 --- a/quill-sql/src/main/scala/io/getquill/MirrorSqlDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/MirrorSqlDialect.scala @@ -1,5 +1,6 @@ package io.getquill +import io.getquill.context.{ CanReturnClause, CanReturnField, CanReturnMultiField, CannotReturn } import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.sql.idiom.QuestionMarkBindVariables import io.getquill.context.sql.idiom.ConcatSupport @@ -8,7 +9,38 @@ trait MirrorSqlDialect extends SqlIdiom with QuestionMarkBindVariables with ConcatSupport + with CanReturnField + +trait MirrorSqlDialectWithReturnMulti + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnMultiField + +trait MirrorSqlDialectWithReturnClause + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnClause + +trait MirrorSqlDialectWithNoReturn + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CannotReturn object MirrorSqlDialect extends MirrorSqlDialect { override def prepareForProbing(string: String) = string } + +object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithReturnClause extends MirrorSqlDialectWithReturnClause { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn { + override def prepareForProbing(string: String) = string +} diff --git a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala index 7d9f0cfc2d..6f1e3eb01a 100644 --- a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala @@ -1,6 +1,7 @@ package io.getquill import io.getquill.ast.{ Ast, _ } +import io.getquill.context.CanReturnField import io.getquill.context.sql.OrderByCriteria import io.getquill.context.sql.idiom.{ NoConcatSupport, QuestionMarkBindVariables, SqlIdiom } import io.getquill.idiom.StatementInterpolator._ @@ -10,7 +11,8 @@ import io.getquill.util.Messages.fail trait MySQLDialect extends SqlIdiom with QuestionMarkBindVariables - with NoConcatSupport { + with NoConcatSupport + with CanReturnField { override def prepareForProbing(string: String) = { val quoted = string.replace("'", "\\'") diff --git a/quill-sql/src/main/scala/io/getquill/OracleDialect.scala b/quill-sql/src/main/scala/io/getquill/OracleDialect.scala index c1aacf3d31..d6659739e8 100644 --- a/quill-sql/src/main/scala/io/getquill/OracleDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/OracleDialect.scala @@ -1,6 +1,7 @@ package io.getquill import io.getquill.ast._ +import io.getquill.context.CanReturnMultiField import io.getquill.context.sql._ import io.getquill.context.sql.idiom._ import io.getquill.idiom.StatementInterpolator._ @@ -11,7 +12,8 @@ import io.getquill.norm.ConcatBehavior.NonAnsiConcat trait OracleDialect extends SqlIdiom with QuestionMarkBindVariables - with ConcatSupport { + with ConcatSupport + with CanReturnMultiField { class OracleFlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) extends FlattenSqlQueryTokenizerHelper(q)(astTokenizer, strategy) { diff --git a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala index cac1accd84..97a5eb6979 100644 --- a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala @@ -3,6 +3,7 @@ package io.getquill import java.util.concurrent.atomic.AtomicInteger import io.getquill.ast._ +import io.getquill.context.CanReturnClause import io.getquill.context.sql.idiom._ import io.getquill.idiom.StatementInterpolator._ @@ -10,7 +11,8 @@ trait PostgresDialect extends SqlIdiom with QuestionMarkBindVariables with ConcatSupport - with OnConflictSupport { + with OnConflictSupport + with CanReturnClause { override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] = Tokenizer[Ast] { diff --git a/quill-sql/src/main/scala/io/getquill/SQLServerDialect.scala b/quill-sql/src/main/scala/io/getquill/SQLServerDialect.scala index fc3d7de39a..6418a49083 100644 --- a/quill-sql/src/main/scala/io/getquill/SQLServerDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/SQLServerDialect.scala @@ -1,6 +1,7 @@ package io.getquill import io.getquill.ast._ +import io.getquill.context.CanReturnField import io.getquill.context.sql.{ FlattenSqlQuery, SqlQuery } import io.getquill.context.sql.idiom._ import io.getquill.context.sql.norm.AddDropToNestedOrderBy @@ -14,7 +15,8 @@ import io.getquill.util.Messages.fail trait SQLServerDialect extends SqlIdiom with QuestionMarkBindVariables - with ConcatSupport { + with ConcatSupport + with CanReturnField { override def querifyAst(ast: Ast) = AddDropToNestedOrderBy(SqlQuery(ast)) diff --git a/quill-sql/src/main/scala/io/getquill/SqliteDialect.scala b/quill-sql/src/main/scala/io/getquill/SqliteDialect.scala index 64ffa97968..6c06d8d929 100644 --- a/quill-sql/src/main/scala/io/getquill/SqliteDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/SqliteDialect.scala @@ -5,13 +5,15 @@ import io.getquill.context.sql.idiom._ import io.getquill.idiom.StatementInterpolator.Tokenizer import io.getquill.idiom.{ StringToken, Token } import io.getquill.ast._ +import io.getquill.context.CanReturnField import io.getquill.context.sql.OrderByCriteria trait SqliteDialect extends SqlIdiom with QuestionMarkBindVariables with NoConcatSupport - with OnConflictSupport { + with OnConflictSupport + with CanReturnField { override def emptySetContainsToken(field: Token) = StringToken("0") diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala index 8dc02bce16..1a52e92998 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala @@ -5,15 +5,18 @@ import io.getquill.ast.BooleanOperator._ import io.getquill.ast.Lift import io.getquill.context.sql._ import io.getquill.context.sql.norm._ -import io.getquill.idiom.{ Idiom, SetContainsToken, Statement } +import io.getquill.idiom._ import io.getquill.idiom.StatementInterpolator._ import io.getquill.NamingStrategy +import io.getquill.context.ReturningClauseSupported import io.getquill.util.Interleave import io.getquill.util.Messages.{ fail, trace } import io.getquill.idiom.Token -import io.getquill.norm.{ ConcatBehavior, EqualityBehavior } +import io.getquill.norm.EqualityBehavior +import io.getquill.norm.ConcatBehavior import io.getquill.norm.ConcatBehavior.AnsiConcat import io.getquill.norm.EqualityBehavior.AnsiEquality +import io.getquill.norm.ExpandReturning trait SqlIdiom extends Idiom { @@ -60,6 +63,7 @@ trait SqlIdiom extends Idiom { case a: Infix => a.token case a: Action => a.token case a: Ident => a.token + case a: ExternalIdent => a.token case a: Property => a.token case a: Value => a.token case a: If => a.token @@ -313,7 +317,20 @@ trait SqlIdiom extends Idiom { } Tokenizer[Property] { case Property(ast, name) => + // When we have things like Embedded tables, properties inside of one another needs to be un-nested. + // E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner + // property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias` + // as opposed to `realTable.embeddedTableAlias.realPropertyAlias`. unnest(ast) match { + // When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL + // returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no + // alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'. + // In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`. + case (ExternalIdent(_), prefix) => + stmt"${tokenizeColumn(strategy, prefix.mkString + name).token}" + // The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // becomes `realTable.realPropertyAlias`. case (ast, prefix) => stmt"${scopedTokenizer(ast)}.${tokenizeColumn(strategy, prefix.mkString + name).token}" } @@ -339,6 +356,9 @@ trait SqlIdiom extends Idiom { implicit def identTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ident] = Tokenizer[Ident](e => strategy.default(e.name).token) + implicit def externalIdentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent](e => strategy.default(e.name).token) + implicit def assignmentTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] { case Assignment(alias, prop, value) => stmt"${prop.token} = ${scopedTokenizer(value)}" @@ -360,6 +380,19 @@ trait SqlIdiom extends Idiom { case Property(_, name) => tokenizeColumn(strategy, name).token } + def returnListTokenizer(implicit tokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[List[Ast]] = { + val customAstTokenizer = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case sq: Query => + stmt"(${tokenizer.token(sq)})" + } + + Tokenizer[List[Ast]] { + case list => + list.mkStmt(", ")(customAstTokenizer) + } + } + protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Action] = Tokenizer[Action] { @@ -381,11 +414,21 @@ trait SqlIdiom extends Idiom { case Delete(table: Entity) => stmt"DELETE FROM ${table.token}" - case Returning(Insert(table: Entity, Nil), alias, prop) => - stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) => + idiomReturningCapability match { + case ReturningClauseSupported => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer.token(ExpandReturning(r)(this, strategy).map(_._1))}" + case other => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + } - case Returning(action, alias, prop) => - action.token + case r @ ReturningAction(action, alias, prop) => + idiomReturningCapability match { + case ReturningClauseSupported => + stmt"${action.token} RETURNING ${returnListTokenizer.token(ExpandReturning(r)(this, strategy).map(_._1))}" + case other => + stmt"${action.token}" + } case other => fail(s"Action ast can't be translated to sql: '$other'") diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/EncodingSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/EncodingSpec.scala index d8b3f7f8de..5f44d324a2 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/EncodingSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/EncodingSpec.scala @@ -154,7 +154,7 @@ trait EncodingSpec extends Spec { case class BarCode(description: String, uuid: Option[UUID] = None) - val insertBarCode = quote((b: BarCode) => query[BarCode].insert(b).returning(_.uuid)) + val insertBarCode = quote((b: BarCode) => query[BarCode].insert(b).returningGenerated(_.uuid)) val barCodeEntry = BarCode("returning UUID") def findBarCodeByUuid(uuid: UUID) = quote(query[BarCode].filter(_.uuid.forall(_ == lift(uuid)))) diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/ProductSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/ProductSpec.scala index 1a6a421de4..137557cd19 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/ProductSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/ProductSpec.scala @@ -17,7 +17,7 @@ trait ProductSpec extends Spec { } val productInsert = quote { - (p: Product) => query[Product].insert(p).returning(_.id) + (p: Product) => query[Product].insert(p).returningGenerated(_.id) } val productInsertBatch = quote { @@ -35,6 +35,6 @@ trait ProductSpec extends Spec { ) val productSingleInsert = quote { - product.insert(_.id -> 0, _.description -> "Window", _.sku -> 1004L).returning(_.id) + product.insert(_.id -> 0, _.description -> "Window", _.sku -> 1004L).returningGenerated(_.id) } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/SqlActionMacroSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/SqlActionMacroSpec.scala index 6e721f8a12..cfde0dfcd7 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/SqlActionMacroSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/SqlActionMacroSpec.scala @@ -1,6 +1,7 @@ package io.getquill.context.sql import io.getquill._ +import io.getquill.ReturnAction.{ ReturnColumns, ReturnRecord } import io.getquill.context.mirror.Row class SqlActionMacroSpec extends Spec { @@ -48,33 +49,395 @@ class SqlActionMacroSpec extends Spec { mirror.prepareRow mustEqual Row("s", 1) } } - "with returning" in { - val q = quote { - qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(_.l) - } - val a = TestEntity - val mirror = testContext.run(q) - mirror.string mustEqual "INSERT INTO TestEntity (s,i,o) VALUES (?, ?, ?)" - mirror.returningColumn mustEqual "l" + + "returning generated" - { + "single with returning generated" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(_.l) + } + + val mirror = testContext.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,o) VALUES (?, ?, ?)" + mirror.returningBehavior mustEqual ReturnColumns(List("l")) + } + "with assigned values" in { + val q = quote { + qr1.insert(_.s -> "s", _.i -> 0).returningGenerated(_.l) + } + + val mirror = testContext.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i) VALUES ('s', 0)" + mirror.returningBehavior mustEqual ReturnColumns(List("l")) + } + "single should fail on record type with multiple fields" in { + """testContext.run(qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => r))""" mustNot compile + } + "multi" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(_.l) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,o) VALUES (?, ?, ?)" + mirror.returningBehavior mustEqual ReturnColumns(List("l")) + } + "multi with record type returning" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => r) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + mirror.returningBehavior mustEqual ReturnColumns(List("s", "i", "l", "o")) + } + "multi with record type returning generated should exclude all" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => r) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity DEFAULT VALUES" + mirror.returningBehavior mustEqual ReturnColumns(List("s", "i", "l", "o")) + } + "multi - should fail on operation" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + """import ctx._; quote { qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => (r.i, r.l + 1)) }""" mustNot compile + } + "multi - should fail on operation inside case class" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + case class Foo(one: Int, two: Long) + """import ctx._; quote { qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => Foo(r.i, r.l + 1)) }""" mustNot compile + } + "no return - should fail on property" in testContext.withDialect(MirrorSqlDialectWithNoReturn) { ctx => + """import ctx._; quote { qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => r.i) }""" mustNot compile + } + "returning clause - single" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(_.l) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,o) VALUES (?, ?, ?) RETURNING l" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - multi" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => (r.i, r.l)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,o) VALUES (?, ?) RETURNING i, l" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - operation" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => (r.i, r.l + 1)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,o) VALUES (?, ?) RETURNING i, l + 1" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - record" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => r) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING s, i, l, o" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning generated clause - record" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returningGenerated(r => r) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity DEFAULT VALUES RETURNING s, i, l, o" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - embedded" - { + case class Dummy(i: Int) + + "embedded property" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returningGenerated(_.emb.i) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?) RETURNING i" + mirror.returningBehavior mustEqual ReturnRecord + } + "two embedded properties" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returningGenerated(r => (r.emb.i, r.emb.s)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (l,o) VALUES (?, ?) RETURNING i, s" + mirror.returningBehavior mustEqual ReturnRecord + } + "query with filter using id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returningGenerated(r => (query[Dummy].filter(d => d.i == r.emb.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?) RETURNING (SELECT MAX(*) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + } + "with returning clause - query" - { + case class Dummy(i: Int) + + "simple not using id - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated(r => (query[Dummy].map(d => d.i).value)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT d.i FROM Dummy d)" + mirror.returningBehavior mustEqual ReturnRecord + } + "simple with id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated(r => (r.i, query[Dummy].map(d => d.i).value)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?) RETURNING i, (SELECT d.i FROM Dummy d)" + mirror.returningBehavior mustEqual ReturnRecord + } + "simple with filter using id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated(r => (query[Dummy].filter(d => d.i == r.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?) RETURNING (SELECT MAX(*) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated(r => (query[Dummy].map(r => r.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(r1.i) FROM Dummy r1)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable in multiple clauses - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated( + r => + (query[Dummy] + .filter( + r => r.i == r.i /* always true since r overridden! */ + ) + .map(r => r.i) + .max) + ) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(r1.i) FROM Dummy r1 WHERE r1.i = r1.i)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable in one of multiple clauses - id excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returningGenerated( + r => (query[Dummy].filter(d => d.i == r.i).map(r => r.i).max) + ) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?) RETURNING (SELECT MAX(d.i) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + } } - "with assigned values and returning" in { - val q = quote { - qr1.insert(_.s -> "s", _.i -> 0).returning(_.l) - } - val a = TestEntity - val mirror = testContext.run(q) - mirror.string mustEqual "INSERT INTO TestEntity (s,i) VALUES ('s', 0)" - mirror.returningColumn mustEqual "l" + + "returning" - { + "multi" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(_.l) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + mirror.returningBehavior mustEqual ReturnColumns(List("l")) + } + "multi - should fail on operation" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + """import ctx._; quote { qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => (r.i, r.l + 1)) }""" mustNot compile + } + "no return - should fail on property" in testContext.withDialect(MirrorSqlDialectWithNoReturn) { ctx => + """import ctx._; quote { qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => r.i) }""" mustNot compile + } + "returning clause - single" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(_.l) + } + + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING l" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - multi" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => (r.i, r.l)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING i, l" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - operation" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 1L, None))).returning(r => (r.i, r.l + 1)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING i, l + 1" + mirror.returningBehavior mustEqual ReturnRecord + } + "returning clause - embedded" - { + case class Dummy(i: Int) + + "embedded property" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returning(_.emb.i) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING i" + mirror.returningBehavior mustEqual ReturnRecord + } + "two embedded properties" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returning(r => (r.emb.i, r.emb.s)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING i, s" + mirror.returningBehavior mustEqual ReturnRecord + } + "query with filter using id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1Emb.insert(lift(TestEntityEmb(Emb("s", 0), 1L, None))).returning(r => (query[Dummy].filter(d => d.i == r.emb.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(*) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + } + "with returning clause - query" - { + case class Dummy(i: Int) + + "simple not using id - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning(r => (query[Dummy].map(d => d.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(d.i) FROM Dummy d)" + mirror.returningBehavior mustEqual ReturnRecord + } + "simple with id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning(r => (r.i, query[Dummy].map(d => d.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING i, (SELECT MAX(d.i) FROM Dummy d)" + mirror.returningBehavior mustEqual ReturnRecord + } + "simple with filter using id - id should be excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning(r => (query[Dummy].filter(d => d.i == r.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(*) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning(r => (query[Dummy].map(r => r.i).max)) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(r1.i) FROM Dummy r1)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable in multiple clauses - id not excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning( + r => + (query[Dummy] + .filter( + r => r.i == r.i /* always true since r overridden! */ + ) + .map(r => r.i) + .max) + ) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(r1.i) FROM Dummy r1 WHERE r1.i = r1.i)" + mirror.returningBehavior mustEqual ReturnRecord + } + "shadow variable in one of multiple clauses - id excluded" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val q = quote { + qr1 + .insert(lift(TestEntity("s", 0, 1L, None))) + .returning( + r => (query[Dummy].filter(d => d.i == r.i).map(r => r.i).max) + ) + } + val mirror = ctx.run(q) + mirror.string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) RETURNING (SELECT MAX(d.i) FROM Dummy d WHERE d.i = i)" + mirror.returningBehavior mustEqual ReturnRecord + } + } } + } "apply naming strategy to returning action" in testContext.withNaming(SnakeCase) { ctx => import ctx._ case class TestEntity4(intId: Int, textCol: String) val q = quote { - query[TestEntity4].insert(lift(TestEntity4(1, "s"))).returning(_.intId) + query[TestEntity4].insert(lift(TestEntity4(1, "s"))).returningGenerated(_.intId) } val mirror = ctx.run(q) mirror.string mustEqual "INSERT INTO test_entity4 (text_col) VALUES (?)" - mirror.returningColumn mustEqual "int_id" + mirror.returningBehavior mustEqual ReturnColumns(List("int_id")) } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/SqlContextSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/SqlContextSpec.scala index ecea9fa3ed..cb3c2b721a 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/SqlContextSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/SqlContextSpec.scala @@ -9,7 +9,7 @@ import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.context.sql.testContext._ import scala.util.Try -import io.getquill.context.Context +import io.getquill.context.{ CanReturnField, Context } import io.getquill.context.sql.idiom.ConcatSupport class SqlContextSpec extends Spec { @@ -37,7 +37,7 @@ class SqlContextSpec extends Spec { "testContext.run(qr1.delete)" mustNot compile - class EvilDBDialect extends SqlIdiom with ConcatSupport { + class EvilDBDialect extends SqlIdiom with ConcatSupport with CanReturnField { override def liftingPlaceholder(index: Int): String = "?" override def prepareForProbing(string: String) = string diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/TestContext.scala b/quill-sql/src/test/scala/io/getquill/context/sql/TestContext.scala index 0372b61b83..bfccaab134 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/TestContext.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/TestContext.scala @@ -1,23 +1,30 @@ package io.getquill.context.sql +import io.getquill.context.sql.idiom.SqlIdiom import io.getquill.norm.EqualityBehavior import io.getquill.norm.EqualityBehavior.NonAnsiEquality import io.getquill.{ Literal, MirrorSqlDialect, NamingStrategy, SqlMirrorContext, TestEntities } -class TestContextTemplate[Naming <: NamingStrategy](naming: Naming) - extends SqlMirrorContext(MirrorSqlDialect, naming) +class TestContextTemplate[Dialect <: SqlIdiom, Naming <: NamingStrategy](dialect: Dialect, naming: Naming) + extends SqlMirrorContext(dialect, naming) with TestEntities with TestEncoders with TestDecoders { - def withNaming[N <: NamingStrategy](naming: N)(f: TestContextTemplate[N] => Any): Unit = { - val ctx = new TestContextTemplate[N](naming) + def withNaming[N <: NamingStrategy](naming: N)(f: TestContextTemplate[Dialect, N] => Any): Unit = { + val ctx = new TestContextTemplate[Dialect, N](dialect, naming) + f(ctx) + ctx.close + } + + def withDialect[I <: SqlIdiom](dialect: I)(f: TestContextTemplate[I, Naming] => Any): Unit = { + val ctx = new TestContextTemplate[I, Naming](dialect, naming) f(ctx) ctx.close } } -object testContext extends TestContextTemplate(Literal) +object testContext extends TestContextTemplate[MirrorSqlDialect, Literal](MirrorSqlDialect, Literal) trait NonAnsiMirrorSqlDialect extends MirrorSqlDialect { override def equalityBehavior: EqualityBehavior = NonAnsiEquality @@ -31,12 +38,6 @@ class NonAnsiTestContextTemplate[Naming <: NamingStrategy](naming: Naming) with TestEntities with TestEncoders with TestDecoders { - - def withNaming[N <: NamingStrategy](naming: N)(f: TestContextTemplate[N] => Any): Unit = { - val ctx = new TestContextTemplate[N](naming) - f(ctx) - ctx.close - } } object nonAnsiTestContext extends NonAnsiTestContextTemplate(Literal) \ No newline at end of file diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala index ac96d125a3..28a2cabe04 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala @@ -75,13 +75,25 @@ class MySQLDialectSpec extends OnConflictSpec { } } - "Insert with returning with single column table" in { + "Insert with returning generated with single column table" in { val q = quote { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) } ctx.run(q).string mustEqual "INSERT INTO TestEntity4 (i) VALUES (DEFAULT)" } + "Insert with returning generated - multiple fields - should not compile" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, Some(3)))) + } + "ctx.run(q.returningGenerated(r => (r.i, r.l))).string" mustNot compile + } + "Insert with returning should not compile" in { + val q = quote { + qr4.insert(lift(TestEntity4(0))) + } + "ctx.run(q.returning(_.i)).string" mustNot compile + } "OnConflict" - { "no target - ignore" in { diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OracleDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OracleDialectSpec.scala index a8a258f021..100da6ec99 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OracleDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OracleDialectSpec.scala @@ -25,12 +25,48 @@ class OracleDialectSpec extends Spec { } } - "Insert with returning with single column table" in { - val q = quote { - qr4.insert(lift(TestEntity4(0))).returning(_.i) + "Insert with returning" - { + "with single column table" in { + val q = quote { + qr4.insert(lift(TestEntity4(0))).returning(_.i) + } + ctx.run(q).string mustEqual + "INSERT INTO TestEntity4 (i) VALUES (?)" + } + + "returning generated with single column table" in { + val q = quote { + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) + } + ctx.run(q).string mustEqual + "INSERT INTO TestEntity4 (i) VALUES (DEFAULT)" + } + "returning with multi column table" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 0L, Some(3)))).returning(r => (r.i, r.l)) + } + ctx.run(q).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + } + "returning generated with multi column table" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 0, 0L, Some(3)))).returningGenerated(r => (r.i, r.l)) + } + ctx.run(q).string mustEqual + "INSERT INTO TestEntity (s,o) VALUES (?, ?)" + } + "returning - multiple fields + operations - should not compile" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, Some(3)))) + } + "ctx.run(q.returning(r => (r.i, r.l + 1))).string" mustNot compile + } + "returning generated - multiple fields + operations - should not compile" in { + val q = quote { + qr1.insert(lift(TestEntity("s", 1, 2L, Some(3)))) + } + "ctx.run(q.returningGenerated(r => (r.i, r.l + 1))).string" mustNot compile } - ctx.run(q).string mustEqual - "INSERT INTO TestEntity4 (i) VALUES (DEFAULT)" } "offset/fetch" - { diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/SqlIdiomSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/SqlIdiomSpec.scala index 13814f198d..924dc6b5fd 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/SqlIdiomSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/SqlIdiomSpec.scala @@ -1,6 +1,7 @@ package io.getquill.context.sql.idiom -import io.getquill.Spec +import io.getquill.ReturnAction.ReturnColumns +import io.getquill.{ MirrorSqlDialectWithReturnMulti, Spec } import io.getquill.context.mirror.Row import io.getquill.context.sql.testContext import io.getquill.context.sql.testContext._ @@ -758,12 +759,20 @@ class SqlIdiomSpec extends Spec { val v = TestEntity("s", 1, 2L, Some(1)) testContext.run(q(lift(v))).string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" } - "returning" in { + "returning" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ val q = quote { (v: TestEntity) => query[TestEntity].insert(v) } val v = TestEntity("s", 1, 2L, Some(1)) - testContext.run(q(lift(v)).returning(v => v.i)).string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)" + ctx.run(q(lift(v)).returning(v => v.i)).string mustEqual "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + } + "returning generated" in { + val q = quote { (v: TestEntity) => + query[TestEntity].insert(v) + } + val v = TestEntity("s", 1, 2L, Some(1)) + testContext.run(q(lift(v)).returningGenerated(v => v.i)).string mustEqual "INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)" } "foreach" in { val v = TestEntity("s", 1, 2L, Some(1)) @@ -771,12 +780,22 @@ class SqlIdiomSpec extends Spec { liftQuery(List(v)).foreach(v => query[TestEntity].insert(v)) ).groups mustEqual List(("INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)", List(Row(v.productIterator.toList: _*)))) } - "foreach returning" in { + "foreach returning" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ + val v = TestEntity("s", 1, 2L, Some(1)) + ctx.run(liftQuery(List(v)).foreach(v => query[TestEntity].insert(v).returning(v => v.i))).groups mustEqual + List(("INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)", + ReturnColumns(List("i")), + List(Row(v.productIterator.toList: _*)) + )) + } + "foreach returning generated" in { val v = TestEntity("s", 1, 2L, Some(1)) testContext.run( - liftQuery(List(v)).foreach(v => query[TestEntity].insert(v).returning(v => v.i)) + liftQuery(List(v)).foreach(v => query[TestEntity].insert(v).returningGenerated(v => v.i)) ).groups mustEqual - List(("INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)", "i", + List(("INSERT INTO TestEntity (s,l,o) VALUES (?, ?, ?)", + ReturnColumns(List("i")), List(Row(v.productIterator.toList.filter(m => !m.isInstanceOf[Int]): _*)) )) } @@ -795,17 +814,33 @@ class SqlIdiomSpec extends Spec { testContext.run(q).string mustEqual "INSERT INTO TestEntity (l,s) VALUES ((SELECT COUNT(t.i) FROM TestEntity2 t), 's')" } - "returning" in { + "returning" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ val q = quote { query[TestEntity].insert(lift(TestEntity("s", 1, 2L, Some(1)))).returning(_.l) } + val run = ctx.run(q).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + } + "returning generated" in { + val q = quote { + query[TestEntity].insert(lift(TestEntity("s", 1, 2L, Some(1)))).returningGenerated(_.l) + } val run = testContext.run(q).string mustEqual "INSERT INTO TestEntity (s,i,o) VALUES (?, ?, ?)" } - "returning with single column table" in { + "returning with single column table" in testContext.withDialect(MirrorSqlDialectWithReturnMulti) { ctx => + import ctx._ val q = quote { qr4.insert(lift(TestEntity4(0))).returning(_.i) } + ctx.run(q).string mustEqual + "INSERT INTO TestEntity4 (i) VALUES (?)" + } + "returning generated with single column table" in { + val q = quote { + qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i) + } testContext.run(q).string mustEqual "INSERT INTO TestEntity4 DEFAULT VALUES" } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala index 18942ef7d5..7205972a12 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/norm/RenamePropertiesSpec.scala @@ -1,6 +1,7 @@ package io.getquill.context.sql.norm -import io.getquill.Spec +import io.getquill.{ MirrorSqlDialectWithReturnClause, Spec } +import io.getquill.ReturnAction.{ ReturnColumns, ReturnRecord } import io.getquill.context.sql.testContext._ import io.getquill.context.sql.testContext @@ -69,12 +70,23 @@ class RenamePropertiesSpec extends Spec { "DELETE FROM test_entity WHERE field_i = 999" } "returning" - { - "alias" in { + "returning - alias" in testContext.withDialect(MirrorSqlDialectWithReturnClause) { ctx => + import ctx._ + val e1 = quote { + querySchema[TestEntity]("test_entity", _.s -> "field_s", _.i -> "field_i") + } + val q = quote { + e1.insert(lift(TestEntity("s", 1, 1L, None))).returning(_.i) + } + val mirror = ctx.run(q.dynamic) + mirror.returningBehavior mustEqual ReturnRecord + } + "returning generated - alias" in { val q = quote { - e.insert(lift(TestEntity("s", 1, 1L, None))).returning(_.i) + e.insert(lift(TestEntity("s", 1, 1L, None))).returningGenerated(_.i) } val mirror = testContext.run(q.dynamic) - mirror.returningColumn mustEqual "field_i" + mirror.returningBehavior mustEqual ReturnColumns(List("field_i")) } } } diff --git a/quill-sql/src/test/sql/postgres-schema.sql b/quill-sql/src/test/sql/postgres-schema.sql index 31b624b3e2..50343abb96 100644 --- a/quill-sql/src/test/sql/postgres-schema.sql +++ b/quill-sql/src/test/sql/postgres-schema.sql @@ -81,6 +81,11 @@ CREATE TABLE TestEntity4( i SERIAL PRIMARY KEY ); +CREATE TABLE TestEntity5( + i SERIAL PRIMARY KEY, + s VARCHAR(255) +); + CREATE TABLE Product( description VARCHAR(255), id SERIAL PRIMARY KEY,