Skip to content

Commit

Permalink
Add SELECT FOR UPDATE variants for Postgres and MySql (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
aboisvert authored Nov 21, 2024
1 parent 78e0970 commit 5c3141c
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 11 deletions.
66 changes: 65 additions & 1 deletion docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -10243,7 +10243,7 @@ db.run(OptDataTypes.select) ==> Seq(rowSome, rowNone)
Operations specific to working with Postgres Databases
### PostgresDialect.distinctOn
ScalaSql's Postgres dialect provides teh `.distinctOn` operator, which translates
ScalaSql's Postgres dialect provides the `.distinctOn` operator, which translates
into a SQL `DISTINCT ON` clause

```scala
Expand Down Expand Up @@ -10276,6 +10276,38 @@ Purchase.select.distinctOn(_.shippingInfoId).sortBy(_.shippingInfoId).desc



### PostgresDialect.forUpdate

ScalaSql's Postgres dialect provides the `.forUpdate` operator, which translates
into a SQL `SELECT ... FOR UPDATE` clause
```scala
Invoice.select.filter(_.id === 1).forUpdate
```
*
```sql
SELECT
invoice0.id AS id,
invoice0.total AS total,
invoice0.vendor_name AS vendor_name
FROM otherschema.invoice invoice0
WHERE (invoice0.id = ?)
FOR UPDATE
```
*
```scala
Seq(
Invoice[Sc](1, 150.4, "Siemens")
)
```
### PostgresDialect.ltrim2
Expand Down Expand Up @@ -10480,6 +10512,38 @@ db.random
## MySqlDialect
Operations specific to working with MySql Databases
### MySqlDialect.forUpdate
ScalaSql's MySql dialect provides the `.forUpdate` operator, which translates
into a SQL `SELECT ... FOR UPDATE` clause

```scala
Buyer.select.filter(_.id === 1).forUpdate
```


*
```sql
SELECT
buyer0.id AS id,
buyer0.name AS name,
buyer0.date_of_birth AS date_of_birth
FROM buyer buyer0
WHERE (buyer0.id = ?)
FOR UPDATE
```



*
```scala
Seq(
Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03"))
)
```



### MySqlDialect.reverse


Expand Down
19 changes: 16 additions & 3 deletions scalasql/query/src/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@ trait Select[Q, R]
protected def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
where: Seq[Expr[?]],
groupBy0: Option[GroupBy]
)(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] =
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def qr: Queryable.Row[Q, R]

/**
* Causes this [[Select]] to ignore duplicate rows, translates into SQL `SELECT DISTINCT`
*/
def distinct: Select[Q, R] = selectWithExprPrefix(true, _ => sql"DISTINCT")

protected def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R]
protected def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R]

protected def subqueryRef(implicit qr: Queryable.Row[Q, R]) = new SubqueryRef(this)

Expand Down Expand Up @@ -227,7 +230,7 @@ trait Select[Q, R]
* in this [[Select]]
*/
def subquery: SimpleSelect[Q, R] = {
newSimpleSelect(expr, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect)
newSimpleSelect(expr, None, None, false, Seq(subqueryRef(qr)), Nil, Nil, None)(qr, dialect)
}

/**
Expand Down Expand Up @@ -278,19 +281,23 @@ object Select {
lhs: Select[Q, R],
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
where: Seq[Expr[?]],
groupBy0: Option[GroupBy]
)(implicit qr: Queryable.Row[Q, R], dialect: DialectTypeMappers): SimpleSelect[Q, R] =
lhs.newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
lhs.newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def toSimpleFrom[Q, R](s: Select[Q, R]) = s.selectToSimpleSelect()

def withExprPrefix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) =
s.selectWithExprPrefix(preserveAll, str)

def withExprSuffix[Q, R](s: Select[Q, R], preserveAll: Boolean, str: Context => SqlStr) =
s.selectWithExprSuffix(preserveAll, str)

implicit class ExprSelectOps[T](s: Select[Expr[T], T]) {
def sorted(implicit tm: TypeMapper[T]): Select[Expr[T], T] = s.sortBy(identity)
}
Expand All @@ -303,6 +310,12 @@ object Select {
): Select[Q, R] =
selectToSimpleSelect().selectWithExprPrefix(preserveAll, s)

override protected def selectWithExprSuffix(
preserveAll: Boolean,
s: Context => SqlStr
): Select[Q, R] =
selectToSimpleSelect().selectWithExprSuffix(preserveAll, s)

override def map[Q2, R2](f: Q => Q2)(implicit qr: Queryable.Row[Q2, R2]): Select[Q2, R2] =
selectToSimpleSelect().map(f)

Expand Down
12 changes: 10 additions & 2 deletions scalasql/query/src/SimpleSelect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scalasql.renderer.JoinsToSql
class SimpleSelect[Q, R](
val expr: Q,
val exprPrefix: Option[Context => SqlStr],
val exprSuffix: Option[Context => SqlStr],
val preserveAll: Boolean,
val from: Seq[Context.From],
val joins: Seq[Join],
Expand All @@ -33,17 +34,21 @@ class SimpleSelect[Q, R](
protected def copy[Q, R](
expr: Q = this.expr,
exprPrefix: Option[Context => SqlStr] = this.exprPrefix,
exprSuffix: Option[Context => SqlStr] = this.exprSuffix,
preserveAll: Boolean = this.preserveAll,
from: Seq[Context.From] = this.from,
joins: Seq[Join] = this.joins,
where: Seq[Expr[?]] = this.where,
groupBy0: Option[GroupBy] = this.groupBy0
)(implicit qr: Queryable.Row[Q, R]) =
newSimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
newSimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)

def selectWithExprPrefix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] =
this.copy(exprPrefix = Some(s), preserveAll = preserveAll)

def selectWithExprSuffix(preserveAll: Boolean, s: Context => SqlStr): Select[Q, R] =
this.copy(exprSuffix = Some(s), preserveAll = preserveAll)

def aggregateExpr[V: TypeMapper](
f: Q => Context => SqlStr
)(implicit qr2: Queryable.Row[Expr[V], V]): Expr[V] = {
Expand Down Expand Up @@ -111,6 +116,7 @@ class SimpleSelect[Q, R](
copy(
expr = newExpr,
exprPrefix = exprPrefix,
exprSuffix = exprSuffix,
joins = joins ++ newJoins,
where = where ++ newWheres
)
Expand Down Expand Up @@ -178,6 +184,7 @@ class SimpleSelect[Q, R](
copy(
expr = newExpr,
exprPrefix = exprPrefix,
exprSuffix = exprSuffix,
from = Seq(this.subqueryRef),
joins = Nil,
where = Nil,
Expand Down Expand Up @@ -287,11 +294,12 @@ object SimpleSelect {
)

lazy val exprPrefix = SqlStr.opt(query.exprPrefix) { p => p(context) + sql" " }
lazy val exprSuffix = SqlStr.opt(query.exprSuffix) { p => p(context) }

val tables = SqlStr
.join(query.from.map(renderedFroms(_)), SqlStr.commaSep)

sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt
sql"SELECT " + exprPrefix + exprStr + sql" FROM " + tables + joins + filtersOpt + groupByOpt + exprSuffix
}

}
Expand Down
1 change: 1 addition & 0 deletions scalasql/query/src/WithCte.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object WithCte {
lhs,
expr = WithSqlExpr.get(lhs),
exprPrefix = None,
exprSuffix = None,
preserveAll = false,
from = Seq(lhsSubQueryRef),
joins = Nil,
Expand Down
6 changes: 5 additions & 1 deletion scalasql/src/dialects/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ object H2Dialect extends H2Dialect {
new SimpleSelect(
Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]],
None,
None,
false,
Seq(ref),
Nil,
Expand Down Expand Up @@ -132,6 +133,7 @@ object H2Dialect extends H2Dialect {
override def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -141,13 +143,14 @@ object H2Dialect extends H2Dialect {
implicit qr: Queryable.Row[Q, R],
dialect: scalasql.core.DialectTypeMappers
): scalasql.query.SimpleSelect[Q, R] = {
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)
}
}

class SimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -157,6 +160,7 @@ object H2Dialect extends H2Dialect {
extends scalasql.query.SimpleSelect(
expr,
exprPrefix,
exprSuffix,
preserveAll,
from,
joins,
Expand Down
39 changes: 38 additions & 1 deletion scalasql/src/dialects/MySqlDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import java.sql.PreparedStatement
import java.time.{Instant, LocalDateTime}
import java.util.UUID
import scala.reflect.ClassTag
import scalasql.query.Select

trait MySqlDialect extends Dialect {
protected def dialectCastParams = false
Expand Down Expand Up @@ -116,6 +117,38 @@ trait MySqlDialect extends Dialect {
implicit def ExprAggOpsConv[T](v: Aggregatable[Expr[T]]): operations.ExprAggOps[T] =
new MySqlDialect.ExprAggOps(v)

implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) {

/**
* SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from
* modifying or locking the selected rows, which is for managing concurrent transactions
* and ensuring data consistency in multi-step operations.
*/
def forUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE")

/**
* SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions
* to read but not modify the locked rows
*/
def forShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR SHARE")

/**
* SELECT ... FOR UPDATE NOWAIT: Immediately returns an error if the selected rows are
* already locked, instead of waiting
*/
def forUpdateNoWait: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE NOWAIT")

/**
* SELECT ... FOR UPDATE SKIP LOCKED: Skips any rows that are already locked by other
* transactions, instead of waiting
*/
def forUpdateSkipLocked: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE SKIP LOCKED")
}

override implicit def DbApiOpsConv(db: => DbApi): MySqlDialect.DbApiOps =
new MySqlDialect.DbApiOps(this)

Expand Down Expand Up @@ -207,6 +240,7 @@ object MySqlDialect extends MySqlDialect {
new SimpleSelect(
Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]],
None,
None,
false,
Seq(ref),
Nil,
Expand Down Expand Up @@ -309,6 +343,7 @@ object MySqlDialect extends MySqlDialect {
override def newSimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -318,13 +353,14 @@ object MySqlDialect extends MySqlDialect {
implicit qr: Queryable.Row[Q, R],
dialect: scalasql.core.DialectTypeMappers
): scalasql.query.SimpleSelect[Q, R] = {
new SimpleSelect(expr, exprPrefix, preserveAll, from, joins, where, groupBy0)
new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0)
}
}

class SimpleSelect[Q, R](
expr: Q,
exprPrefix: Option[Context => SqlStr],
exprSuffix: Option[Context => SqlStr],
preserveAll: Boolean,
from: Seq[Context.From],
joins: Seq[Join],
Expand All @@ -334,6 +370,7 @@ object MySqlDialect extends MySqlDialect {
extends scalasql.query.SimpleSelect(
expr,
exprPrefix,
exprSuffix,
preserveAll,
from,
joins,
Expand Down
29 changes: 29 additions & 0 deletions scalasql/src/dialects/PostgresDialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps {
}
}

implicit class SelectForUpdateConv[Q, R](r: Select[Q, R]) {

/**
* SELECT .. FOR UPDATE acquires an exclusive lock, blocking other transactions from
* modifying or locking the selected rows, which is for managing concurrent transactions
* and ensuring data consistency in multi-step operations.
*/
def forUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR UPDATE")

/**
* SELECT ... FOR NO KEY UPDATE: A weaker lock that doesn't block inserts into child
* tables with foreign key references
*/
def forNoKeyUpdate: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR NO KEY UPDATE")

/**
* SELECT ... FOR SHARE: Locks the selected rows for reading, allowing other transactions
* to read but not modify the locked rows.
*/
def forShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR SHARE")

/** SELECT ... FOR KEY SHARE: The weakest lock, only conflicts with FOR UPDATE */
def forKeyShare: Select[Q, R] =
Select.withExprSuffix(r, true, _ => sql" FOR KEY SHARE")
}

override implicit def DbApiOpsConv(db: => DbApi): PostgresDialect.DbApiOps =
new PostgresDialect.DbApiOps(this)
}
Expand Down
Loading

0 comments on commit 5c3141c

Please sign in to comment.