Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EXPOSED-45 Support single statement UPSERT #1743

Merged
merged 3 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ fun Join.update(where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null, limit:
return query.execute(TransactionManager.current())!!
}

/**
* Represents the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* **Note:** Vendors that do not support this operation directly implement the standard MERGE USING command.
*
* @param keys (optional) Columns to include in the condition that determines a unique constraint match.
* If no columns are provided, primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param where Condition that determines which rows to update, if a unique violation is found.
*/
fun <T : Table> T.upsert(
vararg keys: Column<*>,
onUpdate: List<Pair<Column<*>, Expression<*>>>? = null,
where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null,
body: T.(UpsertStatement<Long>) -> Unit
) = UpsertStatement<Long>(this, *keys, onUpdate = onUpdate, where = where?.let { SqlExpressionBuilder.it() }).apply {
body(this)
execute(TransactionManager.current())
}

/**
* @sample org.jetbrains.exposed.sql.tests.shared.DDLTests.tableExists02
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.vendors.*

/**
* Represents the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* @param table Table to either insert values into or update values from.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param where Condition that determines which rows to update, if a unique violation is found. This clause may not be supported by all vendors.
*/
open class UpsertStatement<Key : Any>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add KDoc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

table: Table,
vararg val keys: Column<*>,
val onUpdate: List<Pair<Column<*>, Expression<*>>>?,
val where: Op<Boolean>?
) : InsertStatement<Key>(table) {

override fun prepareSQL(transaction: Transaction): String {
val functionProvider = when (val dialect = transaction.db.dialect) {
is H2Dialect -> when (dialect.h2Mode) {
H2Dialect.H2CompatibilityMode.MariaDB, H2Dialect.H2CompatibilityMode.MySQL -> MysqlFunctionProvider()
else -> H2FunctionProvider
}
else -> dialect.functionProvider
}
return functionProvider.upsert(table, arguments!!.first(), onUpdate, where, transaction, *keys)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,132 @@ abstract class FunctionProvider {
transaction: Transaction
): String = transaction.throwUnsupportedException("There's no generic SQL for REPLACE. There must be vendor specific implementation.")

/**
* Returns the SQL command that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
* **Note:** Vendors that do not support this operation directly implement the standard MERGE USING command.
*
* @param table Table to either insert values into or update values from.
* @param data Pairs of columns to use for insert or update and values to insert or update.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* @param where Condition that determines which rows to update, if a unique violation is found.
* @param transaction Transaction where the operation is executed.
*/
open fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
if (where != null) {
transaction.throwUnsupportedException("MERGE implementation of UPSERT doesn't support single WHERE clause")
}
val keyColumns = getKeyColumnsForUpsert(table, *keys)
if (keyColumns.isNullOrEmpty()) {
transaction.throwUnsupportedException("UPSERT requires a unique key or constraint as a conflict target")
}

val dataColumns = data.unzip().first
val autoIncColumn = table.autoIncColumn
val nextValExpression = autoIncColumn?.autoIncColumnType?.nextValExpression
val dataColumnsWithoutAutoInc = autoIncColumn?.let { dataColumns - autoIncColumn } ?: dataColumns
val updateColumns = dataColumns.filter { it !in keyColumns }

return with(QueryBuilder(true)) {
+"MERGE INTO "
table.describe(transaction, this)
+" T USING "
data.appendTo(prefix = "(VALUES (", postfix = ")") { (column, value) ->
registerArgument(column, value)
}
dataColumns.appendTo(prefix = ") S(", postfix = ")") { column ->
append(transaction.identity(column))
}

+" ON "
keyColumns.appendTo(separator = " AND ", prefix = "(", postfix = ")") { column ->
val columnName = transaction.identity(column)
append("T.$columnName=S.$columnName")
}

+" WHEN MATCHED THEN"
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = true)

+" WHEN NOT MATCHED THEN INSERT "
dataColumnsWithoutAutoInc.appendTo(prefix = "(") { column ->
append(transaction.identity(column))
}
nextValExpression?.let {
append(", ${transaction.identity(autoIncColumn)}")
}
dataColumnsWithoutAutoInc.appendTo(prefix = ") VALUES(") { column ->
append("S.${transaction.identity(column)}")
}
nextValExpression?.let {
append(", $it")
}
+")"
toString()
}
}

