Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EXPOSED-52 Support batch UPSERT #1749

Merged
merged 3 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,81 @@ fun <T : Table> T.upsert(
execute(TransactionManager.current())
}

/**
* Represents the SQL command that either batch inserts new rows into a table, or updates the existing rows if insertions violate unique constraints.
*
* **Note**: Unlike `upsert`, `batchUpsert` does not include a `where` parameter. Please log a feature request on
* [YouTrack](https://youtrack.jetbrains.com/newIssue?project=EXPOSED&c=Type%20Feature&draftId=25-4449790) if a use-case requires inclusion of a `where` clause.
*
* @param data Collection of values to use in batch upsert.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param shouldReturnGeneratedValues Specifies whether newly generated values (for example, auto-incremented IDs) should be returned.
* See [Batch Insert](https://github.com/JetBrains/Exposed/wiki/DSL#batch-insert) for more details.
* @sample org.jetbrains.exposed.sql.tests.shared.dml.UpsertTests.testBatchUpsertWithNoConflict
*/
fun <T : Table, E : Any> T.batchUpsert(
data: Iterable<E>,
vararg keys: Column<*>,
onUpdate: List<Pair<Column<*>, Expression<*>>>? = null,
shouldReturnGeneratedValues: Boolean = true,
body: BatchUpsertStatement.(E) -> Unit
): List<ResultRow> {
return batchUpsert(data.iterator(), *keys, onUpdate = onUpdate, shouldReturnGeneratedValues = shouldReturnGeneratedValues, body = body)
}

/**
* Represents the SQL command that either batch inserts new rows into a table, or updates the existing rows if insertions violate unique constraints.
*
* **Note**: Unlike `upsert`, `batchUpsert` does not include a `where` parameter. Please log a feature request on
* [YouTrack](https://youtrack.jetbrains.com/newIssue?project=EXPOSED&c=Type%20Feature&draftId=25-4449790) if a use-case requires inclusion of a `where` clause.
*
* @param data Sequence of values to use in batch upsert.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param shouldReturnGeneratedValues Specifies whether newly generated values (for example, auto-incremented IDs) should be returned.
* See [Batch Insert](https://github.com/JetBrains/Exposed/wiki/DSL#batch-insert) for more details.
* @sample org.jetbrains.exposed.sql.tests.shared.dml.UpsertTests.testBatchUpsertWithSequence
*/
fun <T : Table, E : Any> T.batchUpsert(
data: Sequence<E>,
vararg keys: Column<*>,
onUpdate: List<Pair<Column<*>, Expression<*>>>? = null,
shouldReturnGeneratedValues: Boolean = true,
body: BatchUpsertStatement.(E) -> Unit
): List<ResultRow> {
return batchUpsert(data.iterator(), *keys, onUpdate = onUpdate, shouldReturnGeneratedValues = shouldReturnGeneratedValues, body = body)
}

/**
* Represents the SQL command that either batch inserts new rows into a table, or updates the existing rows if insertions violate unique constraints.
*
* **Note**: Unlike `upsert`, `batchUpsert` does not include a `where` parameter. Please log a feature request on
* [YouTrack](https://youtrack.jetbrains.com/newIssue?project=EXPOSED&c=Type%20Feature&draftId=25-4449790) if a use-case requires inclusion of a `where` clause.
*
* @param data Iterator over a collection of values to use in batch upsert.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param shouldReturnGeneratedValues Specifies whether newly generated values (for example, auto-incremented IDs) should be returned.
* See [Batch Insert](https://github.com/JetBrains/Exposed/wiki/DSL#batch-insert) for more details.
* @sample org.jetbrains.exposed.sql.tests.shared.dml.UpsertTests.testBatchUpsertWithNoConflict
*/
private fun <T : Table, E> T.batchUpsert(
data: Iterator<E>,
vararg keys: Column<*>,
onUpdate: List<Pair<Column<*>, Expression<*>>>? = null,
shouldReturnGeneratedValues: Boolean = true,
body: BatchUpsertStatement.(E) -> Unit
): List<ResultRow> = executeBatch(data, body) {
BatchUpsertStatement(this, *keys, onUpdate = onUpdate, shouldReturnGeneratedValues = shouldReturnGeneratedValues)
}

