From 5c3141cd5eb8b57d007a81b7b1949a7724242089 Mon Sep 17 00:00:00 2001 From: Alex Boisvert Date: Wed, 20 Nov 2024 22:36:08 -0700 Subject: [PATCH] Add SELECT FOR UPDATE variants for Postgres and MySql (#45) --- docs/reference.md | 66 ++++++++++++++++++- scalasql/query/src/Select.scala | 19 +++++- scalasql/query/src/SimpleSelect.scala | 12 +++- scalasql/query/src/WithCte.scala | 1 + scalasql/src/dialects/H2Dialect.scala | 6 +- scalasql/src/dialects/MySqlDialect.scala | 39 ++++++++++- scalasql/src/dialects/PostgresDialect.scala | 29 ++++++++ scalasql/src/dialects/SqliteDialect.scala | 6 +- scalasql/src/dialects/TableOps.scala | 2 +- .../test/src/dialects/MySqlDialectTests.scala | 21 ++++++ .../src/dialects/PostgresDialectTests.scala | 22 ++++++- 11 files changed, 212 insertions(+), 11 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index 63e3193f..2d242780 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -10243,7 +10243,7 @@ db.run(OptDataTypes.select) ==> Seq(rowSome, rowNone) Operations specific to working with Postgres Databases ### PostgresDialect.distinctOn -ScalaSql's Postgres dialect provides teh `.distinctOn` operator, which translates +ScalaSql's Postgres dialect provides the `.distinctOn` operator, which translates into a SQL `DISTINCT ON` clause ```scala @@ -10276,6 +10276,38 @@ Purchase.select.distinctOn(_.shippingInfoId).sortBy(_.shippingInfoId).desc +### PostgresDialect.forUpdate + +ScalaSql's Postgres dialect provides the `.forUpdate` operator, which translates +into a SQL `SELECT ... FOR UPDATE` clause + +```scala +Invoice.select.filter(_.id === 1).forUpdate +``` + + +* + ```sql + SELECT + invoice0.id AS id, + invoice0.total AS total, + invoice0.vendor_name AS vendor_name + FROM otherschema.invoice invoice0 + WHERE (invoice0.id = ?) + FOR UPDATE + ``` + + + +* + ```scala + Seq( + Invoice[Sc](1, 150.4, "Siemens") + ) + ``` + + + ### PostgresDialect.ltrim2 @@ -10480,6 +10512,38 @@ db.random ## MySqlDialect Operations specific to working with MySql Databases +### MySqlDialect.forUpdate + +ScalaSql's MySql dialect provides the `.forUpdate` operator, which translates +into a SQL `SELECT ... FOR UPDATE` clause + +```scala +Buyer.select.filter(_.id === 1).forUpdate +``` + + +* + ```sql + SELECT + buyer0.id AS id, + buyer0.name AS name, + buyer0.date_of_birth AS date_of_birth + FROM buyer buyer0 + WHERE (buyer0.id = ?) + FOR UPDATE + ``` + + + +* + ```scala + Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")) + ) + ``` + + + ### MySqlDialect.reverse diff --git a/scalasql/query/src/Select.scala b/scalasql/query/src/Select.scala index c3f47b98..a493cba0 100644 --- a/scalasql/query/src/Select.scala +++ b/scalasql/query/src/Select.scala @@ -57,13 +57,14 @@ trait Select[Q, R] protected def newSimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], where: Seq[Expr[?]], groupBy0: Option[GroupBy] )(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] = - new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) def qr: Queryable.Row[Q, R] @@ -71,7 +72,9 @@ trait Select[Q, R] * Causes this [[Select]] to ignore duplicate rows, translates into SQL `SELECT DISTINCT` */ def distinct: Select[Q, R] = selectWithExprPrefix(true, _ => sql"DISTINCT") + protected def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] + protected def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] protected def subqueryRef(implicit qr: Queryable.Row[Q, R]) = new SubqueryRef(this) @@ -227,7 +230,7 @@ trait Select[Q, R] * in this [[Select]] */ def subquery: SimpleSelect[Q, R] = { - newSimpleSelect(expr, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect) + newSimpleSelect(expr, None, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect) } /** @@ -278,19 +281,23 @@ object Select { lhs: Select[Q, R], expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], where: Seq[Expr[?]], groupBy0: Option[GroupBy] )(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] = - lhs.newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + lhs.newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) def toSimpleFrom[Q, R](s: Select[Q, R]) = s.selectToSimpleSelect() def withExprPrefix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) = s.selectWithExprPrefix(preserveAll, str) + def withExprSuffix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) = + s.selectWithExprSuffix(preserveAll, str) + implicit class ExprSelectOps[T](s: Select[Expr[T], T]) { def sorted(implicit tm: TypeMapper[T]): Select[Expr[T], T] = s.sortBy(identity) } @@ -303,6 +310,12 @@ object Select { ): Select[Q, R] = selectToSimpleSelect().selectWithExprPrefix(preserveAll, s) + override protected def selectWithExprSuffix( + preserveAll: Boolean, + s: Context => SqlStr + ): Select[Q, R] = + selectToSimpleSelect().selectWithExprSuffix(preserveAll, s) + override def map[Q2, R2](f: Q => Q2)(implicit qr: Queryable.Row[Q2, R2]): Select[Q2, R2] = selectToSimpleSelect().map(f) diff --git a/scalasql/query/src/SimpleSelect.scala b/scalasql/query/src/SimpleSelect.scala index c5359beb..54ebc0b4 100644 --- a/scalasql/query/src/SimpleSelect.scala +++ b/scalasql/query/src/SimpleSelect.scala @@ -22,6 +22,7 @@ import scalasql.renderer.JoinsToSql class SimpleSelect[Q, R]( val expr: Q, val exprPrefix: Option[Context => SqlStr], + val exprSuffix: Option[Context => SqlStr], val preserveAll: Boolean, val from: Seq[Context.From], val joins: Seq[Join], @@ -33,17 +34,21 @@ class SimpleSelect[Q, R]( protected def copy[Q, R]( expr: Q = this.expr, exprPrefix: Option[Context => SqlStr] = this.exprPrefix, + exprSuffix: Option[Context => SqlStr] = this.exprSuffix, preserveAll: Boolean = this.preserveAll, from: Seq[Context.From] = this.from, joins: Seq[Join] = this.joins, where: Seq[Expr[?]] = this.where, groupBy0: Option[GroupBy] = this.groupBy0 )(implicit qr: Queryable.Row[Q, R]) = - newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] = this.copy(exprPrefix = Some(s), preserveAll = preserveAll) + def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] = + this.copy(exprSuffix = Some(s), preserveAll = preserveAll) + def aggregateExpr[V: TypeMapper]( f: Q => Context => SqlStr )(implicit qr2: Queryable.Row[Expr[V], V]): Expr[V] = { @@ -111,6 +116,7 @@ class SimpleSelect[Q, R]( copy( expr = newExpr, exprPrefix = exprPrefix, + exprSuffix = exprSuffix, joins = joins ++ newJoins, where = where ++ newWheres ) @@ -178,6 +184,7 @@ class SimpleSelect[Q, R]( copy( expr = newExpr, exprPrefix = exprPrefix, + exprSuffix = exprSuffix, from = Seq(this.subqueryRef), joins = Nil, where = Nil, @@ -287,11 +294,12 @@ object SimpleSelect { ) lazy val exprPrefix = SqlStr.opt(query.exprPrefix) { p => p(context) + sql" " } + lazy val exprSuffix = SqlStr.opt(query.exprSuffix) { p => p(context) } val tables = SqlStr .join(query.from.map(renderedFroms(_)), SqlStr.commaSep) - sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt + sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt + exprSuffix } } diff --git a/scalasql/query/src/WithCte.scala b/scalasql/query/src/WithCte.scala index 16b51c2d..740a83dc 100644 --- a/scalasql/query/src/WithCte.scala +++ b/scalasql/query/src/WithCte.scala @@ -73,6 +73,7 @@ object WithCte { lhs, expr = WithSqlExpr.get(lhs), exprPrefix = None, + exprSuffix = None, preserveAll = false, from = Seq(lhsSubQueryRef), joins = Nil, diff --git a/scalasql/src/dialects/H2Dialect.scala b/scalasql/src/dialects/H2Dialect.scala index 1543efe4..139f3e5b 100644 --- a/scalasql/src/dialects/H2Dialect.scala +++ b/scalasql/src/dialects/H2Dialect.scala @@ -104,6 +104,7 @@ object H2Dialect extends H2Dialect { new SimpleSelect( Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]], None, + None, false, Seq(ref), Nil, @@ -132,6 +133,7 @@ object H2Dialect extends H2Dialect { override def newSimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -141,13 +143,14 @@ object H2Dialect extends H2Dialect { implicit qr: Queryable.Row[Q, R], dialect: scalasql.core.DialectTypeMappers ): scalasql.query.SimpleSelect[Q, R] = { - new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) } } class SimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -157,6 +160,7 @@ object H2Dialect extends H2Dialect { extends scalasql.query.SimpleSelect( expr, exprPrefix, + exprSuffix, preserveAll, from, joins, diff --git a/scalasql/src/dialects/MySqlDialect.scala b/scalasql/src/dialects/MySqlDialect.scala index 7b392d99..331345aa 100644 --- a/scalasql/src/dialects/MySqlDialect.scala +++ b/scalasql/src/dialects/MySqlDialect.scala @@ -40,6 +40,7 @@ import java.sql.PreparedStatement import java.time.{Instant, LocalDateTime} import java.util.UUID import scala.reflect.ClassTag +import scalasql.query.Select trait MySqlDialect extends Dialect { protected def dialectCastParams = false @@ -116,6 +117,38 @@ trait MySqlDialect extends Dialect { implicit def ExprAggOpsConv[T](v: Aggregatable[Expr[T]]): operations.ExprAggOps[T] = new MySqlDialect.ExprAggOps(v) + implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) { + + /** + * SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from + * modifying or locking the selected rows, which is for managing concurrent transactions + * and ensuring data consistency in multi-step operations. + */ + def forUpdate: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR UPDATE") + + /** + * SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions + * to read but not modify the locked rows + */ + def forShare: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR SHARE") + + /** + * SELECT ... FOR UPDATE NOWAIT: Immediately returns an error if the selected rows are + * already locked, instead of waiting + */ + def forUpdateNoWait: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR UPDATE NOWAIT") + + /** + * SELECT ... FOR UPDATE SKIP LOCKED: Skips any rows that are already locked by other + * transactions, instead of waiting + */ + def forUpdateSkipLocked: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR UPDATE SKIP LOCKED") + } + override implicit def DbApiOpsConv(db: => DbApi): MySqlDialect.DbApiOps = new MySqlDialect.DbApiOps(this) @@ -207,6 +240,7 @@ object MySqlDialect extends MySqlDialect { new SimpleSelect( Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]], None, + None, false, Seq(ref), Nil, @@ -309,6 +343,7 @@ object MySqlDialect extends MySqlDialect { override def newSimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -318,13 +353,14 @@ object MySqlDialect extends MySqlDialect { implicit qr: Queryable.Row[Q, R], dialect: scalasql.core.DialectTypeMappers ): scalasql.query.SimpleSelect[Q, R] = { - new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) } } class SimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -334,6 +370,7 @@ object MySqlDialect extends MySqlDialect { extends scalasql.query.SimpleSelect( expr, exprPrefix, + exprSuffix, preserveAll, from, joins, diff --git a/scalasql/src/dialects/PostgresDialect.scala b/scalasql/src/dialects/PostgresDialect.scala index 038967d1..6c98ca04 100644 --- a/scalasql/src/dialects/PostgresDialect.scala +++ b/scalasql/src/dialects/PostgresDialect.scala @@ -54,6 +54,35 @@ trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps { } } + implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) { + + /** + * SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from + * modifying or locking the selected rows, which is for managing concurrent transactions + * and ensuring data consistency in multi-step operations. + */ + def forUpdate: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR UPDATE") + + /** + * SELECT ... FOR NO KEY UPDATE: A weaker lock that doesn't block inserts into child + * tables with foreign key references + */ + def forNoKeyUpdate: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR NO KEY UPDATE") + + /** + * SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions + * to read but not modify the locked rows. + */ + def forShare: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR SHARE") + + /** SELECT ... FOR KEY SHARE: The weakest lock, only conflicts with FOR UPDATE */ + def forKeyShare: Select[Q, R] = + Select.withExprSuffix(r, true, _ => sql" FOR KEY SHARE") + } + override implicit def DbApiOpsConv(db: => DbApi): PostgresDialect.DbApiOps = new PostgresDialect.DbApiOps(this) } diff --git a/scalasql/src/dialects/SqliteDialect.scala b/scalasql/src/dialects/SqliteDialect.scala index 1e143472..aca56043 100644 --- a/scalasql/src/dialects/SqliteDialect.scala +++ b/scalasql/src/dialects/SqliteDialect.scala @@ -193,6 +193,7 @@ object SqliteDialect extends SqliteDialect { new SimpleSelect( Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]], None, + None, false, Seq(ref), Nil, @@ -221,6 +222,7 @@ object SqliteDialect extends SqliteDialect { override def newSimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -230,13 +232,14 @@ object SqliteDialect extends SqliteDialect { implicit qr: Queryable.Row[Q, R], dialect: scalasql.core.DialectTypeMappers ): scalasql.query.SimpleSelect[Q, R] = { - new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0) + new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) } } class SimpleSelect[Q, R]( expr: Q, exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], preserveAll: Boolean, from: Seq[Context.From], joins: Seq[Join], @@ -246,6 +249,7 @@ object SqliteDialect extends SqliteDialect { extends scalasql.query.SimpleSelect( expr, exprPrefix, + exprSuffix, preserveAll, from, joins, diff --git a/scalasql/src/dialects/TableOps.scala b/scalasql/src/dialects/TableOps.scala index 2c166899..ab3b89c5 100644 --- a/scalasql/src/dialects/TableOps.scala +++ b/scalasql/src/dialects/TableOps.scala @@ -22,7 +22,7 @@ class TableOps[V[_[_]]](val t: Table[V])(implicit dialect: Dialect) protected def joinableToSelect: Select[V[Expr], V[Sc]] = { val (ref, expr) = joinableToFromExpr - new SimpleSelect(expr, None, false, Seq(ref), Nil, Nil, None)( + new SimpleSelect(expr, None, None, false, Seq(ref), Nil, Nil, None)( t.containerQr, dialect ) diff --git a/scalasql/test/src/dialects/MySqlDialectTests.scala b/scalasql/test/src/dialects/MySqlDialectTests.scala index cc457fbb..96ec8837 100644 --- a/scalasql/test/src/dialects/MySqlDialectTests.scala +++ b/scalasql/test/src/dialects/MySqlDialectTests.scala @@ -11,6 +11,27 @@ trait MySqlDialectTests extends MySqlSuite { def description = "Operations specific to working with MySql Databases" override def utestBeforeEach(path: Seq[String]): Unit = checker.reset() def tests = Tests { + + test("forUpdate") - checker( + query = Buyer.select.filter(_.id === 1).forUpdate, + sql = """ + SELECT + buyer0.id AS id, + buyer0.name AS name, + buyer0.date_of_birth AS date_of_birth + FROM buyer buyer0 + WHERE (buyer0.id = ?) + FOR UPDATE + """, + value = Seq( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")) + ), + docs = """ + ScalaSql's MySql dialect provides the `.forUpdate` operator, which translates + into a SQL `SELECT ... FOR UPDATE` clause + """ + ) + test("reverse") - checker(query = Expr("Hello").reverse, sql = "SELECT REVERSE(?) AS res", value = "olleH") diff --git a/scalasql/test/src/dialects/PostgresDialectTests.scala b/scalasql/test/src/dialects/PostgresDialectTests.scala index bfc4804d..7d40d2c6 100644 --- a/scalasql/test/src/dialects/PostgresDialectTests.scala +++ b/scalasql/test/src/dialects/PostgresDialectTests.scala @@ -27,11 +27,31 @@ trait PostgresDialectTests extends PostgresSuite { Purchase[Sc](2, 1, 2, 3, 900.0) ), docs = """ - ScalaSql's Postgres dialect provides teh `.distinctOn` operator, which translates + ScalaSql's Postgres dialect provides the `.distinctOn` operator, which translates into a SQL `DISTINCT ON` clause """ ) + test("forUpdate") - checker( + query = Invoice.select.filter(_.id === 1).forUpdate, + sql = """ + SELECT + invoice0.id AS id, + invoice0.total AS total, + invoice0.vendor_name AS vendor_name + FROM otherschema.invoice invoice0 + WHERE (invoice0.id = ?) + FOR UPDATE + """, + value = Seq( + Invoice[Sc](1, 150.4, "Siemens") + ), + docs = """ + ScalaSql's Postgres dialect provides the `.forUpdate` operator, which translates + into a SQL `SELECT ... FOR UPDATE` clause + """ + ) + test("ltrim2") - checker( query = Expr("xxHellox").ltrim("x"), sql = "SELECT LTRIM(?, ?) AS res",