Skip to content

Commit

Permalink
Try to fetch not only autoincrement values but all returned after ins…
Browse files Browse the repository at this point in the history
…ert / #2
  • Loading branch information
Tapac committed Apr 4, 2019
1 parent 3cd1388 commit 5f1d3a4
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ open class InsertStatement<Key:Any>(val table: Table, val isIgnore: Boolean = fa
val autoGeneratedKeys = arrayListOf<MutableMap<Column<*>, Any?>>()

if (inserted > 0) {
val returnedColumns = targets.flatMap { it.columns }.mapNotNull { col ->
try {
rs?.findColumn(col.name)?.let { col to it }
} catch (e: SQLException) {
null
val returnedColumns =
(if (currentDialect.supportsOnlyIdentifiersInGeneratedKeys) autoIncColumns else table.columns).mapNotNull { col ->
try {
rs?.findColumn(col.name)?.let { col to it }
} catch (e: SQLException) {
null
}
}
}

val firstAutoIncColumn = autoIncColumns.firstOrNull()
val firstAutoIncColumn = autoIncColumns.firstOrNull()
if (firstAutoIncColumn != null || returnedColumns.isNotEmpty()) {
while (rs?.next() == true) {
val returnedValues = returnedColumns.associateTo(mutableMapOf()) { it.first to rs.getObject(it.second) }
Expand Down Expand Up @@ -113,7 +114,9 @@ open class InsertStatement<Key:Any>(val table: Table, val isIgnore: Boolean = fa
}
}

protected val autoIncColumns = targets.flatMap { it.columns }.filter { it.columnType.isAutoInc || it.columnType is EntityIDColumnType<*> }
protected val autoIncColumns = targets.flatMap { it.columns }.filter {
it.columnType.isAutoInc || (it.columnType is EntityIDColumnType<*> && !currentDialect.supportsOnlyIdentifiersInGeneratedKeys)
}

override fun prepared(transaction: Transaction, sql: String): PreparedStatement = when {
// https://github.com/pgjdbc/pgjdbc/issues/1168
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ abstract class Statement<out T>(val type: StatementType, val targets: List<Table
open fun prepared(transaction: Transaction, sql: String) : PreparedStatement =
transaction.connection.prepareStatement(sql, PreparedStatement.NO_GENERATED_KEYS)!!

open val isAlwaysBatch: Boolean get() = false
open val isAlwaysBatch: Boolean = false

fun execute(transaction: Transaction): T? = transaction.exec(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ interface DatabaseDialect {

val defaultReferenceOption : ReferenceOption get() = ReferenceOption.RESTRICT

val supportsOnlyIdentifiersInGeneratedKeys get() = false

// Specific SQL statements

fun createIndex(index: Index): String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal object H2FunctionProvider : FunctionProvider() {
private fun currentMode(): String =
((TransactionManager.current().connection as Wrapper).unwrap(JdbcConnection::class.java).session as? Session)?.database?.mode?.name ?: ""

private val isMySQLMode: Boolean get() = currentMode() == "MySQL"
internal val isMySQLMode: Boolean get() = currentMode() == "MySQL"

private fun dbReleaseDate(transaction: Transaction) : DateTime {
val releaseDate = transaction.db.metadata.databaseProductVersion.substringAfterLast('(').substringBeforeLast(')')
Expand Down Expand Up @@ -59,6 +59,8 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function

override val supportsMultipleGeneratedKeys: Boolean = false

override val supportsOnlyIdentifiersInGeneratedKeys get() = !(functionProvider as H2FunctionProvider).isMySQLMode

override fun existingIndices(vararg tables: Table): Map<Table, List<Index>> =
super.existingIndices(*tables).mapValues { it.value.filterNot { it.indexName.startsWith("PRIMARY_KEY_") } }.filterValues { it.isNotEmpty() }

Expand All @@ -70,6 +72,11 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function
return super.createIndex(index)
}

override val name: String
get() = when {
(functionProvider as H2FunctionProvider).isMySQLMode -> "$dialectName (Mysql Mode)"
else -> dialectName
}
companion object {
const val dialectName = "h2"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.jetbrains.exposed.sql.vendors

class MariaDBDialect : MysqlDialect() {
override val name: String = dialectName
override val supportsOnlyIdentifiersInGeneratedKeys = true
companion object {
const val dialectName = "mariadb"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ open class SQLServerDialect : VendorDialect(dialectName, SQLServerDataTypeProvid

override val defaultReferenceOption: ReferenceOption get() = ReferenceOption.NO_ACTION

override val supportsOnlyIdentifiersInGeneratedKeys: Boolean = true

override fun modifyColumn(column: Column<*>) =
super.modifyColumn(column).replace("MODIFY COLUMN", "ALTER COLUMN")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.tests.shared
import org.jetbrains.exposed.dao.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.currentDialect
import org.junit.Test
import java.util.concurrent.atomic.AtomicInteger

Expand Down Expand Up @@ -79,19 +80,21 @@ class NonAutoIncEntities : DatabaseTestsBase() {

@Test fun testFlushNonAutoincEntityWithoutDefaultValue() {
withTables(AutoIncSharedTable) {
SharedNonAutoIncEntity.new {
bool = true
}
if (!currentDialect.supportsOnlyIdentifiersInGeneratedKeys) {
SharedNonAutoIncEntity.new {
bool = true
}

SharedNonAutoIncEntity.new {
bool = false
}
SharedNonAutoIncEntity.new {
bool = false
}

val entities = flushCache()
val entities = flushCache()

assertEquals(2, entities.size)
assertEquals(1, entities[0].id._value)
assertEquals(2, entities[1].id._value)
assertEquals(2, entities.size)
assertEquals(1, entities[0].id._value)
assertEquals(2, entities[1].id._value)
}
}
}
}

0 comments on commit 5f1d3a4

Please sign in to comment.