Skip to content

Commit

Permalink
feat: Add spring mutli container support (JetBrains#1781)
Browse files Browse the repository at this point in the history
* feat: fix to support spring multi container

* test: add test for spring multi container
  • Loading branch information
FullOfOrange authored and saral committed Oct 3, 2023
1 parent 6e5aea3 commit 9cf56be
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,59 @@ class SpringTransactionManager(
databaseConfig = databaseConfig
) { this }

@Volatile override var defaultIsolationLevel: Int = -1
@Volatile
override var defaultIsolationLevel: Int = -1
get() {
if (field == -1) {
field = Database.getDefaultIsolationLevel(db)
}
return field
}

private val springTxKey = "SPRING_TX_KEY"
private val transactionStackKey = "SPRING_TRANSACTION_STACK_KEY"

private fun getTransactionStack(): List<TransactionManager> {
return TransactionSynchronizationManager.getResource(transactionStackKey)
?.let { it as List<TransactionManager> }
?: listOf()
}

private fun setTransactionStack(list: List<TransactionManager>) {
TransactionSynchronizationManager.unbindResourceIfPossible(transactionStackKey)
TransactionSynchronizationManager.bindResource(transactionStackKey, list)
}

private fun pushTransactionStack(transaction: TransactionManager) {
val transactionList = getTransactionStack()
setTransactionStack(transactionList + transaction)
}

private fun popTransactionStack() = setTransactionStack(getTransactionStack().dropLast(1))

private fun getLastTransactionStack() = getTransactionStack().lastOrNull()

override fun doBegin(transaction: Any, definition: TransactionDefinition) {
super.doBegin(transaction, definition)

if (TransactionSynchronizationManager.hasResource(obtainDataSource())) {
currentOrNull() ?: initTransaction()
}
if (!TransactionSynchronizationManager.hasResource(springTxKey)) {
TransactionSynchronizationManager.bindResource(springTxKey, transaction)
currentOrNull() ?: initTransaction(transaction)
}

pushTransactionStack(this@SpringTransactionManager)
}

override fun doCleanupAfterCompletion(transaction: Any) {
super.doCleanupAfterCompletion(transaction)
if (!TransactionSynchronizationManager.hasResource(obtainDataSource())) {
TransactionSynchronizationManager.unbindResourceIfPossible(this)
TransactionSynchronizationManager.unbindResource(springTxKey)
}

popTransactionStack()
TransactionManager.resetCurrent(getLastTransactionStack())

if (TransactionSynchronizationManager.isSynchronizationActive() && TransactionSynchronizationManager.getSynchronizations().isEmpty()) {
TransactionSynchronizationManager.clearSynchronization()
}
TransactionManager.resetCurrent(null)
}

override fun doSuspend(transaction: Any): Any {
Expand Down Expand Up @@ -100,22 +122,22 @@ class SpringTransactionManager(
isolationLevel = isolation
}

getTransaction(tDefinition)

return currentOrNull() ?: initTransaction()
val transactionStatus = (getTransaction(tDefinition) as DefaultTransactionStatus)
return currentOrNull() ?: initTransaction(transactionStatus.transaction)
}

private fun initTransaction(): Transaction {
private fun initTransaction(transaction: Any): Transaction {
val connection = (TransactionSynchronizationManager.getResource(obtainDataSource()) as ConnectionHolder).connection

@Suppress("TooGenericExceptionCaught")
val transactionImpl = try {
SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull())
SpringTransaction(JdbcConnectionImpl(connection), db, defaultIsolationLevel, defaultReadOnly, currentOrNull(), transaction)
} catch (e: Exception) {
exposedLogger.error("Failed to start transaction. Connection will be closed.", e)
connection.close()
throw e
}

TransactionManager.resetCurrent(this)
return Transaction(transactionImpl).apply {
TransactionSynchronizationManager.bindResource(this@SpringTransactionManager, this)
Expand Down Expand Up @@ -144,7 +166,8 @@ class SpringTransactionManager(
override val db: Database,
override val transactionIsolation: Int,
override val readOnly: Boolean,
override val outerTransaction: Transaction?
override val outerTransaction: Transaction?,
private val currentTransaction: Any,
) : TransactionInterface {

override fun commit() {
Expand All @@ -157,9 +180,7 @@ class SpringTransactionManager(

override fun close() {
if (TransactionSynchronizationManager.isActualTransactionActive()) {
TransactionSynchronizationManager.getResource(springTxKey)?.let { springTx ->
this@SpringTransactionManager.doCleanupAfterCompletion(springTx)
}
this@SpringTransactionManager.doCleanupAfterCompletion(currentTransaction)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package org.jetbrains.exposed.spring

import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.deleteAll
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.selectAll
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.Assert
import org.junit.Test
import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType
import org.springframework.transaction.annotation.EnableTransactionManagement
import org.springframework.transaction.annotation.Transactional
import javax.sql.DataSource
import kotlin.test.BeforeTest

open class SpringMultiContainerTransactionTest {

val orderContainer = AnnotationConfigApplicationContext(OrderConfig::class.java)
val paymentContainer = AnnotationConfigApplicationContext(PaymentConfig::class.java)

val orders: Orders = orderContainer.getBean(Orders::class.java)
val payments: Payments = paymentContainer.getBean(Payments::class.java)

@BeforeTest
open fun beforeTest() {
orders.init()
payments.init()
}

@Test
open fun test1() {
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(0, payments.findAll().size)
}

@Test
open fun test2() {
orders.create()
Assert.assertEquals(1, orders.findAll().size)
payments.create()
Assert.assertEquals(1, payments.findAll().size)
}

@Test
open fun test3() {
orders.transaction {
payments.create()
orders.create()
payments.create()
}
Assert.assertEquals(1, orders.findAll().size)
Assert.assertEquals(2, payments.findAll().size)
}

@Test
open fun test4() {
kotlin.runCatching {
orders.transaction {
orders.create()
payments.create()
throw SpringTransactionTestException()
}
}
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(1, payments.findAll().size)
}

@Test
open fun test5() {
kotlin.runCatching {
orders.transaction {
orders.create()
payments.databaseTemplate {
payments.create()
throw SpringTransactionTestException()
}
}
}
Assert.assertEquals(0, orders.findAll().size)
Assert.assertEquals(0, payments.findAll().size)
}

@Test
open fun test6() {
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test7() {
orders.createWithExposedTrxBlock()
Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size)
payments.createWithExposedTrxBlock()
Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test8() {
orders.transaction {
payments.createWithExposedTrxBlock()
orders.createWithExposedTrxBlock()
payments.createWithExposedTrxBlock()
}
Assert.assertEquals(1, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(2, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test9() {
kotlin.runCatching {
orders.transaction {
orders.createWithExposedTrxBlock()
payments.createWithExposedTrxBlock()
throw SpringTransactionTestException()
}
}
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(1, payments.findAllWithExposedTrxBlock().size)
}

@Test
open fun test10() {
kotlin.runCatching {
orders.transaction {
orders.createWithExposedTrxBlock()
payments.databaseTemplate {
payments.createWithExposedTrxBlock()
throw SpringTransactionTestException()
}
}
}
Assert.assertEquals(0, orders.findAllWithExposedTrxBlock().size)
Assert.assertEquals(0, payments.findAllWithExposedTrxBlock().size)
}
}

@Configuration
@EnableTransactionManagement(proxyTargetClass = true)
open class OrderConfig {

@Bean
open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest1").setType(EmbeddedDatabaseType.H2).build()

@Bean
open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource)

@Bean
open fun orders() = Orders()
}

@Transactional
open class Orders {

open fun findAll() = Order.selectAll().map { it }

open fun findAllWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { findAll() }

open fun create() = Order.insertAndGetId {
it[buyer] = 123
}.value

open fun createWithExposedTrxBlock() = org.jetbrains.exposed.sql.transactions.transaction { create() }

open fun init() {
SchemaUtils.create(Order)
Order.deleteAll()
}

open fun transaction(block: () -> Unit) {
block()
}
}

object Order : LongIdTable("orders") {
val buyer = long("buyer_id")
}

@Configuration
@EnableTransactionManagement(proxyTargetClass = true)
open class PaymentConfig {

@Bean
open fun dataSource(): EmbeddedDatabase = EmbeddedDatabaseBuilder().setName("embeddedTest2").setType(EmbeddedDatabaseType.H2).build()

@Bean
open fun transactionManager(dataSource: DataSource) = SpringTransactionManager(dataSource)

@Bean
open fun payments() = Payments()
}

@Transactional
open class Payments {

open fun findAll() = Payment.selectAll().map { it }

open fun findAllWithExposedTrxBlock() = transaction { findAll() }

open fun create() = Payment.insertAndGetId {
it[state] = "state"
}.value

open fun createWithExposedTrxBlock() = transaction { create() }

open fun init() {
SchemaUtils.create(Payment)
Payment.deleteAll()
}

open fun databaseTemplate(block: () -> Unit) {
block()
}
}

object Payment : LongIdTable("payments") {
val state = varchar("state", 50)
}

private class SpringTransactionTestException : Error()
Loading

0 comments on commit 9cf56be

Please sign in to comment.