/**
* Returns the columns to be used in the conflict condition of an upsert statement.
*/
protected fun getKeyColumnsForUpsert(table: Table, vararg keys: Column<*>): List<Column<*>>? {
return keys.toList().ifEmpty {
table.primaryKey?.columns?.toList() ?: table.indices.firstOrNull { it.unique }?.columns
}
}

/**
* Appends the complete default SQL insert (no ignore) command to [this] QueryBuilder.
*/
protected fun QueryBuilder.appendInsertToUpsertClause(table: Table, data: List<Pair<Column<*>, Any?>>, transaction: Transaction) {
val valuesSql = if (data.isEmpty()) {
""
} else {
data.appendTo(QueryBuilder(true), prefix = "VALUES (", postfix = ")") { (column, value) ->
registerArgument(column, value)
}.toString()
}
val insertStatement = insert(false, table, data.unzip().first, valuesSql, transaction)

+insertStatement
}

/**
* Appends an SQL update command for a derived table (with or without alias identifiers) to [this] QueryBuilder.
*/
protected fun QueryBuilder.appendUpdateToUpsertClause(
table: Table,
updateColumns: List<Column<*>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
transaction: Transaction,
isAliasNeeded: Boolean
) {
+" UPDATE SET "
onUpdate?.appendTo { (columnToUpdate, updateExpression) ->
if (isAliasNeeded) {
val aliasExpression = updateExpression.toString().replace(transaction.identity(table), "T")
append("T.${transaction.identity(columnToUpdate)}=${aliasExpression}")
} else {
append("${transaction.identity(columnToUpdate)}=${updateExpression}")
}
} ?: run {
updateColumns.appendTo { column ->
val columnName = transaction.identity(column)
if (isAliasNeeded) {
append("T.$columnName=S.$columnName")
} else {
append("$columnName=EXCLUDED.$columnName")
}
}
}
}

