Skip to content

Commit

Permalink
feat: EXPOSED-45 Support single statement UPSERT (#1743)
Browse files Browse the repository at this point in the history
* Add functionality for insert or update (upsert) command in all dialects.

MySQL, PostgreSQL, and SQLite use their own insert syntax, while the other
dialects use the standard merge syntax.

The implementation accepts optional user-defined conflict key columns, as well
as a single update where clause. It also allows update expressions to be provided
if they need to be different from the provided insert values. Column expressions,
defaults, and subqueries are also enabled if the individual dialect allows them.

* Extract excluding H2 v1 condition check function

* Add KDocs to UpsertStatement class
  • Loading branch information
bog-walk authored May 31, 2023
1 parent f030818 commit c0a8af0
Show file tree
Hide file tree
Showing 10 changed files with 696 additions and 0 deletions.
21 changes: 21 additions & 0 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ fun Join.update(where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null, limit:
return query.execute(TransactionManager.current())!!
}

/**
* Represents the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* **Note:** Vendors that do not support this operation directly implement the standard MERGE USING command.
*
* @param keys (optional) Columns to include in the condition that determines a unique constraint match.
* If no columns are provided, primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param where Condition that determines which rows to update, if a unique violation is found.
*/
fun <T : Table> T.upsert(
vararg keys: Column<*>,
onUpdate: List<Pair<Column<*>, Expression<*>>>? = null,
where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null,
body: T.(UpsertStatement<Long>) -> Unit
) = UpsertStatement<Long>(this, *keys, onUpdate = onUpdate, where = where?.let { SqlExpressionBuilder.it() }).apply {
body(this)
execute(TransactionManager.current())
}

/**
* @sample org.jetbrains.exposed.sql.tests.shared.DDLTests.tableExists02
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.vendors.*

/**
* Represents the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* @param table Table to either insert values into or update values from.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param where Condition that determines which rows to update, if a unique violation is found. This clause may not be supported by all vendors.
*/
open class UpsertStatement<Key : Any>(
table: Table,
vararg val keys: Column<*>,
val onUpdate: List<Pair<Column<*>, Expression<*>>>?,
val where: Op<Boolean>?
) : InsertStatement<Key>(table) {

override fun prepareSQL(transaction: Transaction): String {
val functionProvider = when (val dialect = transaction.db.dialect) {
is H2Dialect -> when (dialect.h2Mode) {
H2Dialect.H2CompatibilityMode.MariaDB, H2Dialect.H2CompatibilityMode.MySQL -> MysqlFunctionProvider()
else -> H2FunctionProvider
}
else -> dialect.functionProvider
}
return functionProvider.upsert(table, arguments!!.first(), onUpdate, where, transaction, *keys)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,132 @@ abstract class FunctionProvider {
transaction: Transaction
): String = transaction.throwUnsupportedException("There's no generic SQL for REPLACE. There must be vendor specific implementation.")

/**
* Returns the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* **Note:** Vendors that do not support this operation directly implement the standard MERGE USING command.
*
* @param table Table to either insert values into or update values from.
* @param data Pairs of columns to use for insert or update and values to insert or update.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* @param where Condition that determines which rows to update, if a unique violation is found.
* @param transaction Transaction where the operation is executed.
*/
open fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
if (where != null) {
transaction.throwUnsupportedException("MERGE implementation of UPSERT doesn't support single WHERE clause")
}
val keyColumns = getKeyColumnsForUpsert(table, *keys)
if (keyColumns.isNullOrEmpty()) {
transaction.throwUnsupportedException("UPSERT requires a unique key or constraint as a conflict target")
}

val dataColumns = data.unzip().first
val autoIncColumn = table.autoIncColumn
val nextValExpression = autoIncColumn?.autoIncColumnType?.nextValExpression
val dataColumnsWithoutAutoInc = autoIncColumn?.let { dataColumns - autoIncColumn } ?: dataColumns
val updateColumns = dataColumns.filter { it !in keyColumns }

return with(QueryBuilder(true)) {
+"MERGE INTO "
table.describe(transaction, this)
+" T USING "
data.appendTo(prefix = "(VALUES (", postfix = ")") { (column, value) ->
registerArgument(column, value)
}
dataColumns.appendTo(prefix = ") S(", postfix = ")") { column ->
append(transaction.identity(column))
}

+" ON "
keyColumns.appendTo(separator = " AND ", prefix = "(", postfix = ")") { column ->
val columnName = transaction.identity(column)
append("T.$columnName=S.$columnName")
}

+" WHEN MATCHED THEN"
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = true)

+" WHEN NOT MATCHED THEN INSERT "
dataColumnsWithoutAutoInc.appendTo(prefix = "(") { column ->
append(transaction.identity(column))
}
nextValExpression?.let {
append(", ${transaction.identity(autoIncColumn)}")
}
dataColumnsWithoutAutoInc.appendTo(prefix = ") VALUES(") { column ->
append("S.${transaction.identity(column)}")
}
nextValExpression?.let {
append(", $it")
}
+")"
toString()
}
}

/**
* Returns the columns to be used in the conflict condition of an upsert statement.
*/
protected fun getKeyColumnsForUpsert(table: Table, vararg keys: Column<*>): List<Column<*>>? {
return keys.toList().ifEmpty {
table.primaryKey?.columns?.toList() ?: table.indices.firstOrNull { it.unique }?.columns
}
}

/**
* Appends the complete default SQL insert (no ignore) command to [this] QueryBuilder.
*/
protected fun QueryBuilder.appendInsertToUpsertClause(table: Table, data: List<Pair<Column<*>, Any?>>, transaction: Transaction) {
val valuesSql = if (data.isEmpty()) {
""
} else {
data.appendTo(QueryBuilder(true), prefix = "VALUES (", postfix = ")") { (column, value) ->
registerArgument(column, value)
}.toString()
}
val insertStatement = insert(false, table, data.unzip().first, valuesSql, transaction)

+insertStatement
}

/**
* Appends an SQL update command for a derived table (with or without alias identifiers) to [this] QueryBuilder.
*/
protected fun QueryBuilder.appendUpdateToUpsertClause(
table: Table,
updateColumns: List<Column<*>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
transaction: Transaction,
isAliasNeeded: Boolean
) {
+" UPDATE SET "
onUpdate?.appendTo { (columnToUpdate, updateExpression) ->
if (isAliasNeeded) {
val aliasExpression = updateExpression.toString().replace(transaction.identity(table), "T")
append("T.${transaction.identity(columnToUpdate)}=${aliasExpression}")
} else {
append("${transaction.identity(columnToUpdate)}=${updateExpression}")
}
} ?: run {
updateColumns.appendTo { column ->
val columnName = transaction.identity(column)
if (isAliasNeeded) {
append("T.$columnName=S.$columnName")
} else {
append("$columnName=EXCLUDED.$columnName")
}
}
}
}

/**
* Returns the SQL command that deletes one or more rows of a table.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.exposed.sql.vendors

import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.math.BigDecimal
Expand Down Expand Up @@ -145,6 +146,49 @@ internal open class MysqlFunctionProvider : FunctionProvider() {
limit?.let { +" LIMIT $it" }
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
if (keys.isNotEmpty()) {
transaction.throwUnsupportedException("MySQL doesn't support specifying conflict keys in UPSERT clause")
}
if (where != null) {
transaction.throwUnsupportedException("MySQL doesn't support WHERE in UPSERT clause")
}

val isAliasSupported = when (val dialect = transaction.db.dialect) {
is MysqlDialect -> dialect !is MariaDBDialect && dialect.isMysql8
else -> false // H2_MySQL mode also uses this function provider & requires older version
}

return with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)
if (isAliasSupported) {
+" AS NEW"
}

+" ON DUPLICATE KEY UPDATE "
onUpdate?.appendTo { (columnToUpdate, updateExpression) ->
append("${transaction.identity(columnToUpdate)}=${updateExpression}")
} ?: run {
data.unzip().first.appendTo { column ->
val columnName = transaction.identity(column)
if (isAliasSupported) {
append("$columnName=NEW.$columnName")
} else {
append("$columnName=VALUES($columnName)")
}
}
}
toString()
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,28 @@ internal object OracleFunctionProvider : FunctionProvider() {
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
val statement = super.upsert(table, data, onUpdate, where, transaction, *keys)

val dualTable = data.appendTo(QueryBuilder(true), prefix = "(SELECT ", postfix = " FROM DUAL) S") { (column, value) ->
registerArgument(column, value)
+" AS "
append(transaction.identity(column))
}.toString()

val (leftReserved, rightReserved) = " USING " to " ON "
val leftBoundary = statement.indexOf(leftReserved) + leftReserved.length
val rightBoundary = statement.indexOf(rightReserved)
return statement.replaceRange(leftBoundary, rightBoundary, dualTable)
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,40 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() {
}
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
val keyColumns = getKeyColumnsForUpsert(table, *keys)
if (keyColumns.isNullOrEmpty()) {
transaction.throwUnsupportedException("UPSERT requires a unique key or constraint as a conflict target")
}

val updateColumns = data.unzip().first.filter { it !in keyColumns }

return with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)

+" ON CONFLICT "
keyColumns.appendTo(prefix = "(", postfix = ")") { column ->
append(transaction.identity(column))
}

+" DO"
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = false)

where?.let {
+" WHERE "
+it
}
toString()
}
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ internal object SQLServerFunctionProvider : FunctionProvider() {
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
// SQLSERVER MERGE statement must be terminated by a semi-colon (;)
return super.upsert(table, data, onUpdate, where, transaction, *keys) + ";"
}

override fun delete(ignore: Boolean, table: Table, where: String?, limit: Int?, transaction: Transaction): String {
val def = super.delete(ignore, table, where, null, transaction)
return if (limit != null) def.replaceFirst("DELETE", "DELETE TOP($limit)") else def
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ internal object SQLiteFunctionProvider : FunctionProvider() {
return "INSERT OR REPLACE INTO ${transaction.identity(table)} ($columns) VALUES ($values)"
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String = with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)

+" ON CONFLICT"
val keyColumns = getKeyColumnsForUpsert(table, *keys) ?: emptyList()
if (keyColumns.isNotEmpty()) {
keyColumns.appendTo(prefix = " (", postfix = ")") { column ->
append(transaction.identity(column))
}
}

+" DO"
val updateColumns = data.unzip().first.filter { it !in keyColumns }
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = false)

where?.let {
+" WHERE "
+it
}
toString()
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.junit.Assume
import org.junit.AssumptionViolatedException
import org.testcontainers.containers.MySQLContainer
import org.testcontainers.containers.PostgreSQLContainer
import java.math.BigDecimal
import java.sql.Connection
import java.sql.SQLException
import java.time.Duration
Expand Down Expand Up @@ -283,6 +284,12 @@ abstract class DatabaseTestsBase {
""
}

fun Transaction.excludingH2Version1(dbSettings: TestDB, statement: Transaction.(TestDB) -> Unit) {
if (dbSettings !in TestDB.allH2TestDB || db.isVersionCovers(BigDecimal("2.0"))) {
statement(dbSettings)
}
}

protected fun prepareSchemaForTest(schemaName: String) : Schema {
return Schema(schemaName, defaultTablespace = "USERS", temporaryTablespace = "TEMP ", quota = "20M", on = "USERS")
}
Expand Down
Loading

0 comments on commit c0a8af0

Please sign in to comment.