Skip to content

Commit

Permalink
inserted record can be returned from query
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandjohann authored and deusaquilus committed Jul 4, 2019
1 parent 08cf5b3 commit 7219984
Show file tree
Hide file tree
Showing 83 changed files with 1,930 additions and 395 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MysqlAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConnect
def this(naming: N, config: Config) = this(naming, MysqlAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result match {
case r: MySQLQueryResult =>
returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill.context.async.mysql

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, MysqlAsyncContext, Spec }
import io.getquill.{ Literal, MysqlAsyncContext, ReturnAction, Spec }

class MysqlAsyncContextSpec extends Spec {

Expand All @@ -18,7 +19,7 @@ class MysqlAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
Expand All @@ -35,13 +36,13 @@ class MysqlAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends MysqlAsyncContext(Literal, "testMysqlDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
val prd = Product(0L, "test1", 1L)
val inserted = await {
testContext.run {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
}
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -47,7 +47,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
"Single insert with free variable and explicit quotation" in {
val prd = Product(0L, "test2", 2L)
val q1 = quote {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
val inserted = await(testContext.run(q1))
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -60,7 +60,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
case class Product(id: Id, description: String, sku: Long)
val prd = Product(Id(0L), "test2", 2L)
val q1 = quote {
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
await(testContext.run(q1)) mustBe a[Id]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill

import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.pool.PartitionedConnectionPool
import com.github.mauricio.async.db.postgresql.PostgreSQLConnection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.typesafe.config.Config
import io.getquill.ReturnAction.{ ReturnColumns, ReturnNothing, ReturnRecord }
import io.getquill.context.async.{ ArrayDecoders, ArrayEncoders, AsyncContext, UUIDObjectEncoding }
import io.getquill.util.LoadConfig
import io.getquill.util.Messages.fail
Expand All @@ -18,7 +19,7 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
def this(naming: N, config: Config) = this(naming, PostgresAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result.rows match {
case Some(r) if r.nonEmpty =>
returningExtractor(r.head)
Expand All @@ -27,6 +28,14 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
}
}

override protected def expandAction(sql: String, returningColumn: String): String =
s"$sql RETURNING $returningColumn"
override protected def expandAction(sql: String, returningAction: ReturnAction): String =
returningAction match {
// The Postgres dialect will create SQL that has a 'RETURNING' clause so we don't have to add one.
case ReturnRecord => s"$sql"
// The Postgres dialect will not actually use these below variants but in case we decide to plug
// in some other dialect into this context...
case ReturnColumns(columns) => s"$sql RETURNING ${columns.mkString(", ")}"
case ReturnNothing => s"$sql"
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill.context.async.postgres

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, PostgresAsyncContext, Spec }
import io.getquill.{ Literal, PostgresAsyncContext, ReturnAction, Spec }

class PostgresAsyncContextSpec extends Spec {

Expand All @@ -18,11 +19,18 @@ class PostgresAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
}
"Insert with returning with multiple columns" in {
await(testContext.run(qr1.delete))
val inserted = await(testContext.run {
qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o))
})
(1, "foo", Some(123)) mustBe inserted
}

"performIO" in {
await(performIO(runIO(qr4).transactional))
Expand All @@ -35,13 +43,13 @@ class PostgresAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends PostgresAsyncContext(Literal, "testPostgresDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import com.github.mauricio.async.db.Connection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.RowData
import com.github.mauricio.async.db.pool.PartitionedConnectionPool

import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Try
import io.getquill.context.sql.SqlContext
import io.getquill.context.sql.idiom.SqlIdiom
import io.getquill.NamingStrategy
import io.getquill.{ NamingStrategy, ReturnAction }
import io.getquill.util.ContextLogger
import io.getquill.monad.ScalaFutureIOMonad
import io.getquill.context.{ Context, TranslateContext }
Expand Down Expand Up @@ -48,9 +49,9 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
case other => f(pool)
}

protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O
protected def extractActionResult[O](returningAction: ReturnAction, extractor: Extractor[O])(result: DBQueryResult): O

protected def expandAction(sql: String, returningColumn: String) = sql
protected def expandAction(sql: String, returningAction: ReturnAction) = sql

def probe(sql: String) =
Try {
Expand Down Expand Up @@ -88,12 +89,12 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
withConnection(_.sendPreparedStatement(sql, values)).map(_.rowsAffected)
}

def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningColumn)
def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningAction)
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(expanded, values))
.map(extractActionResult(returningColumn, extractor))
.map(extractActionResult(returningAction, extractor))
}

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Future[List[Long]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.getquill.context.cassandra

import io.getquill.ast.{ TraversableOperation, _ }
import io.getquill.NamingStrategy
import io.getquill.context.CannotReturn
import io.getquill.util.Messages.fail
import io.getquill.idiom.Idiom
import io.getquill.idiom.StatementInterpolator._
Expand All @@ -10,7 +11,7 @@ import io.getquill.idiom.SetContainsToken
import io.getquill.idiom.Token
import io.getquill.util.Interleave

object CqlIdiom extends CqlIdiom
object CqlIdiom extends CqlIdiom with CannotReturn

trait CqlIdiom extends Idiom {

Expand All @@ -33,6 +34,7 @@ trait CqlIdiom extends Idiom {
case a: Operation => a.token
case a: Action => a.token
case a: Ident => a.token
case a: ExternalIdent => a.token
case a: Property => a.token
case a: Value => a.token
case a: Function => a.body.token
Expand Down Expand Up @@ -135,6 +137,10 @@ trait CqlIdiom extends Idiom {
case e => strategy.default(e.name).token
}

implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent] {
case e => strategy.default(e.name).token
}

implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(alias, prop, value) =>
stmt"${prop.token} = ${value.token}"
Expand Down Expand Up @@ -175,6 +181,9 @@ trait CqlIdiom extends Idiom {
case _: Returning =>
fail(s"Cql doesn't support returning generated during insertion")

case _: ReturningGenerated =>
fail(s"Cql doesn't support returning generated during insertion")

case other =>
fail(s"Action ast can't be translated to cql: '$other'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ class CqlIdiomSpec extends Spec {
"SELECT s FROM TestEntity WHERE i = 1 ORDER BY s ASC LIMIT 1"
}
"returning" in {
val q = quote {
query[TestEntity].insert(_.l -> 1L).returning(_.i)
}
"mirrorContext.run(q).string" mustNot compile
"mirrorContext.run(query[TestEntity].insert(_.l -> 1L).returning(_.i)).string" mustNot compile
}
}

Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom

case class ActionMirror(string: String, prepareRow: PrepareRow)(implicit val ec: ExecutionContext)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)(implicit val ec: ExecutionContext)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)(implicit val ec: ExecutionContext)

case class BatchActionMirror(groups: List[(String, List[Row])])(implicit val ec: ExecutionContext)

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])(implicit val ec: ExecutionContext)

Expand All @@ -74,8 +74,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future(ActionMirror(string, prepare(Row())._2))

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn))
returningBehavior: ReturnAction)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior))

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext) =
Future {
Expand All @@ -91,8 +91,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future {
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)
}
Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/MirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi

case class ActionMirror(string: String, prepareRow: PrepareRow)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)

case class BatchActionMirror(groups: List[(String, List[Row])])

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])

Expand All @@ -59,8 +59,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
ActionMirror(string, prepare(Row())._2)

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn)
returningBehavior: ReturnAction) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior)

def executeBatchAction(groups: List[BatchGroup]) =
BatchActionMirror {
Expand All @@ -73,8 +73,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]) =
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)

Expand Down
Loading

0 comments on commit 7219984

Please sign in to comment.