/**
* Returns the SQL command that deletes one or more rows of a table.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.exposed.sql.vendors

import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.math.BigDecimal
Expand Down Expand Up @@ -143,6 +144,49 @@ internal open class MysqlFunctionProvider : FunctionProvider() {
limit?.let { +" LIMIT $it" }
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
if (keys.isNotEmpty()) {
transaction.throwUnsupportedException("MySQL doesn't support specifying conflict keys in UPSERT clause")
}
if (where != null) {
transaction.throwUnsupportedException("MySQL doesn't support WHERE in UPSERT clause")
}

val isAliasSupported = when (val dialect = transaction.db.dialect) {
is MysqlDialect -> dialect !is MariaDBDialect && dialect.isMysql8
else -> false // H2_MySQL mode also uses this function provider & requires older version
}

return with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)
if (isAliasSupported) {
+" AS NEW"
}

+" ON DUPLICATE KEY UPDATE "
onUpdate?.appendTo { (columnToUpdate, updateExpression) ->
append("${transaction.identity(columnToUpdate)}=${updateExpression}")
} ?: run {
data.unzip().first.appendTo { column ->
val columnName = transaction.identity(column)
if (isAliasSupported) {
append("$columnName=NEW.$columnName")
} else {
append("$columnName=VALUES($columnName)")
}
}
}
toString()
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,28 @@ internal object OracleFunctionProvider : FunctionProvider() {
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
val statement = super.upsert(table, data, onUpdate, where, transaction, *keys)

val dualTable = data.appendTo(QueryBuilder(true), prefix = "(SELECT ", postfix = " FROM DUAL) S") { (column, value) ->
registerArgument(column, value)
+" AS "
append(transaction.identity(column))
}.toString()

val (leftReserved, rightReserved) = " USING " to " ON "
val leftBoundary = statement.indexOf(leftReserved) + leftReserved.length
val rightBoundary = statement.indexOf(rightReserved)
return statement.replaceRange(leftBoundary, rightBoundary, dualTable)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A technicality, but I'm not entirely sure about this choice. If anyone has a cleaner idea of how to replace this chunk in the middle of a statement, I'm open to alternatives. The other way I considered was:

return statement.replaceAfter(" USING ", dualTable) + statement.replaceBefore(" ON ", "")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's ok to make it this way

}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,40 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() {
}
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
val keyColumns = getKeyColumnsForUpsert(table, *keys)
if (keyColumns.isNullOrEmpty()) {
transaction.throwUnsupportedException("UPSERT requires a unique key or constraint as a conflict target")
}

val updateColumns = data.unzip().first.filter { it !in keyColumns }

return with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)

+" ON CONFLICT "
keyColumns.appendTo(prefix = "(", postfix = ")") { column ->
append(transaction.identity(column))
}

+" DO"
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = false)

where?.let {
+" WHERE "
+it
}
toString()
}
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ internal object SQLServerFunctionProvider : FunctionProvider() {
toString()
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String {
// SQLSERVER MERGE statement must be terminated by a semi-colon (;)
return super.upsert(table, data, onUpdate, where, transaction, *keys) + ";"
}

override fun delete(ignore: Boolean, table: Table, where: String?, limit: Int?, transaction: Transaction): String {
val def = super.delete(ignore, table, where, null, transaction)
return if (limit != null) def.replaceFirst("DELETE", "DELETE TOP($limit)") else def
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,35 @@ internal object SQLiteFunctionProvider : FunctionProvider() {
return "INSERT OR REPLACE INTO ${transaction.identity(table)} ($columns) VALUES ($values)"
}

override fun upsert(
table: Table,
data: List<Pair<Column<*>, Any?>>,
onUpdate: List<Pair<Column<*>, Expression<*>>>?,
where: Op<Boolean>?,
transaction: Transaction,
vararg keys: Column<*>
): String = with(QueryBuilder(true)) {
appendInsertToUpsertClause(table, data, transaction)

+" ON CONFLICT"
val keyColumns = getKeyColumnsForUpsert(table, *keys) ?: emptyList()
if (keyColumns.isNotEmpty()) {
keyColumns.appendTo(prefix = " (", postfix = ")") { column ->
append(transaction.identity(column))
}
}

+" DO"
val updateColumns = data.unzip().first.filter { it !in keyColumns }
appendUpdateToUpsertClause(table, updateColumns, onUpdate, transaction, isAliasNeeded = false)

where?.let {
+" WHERE "
+it
}
toString()
}

override fun delete(
ignore: Boolean,
table: Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.junit.Assume
import org.junit.AssumptionViolatedException
import org.testcontainers.containers.MySQLContainer
import org.testcontainers.containers.PostgreSQLContainer
import java.math.BigDecimal
import java.sql.Connection
import java.sql.SQLException
import java.time.Duration
Expand Down Expand Up @@ -283,6 +284,12 @@ abstract class DatabaseTestsBase {
""
}

fun Transaction.excludingH2Version1(dbSettings: TestDB, statement: Transaction.(TestDB) -> Unit) {
if (dbSettings !in TestDB.allH2TestDB || db.isVersionCovers(BigDecimal("2.0"))) {
statement(dbSettings)
}
}

protected fun prepareSchemaForTest(schemaName: String) : Schema {
return Schema(schemaName, defaultTablespace = "USERS", temporaryTablespace = "TEMP ", quota = "20M", on = "USERS")
}
Expand Down
Loading