Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning Record #1489

Merged
merged 4 commits into from
Jul 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ quill-jdbc/src/main/resources/logback.xml
log.txt*
tmp
nohup.out
.bloop/
.metals/
project/.bloop/
/io/
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MysqlAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConnect
def this(naming: N, config: Config) = this(naming, MysqlAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result match {
case r: MySQLQueryResult =>
returningExtractor(new ArrayRowData(0, Map.empty, Array(r.lastInsertId)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill.context.async.mysql

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, MysqlAsyncContext, Spec }
import io.getquill.{ Literal, MysqlAsyncContext, ReturnAction, Spec }

class MysqlAsyncContextSpec extends Spec {

Expand All @@ -18,7 +19,7 @@ class MysqlAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
Expand All @@ -35,13 +36,13 @@ class MysqlAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends MysqlAsyncContext(Literal, "testMysqlDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
val prd = Product(0L, "test1", 1L)
val inserted = await {
testContext.run {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
}
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -47,7 +47,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
"Single insert with free variable and explicit quotation" in {
val prd = Product(0L, "test2", 2L)
val q1 = quote {
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
product.insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
val inserted = await(testContext.run(q1))
val returnedProduct = await(testContext.run(productById(lift(inserted)))).head
Expand All @@ -60,7 +60,7 @@ class ProductMysqlAsyncSpec extends ProductSpec {
case class Product(id: Id, description: String, sku: Long)
val prd = Product(Id(0L), "test2", 2L)
val q1 = quote {
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returning(_.id)
query[Product].insert(_.sku -> lift(prd.sku), _.description -> lift(prd.description)).returningGenerated(_.id)
}
await(testContext.run(q1)) mustBe a[Id]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill

import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.pool.PartitionedConnectionPool
import com.github.mauricio.async.db.postgresql.PostgreSQLConnection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.typesafe.config.Config
import io.getquill.ReturnAction.{ ReturnColumns, ReturnNothing, ReturnRecord }
import io.getquill.context.async.{ ArrayDecoders, ArrayEncoders, AsyncContext, UUIDObjectEncoding }
import io.getquill.util.LoadConfig
import io.getquill.util.Messages.fail
Expand All @@ -18,7 +19,7 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
def this(naming: N, config: Config) = this(naming, PostgresAsyncContextConfig(config))
def this(naming: N, configPrefix: String) = this(naming, LoadConfig(configPrefix))

override protected def extractActionResult[O](returningColumn: String, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
override protected def extractActionResult[O](returningAction: ReturnAction, returningExtractor: Extractor[O])(result: DBQueryResult): O = {
result.rows match {
case Some(r) if r.nonEmpty =>
returningExtractor(r.head)
Expand All @@ -27,6 +28,14 @@ class PostgresAsyncContext[N <: NamingStrategy](naming: N, pool: PartitionedConn
}
}

override protected def expandAction(sql: String, returningColumn: String): String =
s"$sql RETURNING $returningColumn"
override protected def expandAction(sql: String, returningAction: ReturnAction): String =
returningAction match {
// The Postgres dialect will create SQL that has a 'RETURNING' clause so we don't have to add one.
case ReturnRecord => s"$sql"
// The Postgres dialect will not actually use these below variants but in case we decide to plug
// in some other dialect into this context...
case ReturnColumns(columns) => s"$sql RETURNING ${columns.mkString(", ")}"
case ReturnNothing => s"$sql"
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.getquill.context.async.postgres

import com.github.mauricio.async.db.QueryResult
import io.getquill.ReturnAction.ReturnColumns

import scala.concurrent.ExecutionContext.Implicits.global
import io.getquill.{ Literal, PostgresAsyncContext, Spec }
import io.getquill.{ Literal, PostgresAsyncContext, ReturnAction, Spec }

class PostgresAsyncContextSpec extends Spec {

Expand All @@ -18,11 +19,18 @@ class PostgresAsyncContextSpec extends Spec {

"Insert with returning with single column table" in {
val inserted: Long = await(testContext.run {
qr4.insert(lift(TestEntity4(0))).returning(_.i)
qr4.insert(lift(TestEntity4(0))).returningGenerated(_.i)
})
await(testContext.run(qr4.filter(_.i == lift(inserted))))
.head.i mustBe inserted
}
"Insert with returning with multiple columns" in {
await(testContext.run(qr1.delete))
val inserted = await(testContext.run {
qr1.insert(lift(TestEntity("foo", 1, 18L, Some(123)))).returning(r => (r.i, r.s, r.o))
})
(1, "foo", Some(123)) mustBe inserted
}

"performIO" in {
await(performIO(runIO(qr4).transactional))
Expand All @@ -35,13 +43,13 @@ class PostgresAsyncContextSpec extends Spec {
"cannot extract" in {
object ctx extends PostgresAsyncContext(Literal, "testPostgresDB") {
override def extractActionResult[O](
returningColumn: String,
returningAction: ReturnAction,
returningExtractor: ctx.Extractor[O]
)(result: QueryResult) =
super.extractActionResult(returningColumn, returningExtractor)(result)
super.extractActionResult(returningAction, returningExtractor)(result)
}
intercept[IllegalStateException] {
ctx.extractActionResult("w/e", row => 1)(new QueryResult(0, "w/e"))
ctx.extractActionResult(ReturnColumns(List("w/e")), row => 1)(new QueryResult(0, "w/e"))
}
ctx.close
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import com.github.mauricio.async.db.Connection
import com.github.mauricio.async.db.{ QueryResult => DBQueryResult }
import com.github.mauricio.async.db.RowData
import com.github.mauricio.async.db.pool.PartitionedConnectionPool

import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Try
import io.getquill.context.sql.SqlContext
import io.getquill.context.sql.idiom.SqlIdiom
import io.getquill.NamingStrategy
import io.getquill.{ NamingStrategy, ReturnAction }
import io.getquill.util.ContextLogger
import io.getquill.monad.ScalaFutureIOMonad
import io.getquill.context.{ Context, TranslateContext }
Expand Down Expand Up @@ -48,9 +49,9 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
case other => f(pool)
}

protected def extractActionResult[O](returningColumn: String, extractor: Extractor[O])(result: DBQueryResult): O
protected def extractActionResult[O](returningAction: ReturnAction, extractor: Extractor[O])(result: DBQueryResult): O

protected def expandAction(sql: String, returningColumn: String) = sql
protected def expandAction(sql: String, returningAction: ReturnAction) = sql

def probe(sql: String) =
Try {
Expand Down Expand Up @@ -88,12 +89,12 @@ abstract class AsyncContext[D <: SqlIdiom, N <: NamingStrategy, C <: Connection]
withConnection(_.sendPreparedStatement(sql, values)).map(_.rowsAffected)
}

def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningColumn: String)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningColumn)
def executeActionReturning[T](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[T], returningAction: ReturnAction)(implicit ec: ExecutionContext): Future[T] = {
val expanded = expandAction(sql, returningAction)
val (params, values) = prepare(Nil)
logger.logQuery(sql, params)
withConnection(_.sendPreparedStatement(expanded, values))
.map(extractActionResult(returningColumn, extractor))
.map(extractActionResult(returningAction, extractor))
}

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext): Future[List[Long]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.getquill.context.cassandra

import io.getquill.ast.{ TraversableOperation, _ }
import io.getquill.NamingStrategy
import io.getquill.context.CannotReturn
import io.getquill.util.Messages.fail
import io.getquill.idiom.Idiom
import io.getquill.idiom.StatementInterpolator._
Expand All @@ -10,7 +11,7 @@ import io.getquill.idiom.SetContainsToken
import io.getquill.idiom.Token
import io.getquill.util.Interleave

object CqlIdiom extends CqlIdiom
object CqlIdiom extends CqlIdiom with CannotReturn

trait CqlIdiom extends Idiom {

Expand All @@ -33,6 +34,7 @@ trait CqlIdiom extends Idiom {
case a: Operation => a.token
case a: Action => a.token
case a: Ident => a.token
case a: ExternalIdent => a.token
case a: Property => a.token
case a: Value => a.token
case a: Function => a.body.token
Expand Down Expand Up @@ -135,6 +137,10 @@ trait CqlIdiom extends Idiom {
case e => strategy.default(e.name).token
}

implicit def externalIdentTokenizer(implicit strategy: NamingStrategy): Tokenizer[ExternalIdent] = Tokenizer[ExternalIdent] {
case e => strategy.default(e.name).token
}

implicit def assignmentTokenizer(implicit propertyTokenizer: Tokenizer[Property], strategy: NamingStrategy): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(alias, prop, value) =>
stmt"${prop.token} = ${value.token}"
Expand Down Expand Up @@ -172,7 +178,7 @@ trait CqlIdiom extends Idiom {
case Delete(table) =>
stmt"TRUNCATE ${table.token}"

case _: Returning =>
case _: Returning | _: ReturningGenerated =>
fail(s"Cql doesn't support returning generated during insertion")

case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.getquill.context.cassandra
import io.getquill._
import io.getquill.idiom.StatementInterpolator._
import io.getquill.ast.{ Action => AstAction, _ }
import io.getquill.idiom.StringToken

class CqlIdiomSpec extends Spec {

Expand Down Expand Up @@ -38,10 +39,7 @@ class CqlIdiomSpec extends Spec {
"SELECT s FROM TestEntity WHERE i = 1 ORDER BY s ASC LIMIT 1"
}
"returning" in {
val q = quote {
query[TestEntity].insert(_.l -> 1L).returning(_.i)
}
"mirrorContext.run(q).string" mustNot compile
"mirrorContext.run(query[TestEntity].insert(_.l -> 1L).returning(_.i)).string" mustNot compile
}
}

Expand Down Expand Up @@ -392,5 +390,10 @@ class CqlIdiomSpec extends Spec {
intercept[IllegalStateException](t.token(null: AstAction))
intercept[IllegalStateException](t.token(Insert(Nested(Ident("a")), Nil)))
}
// not actually used anywhere but doing a sanity check here
"external ident sanity check" in {
val t = implicitly[Tokenizer[ExternalIdent]]
t.token(ExternalIdent("TestIdent")) mustBe StringToken("TestIdent")
}
}
}
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/AsyncMirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom

case class ActionMirror(string: String, prepareRow: PrepareRow)(implicit val ec: ExecutionContext)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)(implicit val ec: ExecutionContext)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)(implicit val ec: ExecutionContext)

case class BatchActionMirror(groups: List[(String, List[Row])])(implicit val ec: ExecutionContext)

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])(implicit val ec: ExecutionContext)

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])(implicit val ec: ExecutionContext)

Expand All @@ -74,8 +74,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future(ActionMirror(string, prepare(Row())._2))

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn))
returningBehavior: ReturnAction)(implicit ec: ExecutionContext) =
Future(ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior))

def executeBatchAction(groups: List[BatchGroup])(implicit ec: ExecutionContext) =
Future {
Expand All @@ -91,8 +91,8 @@ class AsyncMirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom
Future {
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)
}
Expand Down
12 changes: 6 additions & 6 deletions quill-core/src/main/scala/io/getquill/MirrorContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi

case class ActionMirror(string: String, prepareRow: PrepareRow)

case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningColumn: String)
case class ActionReturningMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T], returningBehavior: ReturnAction)

case class BatchActionMirror(groups: List[(String, List[Row])])

case class BatchActionReturningMirror[T](groups: List[(String, String, List[PrepareRow])], extractor: Extractor[T])
case class BatchActionReturningMirror[T](groups: List[(String, ReturnAction, List[PrepareRow])], extractor: Extractor[T])

case class QueryMirror[T](string: String, prepareRow: PrepareRow, extractor: Extractor[T])

Expand All @@ -59,8 +59,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
ActionMirror(string, prepare(Row())._2)

def executeActionReturning[O](string: String, prepare: Prepare = identityPrepare, extractor: Extractor[O],
returningColumn: String) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningColumn)
returningBehavior: ReturnAction) =
ActionReturningMirror[O](string, prepare(Row())._2, extractor, returningBehavior)

def executeBatchAction(groups: List[BatchGroup]) =
BatchActionMirror {
Expand All @@ -73,8 +73,8 @@ class MirrorContext[Idiom <: BaseIdiom, Naming <: NamingStrategy](val idiom: Idi
def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T]) =
BatchActionReturningMirror[T](
groups.map {
case BatchGroupReturning(string, column, prepare) =>
(string, column, prepare.map(_(Row())._2))
case BatchGroupReturning(string, returningBehavior, prepare) =>
(string, returningBehavior, prepare.map(_(Row())._2))
}, extractor
)

Expand Down
Loading