diff --git a/exposed-dao/src/main/kotlin/org/jetbrains/exposed/dao/EntityLifecycleInterceptor.kt b/exposed-dao/src/main/kotlin/org/jetbrains/exposed/dao/EntityLifecycleInterceptor.kt index e1e6d8441a..0fd8d3168a 100644 --- a/exposed-dao/src/main/kotlin/org/jetbrains/exposed/dao/EntityLifecycleInterceptor.kt +++ b/exposed-dao/src/main/kotlin/org/jetbrains/exposed/dao/EntityLifecycleInterceptor.kt @@ -33,9 +33,17 @@ class EntityLifecycleInterceptor : GlobalStatementInterceptor { @Suppress("ComplexMethod") override fun beforeExecution(transaction: Transaction, context: StatementContext) { - when (val statement = context.statement) { + beforeExecution(transaction = transaction, context = context, childStatement = null) + } + + private fun beforeExecution(transaction: Transaction, context: StatementContext, childStatement: Statement<*>?) { + when (val statement = childStatement ?: context.statement) { is Query -> transaction.flushEntities(statement) + is ReturningStatement -> { + beforeExecution(transaction = transaction, context = context, childStatement = statement.mainStatement) + } + is DeleteStatement -> { transaction.flushCache() transaction.entityCache.removeTablesReferrers(statement.targetsSet.targetTables(), false) diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt index 684fa2eed9..a6cd0b0d2b 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ReturningTests.kt @@ -1,5 +1,8 @@ package org.jetbrains.exposed.sql.tests.shared.dml +import org.jetbrains.exposed.dao.IntEntity +import org.jetbrains.exposed.dao.IntEntityClass +import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.IntIdTable import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.SqlExpressionBuilder.times @@ -21,6 +24,13 @@ class ReturningTests : DatabaseTestsBase() { val price = double("price") } + class ItemDAO(id: EntityID) : IntEntity(id) { + companion object : IntEntityClass(Items) + + var name by Items.name + var price by Items.price + } + @Test fun testInsertReturning() { withTables(TestDB.ALL - returningSupportedDb, Items) { @@ -115,6 +125,34 @@ class ReturningTests : DatabaseTestsBase() { } } + @Test + fun testUpsertReturningWithDAO() { + withTables(TestDB.ALL - returningSupportedDb, Items) { + val result1 = Items.upsertReturning { + it[name] = "A" + it[price] = 99.0 + }.let { + ItemDAO.wrapRow(it.single()) + } + assertEquals(1, result1.id.value) + assertEquals("A", result1.name) + assertEquals(99.0, result1.price) + + val result2 = Items.upsertReturning { + it[id] = 1 + it[name] = "B" + it[price] = 200.0 + }.let { + ItemDAO.wrapRow(it.single()) + } + assertEquals(1, result2.id.value) + assertEquals("B", result2.name) + assertEquals(200.0, result2.price) + + assertEquals(1, Items.selectAll().count()) + } + } + @Test fun testReturningWithNoResults() { withTables(TestDB.enabledDialects() - returningSupportedDb, Items) {