/**
* @sample org.jetbrains.exposed.sql.tests.shared.DDLTests.tableExists02
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.jetbrains.exposed.sql.statements

import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.Expression
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.jetbrains.exposed.sql.vendors.H2FunctionProvider
import org.jetbrains.exposed.sql.vendors.MysqlFunctionProvider

/**
* Represents the SQL command that either batch inserts new rows into a table, or updates the existing rows if insertions violate unique constraints.
*
* **Note**: Unlike `UpsertStatement`, `BatchUpsertStatement` does not include a `where` parameter. Please log a feature request
* on [YouTrack](https://youtrack.jetbrains.com/newIssue?project=EXPOSED&c=Type%20Feature&draftId=25-4449790) if a use-case requires inclusion of a `where` clause.
*
* @param table Table to either insert values into or update values from.
* @param keys (optional) Columns to include in the condition that determines a unique constraint match. If no columns are provided,
* primary keys will be used. If the table does not have any primary keys, the first unique index will be attempted.
* @param onUpdate List of pairs of specific columns to update and the expressions to update them with.
* If left null, all columns will be updated with the values provided for the insert.
* @param shouldReturnGeneratedValues Specifies whether newly generated values (for example, auto-incremented IDs) should be returned.
* See [Batch Insert](https://github.com/JetBrains/Exposed/wiki/DSL#batch-insert) for more details.
*/
open class BatchUpsertStatement(
table: Table,
vararg val keys: Column<*>,
val onUpdate: List<Pair<Column<*>, Expression<*>>>?,
shouldReturnGeneratedValues: Boolean = true
) : BaseBatchInsertStatement(table, ignore = false, shouldReturnGeneratedValues) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Single statement upsert() takes a WHERE clause as an argument (to be used in update part), but this option hasn't been included in batchUpsert().

The rationale is that I couldn't find a use case for a batch upsert including a where that applies to all batch data (it seems to defeat the purpose?). All the PostgreSQL examples I came across used an UPDATE without WHERE (and MySQL doesn't allow WHERE for even single upsert).

I can implement its inclusion if anybody believes it should definitely still be an option. Or we can wait to see if users request it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can note this in KDoc with link to youtrack if there is a use-case

override fun prepareSQL(transaction: Transaction): String {
val functionProvider = when (val dialect = transaction.db.dialect) {
is H2Dialect -> when (dialect.h2Mode) {
H2Dialect.H2CompatibilityMode.MariaDB, H2Dialect.H2CompatibilityMode.MySQL -> MysqlFunctionProvider()
else -> H2FunctionProvider
}
else -> dialect.functionProvider
}
return functionProvider.upsert(table, arguments!!.first(), onUpdate, null, transaction, *keys)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class EntityLifecycleInterceptor : GlobalStatementInterceptor {
}
}

is UpsertStatement<*>, is BatchUpsertStatement -> {
transaction.flushCache()
transaction.entityCache.removeTablesReferrers(statement.targets, true)
if (!isExecutedWithinEntityLifecycle) {
statement.targets.filterIsInstance<IdTable<*>>().forEach {
transaction.entityCache.data[it]?.clear()
}
}
}

is InsertStatement<*> -> {
transaction.flushCache()
transaction.entityCache.removeTablesReferrers(listOf(statement.table), true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ enum class TestDB(

companion object {
val allH2TestDB = listOf(H2, H2_MYSQL, H2_PSQL, H2_MARIADB, H2_ORACLE, H2_SQLSERVER)
val mySqlRelatedDB = listOf(MYSQL, MARIADB, H2_MYSQL, H2_MARIADB)
fun enabledInTests(): Set<TestDB> {
val concreteDialects = System.getProperty("exposed.test.dialects", "")
.split(",")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ import org.jetbrains.exposed.sql.tests.*
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.expectException
import org.junit.Test
import java.util.*

// Upsert implementation does not support H2 version 1
// https://youtrack.jetbrains.com/issue/EXPOSED-30/Phase-Out-Support-for-H2-Version-1.x
class UpsertTests : DatabaseTestsBase() {
// these DB require key columns from ON clause to be included in the derived source table (USING clause)
private val upsertViaMergeDB = listOf(TestDB.SQLSERVER, TestDB.ORACLE) + TestDB.allH2TestDB - TestDB.H2_MYSQL

private val mySqlLikeDB = listOf(TestDB.MYSQL, TestDB.H2_MYSQL, TestDB.MARIADB, TestDB.H2_MARIADB)

@Test
fun testUpsertWithPKConflict() {
val tester = object : Table("tester") {
Expand Down Expand Up @@ -94,29 +93,24 @@ class UpsertTests : DatabaseTestsBase() {

@Test
fun testUpsertWithUniqueIndexConflict() {
val tester = object : Table("tester") {
val name = varchar("name", 64).uniqueIndex()
val age = integer("age")
}

withTables(tester) { testDb ->
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val nameA = tester.upsert {
it[name] = "A"
it[age] = 10
} get tester.name
tester.upsert {
it[name] = "B"
it[age] = 10
val wordA = Words.upsert {
it[word] = "A"
it[count] = 10
} get Words.word
Words.upsert {
it[word] = "B"
it[count] = 10
}
tester.upsert {
it[name] = "A"
it[age] = 9
Words.upsert {
it[word] = wordA
it[count] = 9
}

assertEquals(2, tester.selectAll().count())
val updatedResult = tester.select { tester.name eq nameA }.single()
assertEquals(9, updatedResult[tester.age])
assertEquals(2, Words.selectAll().count())
val updatedResult = Words.select { Words.word eq wordA }.single()
assertEquals(9, updatedResult[Words.count])
}
}
}
Expand All @@ -129,7 +123,7 @@ class UpsertTests : DatabaseTestsBase() {
val name = varchar("name", 64)
}

withTables(excludeSettings = mySqlLikeDB, tester) { testDb ->
withTables(excludeSettings = TestDB.mySqlRelatedDB, tester) { testDb ->
excludingH2Version1(testDb) {
val oldIdA = tester.insert {
it[idA] = 1
Expand Down Expand Up @@ -174,7 +168,7 @@ class UpsertTests : DatabaseTestsBase() {
val name = varchar("name", 64)
}

val okWithNoUniquenessDB = mySqlLikeDB + listOf(TestDB.SQLITE)
val okWithNoUniquenessDB = TestDB.mySqlRelatedDB + TestDB.SQLITE

withTables(tester) { testDb ->
excludingH2Version1(testDb) {
Expand All @@ -196,23 +190,18 @@ class UpsertTests : DatabaseTestsBase() {

@Test
fun testUpsertWithManualUpdateAssignment() {
val tester = object : Table("tester") {
val word = varchar("word", 256).uniqueIndex()
val count = integer("count").default(1)
}

withTables(tester) { testDb ->
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val testWord = "Test"
val incrementCount = listOf(tester.count to tester.count.plus(1))
val incrementCount = listOf(Words.count to Words.count.plus(1))

repeat(3) {
tester.upsert(onUpdate = incrementCount) {
Words.upsert(onUpdate = incrementCount) {
it[word] = testWord
}
}

assertEquals(3, tester.selectAll().single()[tester.count])
assertEquals(3, Words.selectAll().single()[Words.count])
}
}
}
Expand Down Expand Up @@ -298,7 +287,7 @@ class UpsertTests : DatabaseTestsBase() {
val age = integer("age")
}

withTables(excludeSettings = mySqlLikeDB + upsertViaMergeDB, tester) {
withTables(excludeSettings = TestDB.mySqlRelatedDB + upsertViaMergeDB, tester) {
val id1 = tester.insertAndGetId {
it[name] = "A"
it[address] = "Place A"
Expand Down Expand Up @@ -365,4 +354,78 @@ class UpsertTests : DatabaseTestsBase() {
}
}
}

@Test
fun testBatchUpsertWithNoConflict() {
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val amountOfWords = 10
val allWords = List(amountOfWords) { i -> "Word ${'A' + i}" to amountOfWords * i + amountOfWords }

val generatedIds = Words.batchUpsert(allWords) { (word, count) ->
this[Words.word] = word
this[Words.count] = count
}

assertEquals(amountOfWords, generatedIds.size)
assertEquals(amountOfWords.toLong(), Words.selectAll().count())
}
}
}

@Test
fun testBatchUpsertWithConflict() {
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val vowels = listOf("A", "E", "I", "O", "U")
val alphabet = ('A'..'Z').map { it.toString() }
val lettersWithDuplicates = alphabet + vowels
val incrementCount = listOf(Words.count to Words.count.plus(1))

Words.batchUpsert(lettersWithDuplicates, onUpdate = incrementCount) { letter ->
this[Words.word] = letter
}

assertEquals(alphabet.size.toLong(), Words.selectAll().count())
Words.selectAll().forEach {
val expectedCount = if (it[Words.word] in vowels) 2 else 1
assertEquals(expectedCount, it[Words.count])
}
}
}
}

@Test
fun testBatchUpsertWithSequence() {
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val amountOfWords = 25
val allWords = List(amountOfWords) { UUID.randomUUID().toString() }.asSequence()
Words.batchUpsert(allWords) { word -> this[Words.word] = word }

val batchesSize = Words.selectAll().count()

assertEquals(amountOfWords.toLong(), batchesSize)
}
}
}

@Test
fun testBatchUpsertWithEmptySequence() {
withTables(Words) { testDb ->
excludingH2Version1(testDb) {
val allWords = emptySequence<String>()
Words.batchUpsert(allWords) { word -> this[Words.word] = word }

val batchesSize = Words.selectAll().count()

assertEquals(0, batchesSize)
}
}
}

private object Words : Table("words") {
val word = varchar("name", 64).uniqueIndex()
val count = integer("count").default(1)
}
}
Loading