Skip to content

Commit

Permalink
feat: EXPOSED-498 Handle auto-increment status change on a column (#2216
Browse files Browse the repository at this point in the history
)

The function `MigrationUtils.statementsRequiredForDatabaseMigration` now includes the proper statements to change the type of a column when it detects that the auto-increment status has changed for PostgreSQL and SQL Server, and statements to create/drop the necessary sequences.

- **Why**:

**PostgreSQL**: The generated ALTER COLUMN statement was setting SERIAL as the target column type and that was throwing an error because it is not recognised as a type.

**SQL Server**: The generated ALTER COLUMN statement was setting IDENTITY as the target column type and that was throwing an error because it is not possible to add it to an existing column later on and must be defined that way from the start.

- **How**:

Two functions, `checkMissingSequences` and `checkUnmappedSequences` were added to `MigrationUtils`. For `checkUnmappedSequences`, H2_V1 is excluded because its system sequences do not have predictable naming patterns and it’s impossible to know if an existing sequence in the database is mapped in the Exposed code, so we would get false DROP statements that should not be there.

**PostgreSQL**: The SQL equivalent to setting a column type as SERIAL was used. Check PostgreSQL documentation [here](https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-SERIAL). This means that when changing a column to be auto-incrementing in PostgreSQL, Exposed will generate three SQL statements:
1. One to create a sequence
2. One to set the column's default values to be assigned from a sequence generator
3. One to set the column as the owner of the sequence created in Step 1

**SQL Server**:
The simplest way to solve it is to create a new IDENTITY column and drop the old one. This means that when changing a column to be auto-incrementing in SQL Server, Exposed will generate three SQL statements:
1. One to add a new IDENTITY column named "NEW_columnName"
2. One to drop the old column
3. One to rename the new column to the same name as the old dropped column
  • Loading branch information
joc-a authored Sep 6, 2024
1 parent 61b1ebb commit bd74c5d
Show file tree
Hide file tree
Showing 9 changed files with 609 additions and 42 deletions.
10 changes: 10 additions & 0 deletions exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,7 @@ public class org/jetbrains/exposed/sql/Table : org/jetbrains/exposed/sql/ColumnS
public final fun getIndices ()Ljava/util/List;
public fun getPrimaryKey ()Lorg/jetbrains/exposed/sql/Table$PrimaryKey;
public final fun getSchemaName ()Ljava/lang/String;
public final fun getSequences ()Ljava/util/List;
public fun getTableName ()Ljava/lang/String;
public fun hashCode ()I
public final fun index (Ljava/lang/String;Z[Lorg/jetbrains/exposed/sql/Column;Ljava/util/List;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V
Expand Down Expand Up @@ -3993,6 +3994,7 @@ public class org/jetbrains/exposed/sql/vendors/H2Dialect : org/jetbrains/exposed
public final fun getDelegatedDialectNameProvider ()Lorg/jetbrains/exposed/sql/vendors/VendorDialect$DialectNameProvider;
public fun getFunctionProvider ()Lorg/jetbrains/exposed/sql/vendors/FunctionProvider;
public final fun getH2Mode ()Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2CompatibilityMode;
public final fun getMajorVersion ()Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion;
public fun getName ()Ljava/lang/String;
public fun getNeedsSequenceToAutoInc ()Z
public final fun getOriginalDataTypeProvider ()Lorg/jetbrains/exposed/sql/vendors/DataTypeProvider;
Expand Down Expand Up @@ -4031,6 +4033,14 @@ public final class org/jetbrains/exposed/sql/vendors/H2Dialect$H2CompatibilityMo
public static fun values ()[Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2CompatibilityMode;
}

public final class org/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion : java/lang/Enum {
public static final field One Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion;
public static final field Two Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion;
public static fun getEntries ()Lkotlin/enums/EnumEntries;
public static fun valueOf (Ljava/lang/String;)Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion;
public static fun values ()[Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2MajorVersion;
}

public final class org/jetbrains/exposed/sql/vendors/H2Kt {
public static final fun getH2Mode (Lorg/jetbrains/exposed/sql/vendors/DatabaseDialect;)Lorg/jetbrains/exposed/sql/vendors/H2Dialect$H2CompatibilityMode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ object SchemaUtils {
columnType.nullable
}
val incorrectNullability = existingCol.nullable != colNullable
// Exposed doesn't support changing sequences on columns
val incorrectAutoInc = existingCol.autoIncrement != columnType.isAutoInc && col.autoIncColumnType?.autoincSeq == null

val incorrectAutoInc = isIncorrectAutoInc(existingCol, col)

val incorrectDefaults = isIncorrectDefault(dataTypeProvider, existingCol, col)

Expand Down Expand Up @@ -358,6 +358,15 @@ object SchemaUtils {
return statements
}

private fun isIncorrectAutoInc(columnMetadata: ColumnMetadata, column: Column<*>): Boolean = when {
!columnMetadata.autoIncrement && column.columnType.isAutoInc && column.autoIncColumnType?.sequence == null ->
true
columnMetadata.autoIncrement && column.columnType.isAutoInc && column.autoIncColumnType?.sequence != null ->
true
columnMetadata.autoIncrement && !column.columnType.isAutoInc -> true
else -> false
}

/**
* For DDL purposes we do not segregate the cases when the default value was not specified, and when it
* was explicitly set to `null`.
Expand Down
28 changes: 25 additions & 3 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,24 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
/** Returns all foreign key constraints declared on the table. */
val foreignKeys: List<ForeignKeyConstraint> get() = columns.mapNotNull { it.foreignKey } + _foreignKeys

/**
* Returns all sequences declared on the table, along with any auto-generated sequences that are not explicitly
* declared by the user but associated with the table.
*/
val sequences: List<Sequence>
get() = columns.filter { it.columnType.isAutoInc }.mapNotNull { column ->
column.autoIncColumnType?.sequence
?: column.takeIf { currentDialect is PostgreSQLDialect }?.let {
val fallbackSequenceName = fallbackSequenceName(tableName = tableName, columnName = it.name)
Sequence(
fallbackSequenceName,
startWith = 1,
minValue = 1,
maxValue = Long.MAX_VALUE
)
}
}

private val checkConstraints = mutableListOf<Pair<String, Op<Boolean>>>()

private val generatedCheckPrefix = "chk_${tableName}_unsigned_"
Expand Down Expand Up @@ -1572,10 +1590,9 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
private fun <T> Column<T>.cloneWithAutoInc(idSeqName: String?): Column<T> = when (columnType) {
is AutoIncColumnType -> this
is ColumnType -> {
val q = if (tableName.contains('.')) "\"" else ""
val fallbackSeqName = "$q${tableName.replace("\"", "")}_${name}_seq$q"
val fallbackSequenceName = fallbackSequenceName(tableName = tableName, columnName = name)
this.withColumnType(
AutoIncColumnType(columnType, idSeqName, fallbackSeqName)
AutoIncColumnType(columnType, idSeqName, fallbackSequenceName)
)
}

Expand Down Expand Up @@ -1711,3 +1728,8 @@ internal fun String.isAlreadyQuoted(): Boolean =
listOf("\"", "'", "`").any { quoteString ->
startsWith(quoteString) && endsWith(quoteString)
}

internal fun fallbackSequenceName(tableName: String, columnName: String): String {
val q = if (tableName.contains('.')) "\"" else ""
return "$q${tableName.replace("\"", "")}_${columnName}_seq$q"
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ interface DatabaseDialect {
/** Returns `true` if the dialect supports returning multiple generated keys as a result of an insert operation, `false` otherwise. */
val supportsMultipleGeneratedKeys: Boolean

/** Returns`true` if the dialect supports returning generated keys obtained from a sequence. */
/** Returns `true` if the dialect supports returning generated keys obtained from a sequence. */
val supportsSequenceAsGeneratedKeys: Boolean get() = supportsCreateSequence

/** Returns `true` if the dialect supports only returning generated keys that are identity columns. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,15 @@ open class H2Dialect : VendorDialect(dialectName, H2DataTypeProvider, H2Function

override fun toString(): String = "H2Dialect[$dialectName, $h2Mode]"

internal enum class H2MajorVersion {
enum class H2MajorVersion {
One, Two
}

internal val version by lazy {
exactH2Version(TransactionManager.current())
}

internal val majorVersion: H2MajorVersion by lazy {
val majorVersion: H2MajorVersion by lazy {
when {
version.startsWith("1.") -> H2MajorVersion.One
version.startsWith("2.") -> H2MajorVersion.Two
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,31 +391,55 @@ open class PostgreSQLDialect(override val name: String = dialectName) : VendorDi

override fun isAllowedAsColumnDefault(e: Expression<*>): Boolean = true

override fun modifyColumn(column: Column<*>, columnDiff: ColumnDiff): List<String> = listOf(
buildString {
val tr = TransactionManager.current()
append("ALTER TABLE ${tr.identity(column.table)} ")
val colName = tr.identity(column)
append("ALTER COLUMN $colName TYPE ${column.columnType.sqlType()}")

if (columnDiff.nullability) {
append(", ALTER COLUMN $colName ")
if (column.columnType.nullable) {
append("DROP ")
override fun modifyColumn(column: Column<*>, columnDiff: ColumnDiff): List<String> {
val list = mutableListOf(
buildString {
val tr = TransactionManager.current()
append("ALTER TABLE ${tr.identity(column.table)} ")
val colName = tr.identity(column)

if (columnDiff.autoInc && column.autoIncColumnType != null) {
val sequence = column.autoIncColumnType?.sequence
if (sequence != null) {
append("ALTER COLUMN $colName TYPE ${column.columnType.sqlType()}")
append(", ALTER COLUMN $colName DROP DEFAULT")
} else {
val fallbackSequenceName = fallbackSequenceName(tableName = column.table.tableName, columnName = column.name)
append("ALTER COLUMN $colName SET DEFAULT nextval('$fallbackSequenceName')")
}
} else {
append("SET ")
append("ALTER COLUMN $colName TYPE ${column.columnType.sqlType()}")
}
append("NOT NULL")
}
if (columnDiff.defaults) {
column.dbDefaultValue?.let {
append(", ALTER COLUMN $colName SET DEFAULT ${PostgreSQLDataTypeProvider.processForDefaultValue(it)}")
} ?: run {
append(", ALTER COLUMN $colName DROP DEFAULT")

if (columnDiff.nullability) {
append(", ALTER COLUMN $colName ")
if (column.columnType.nullable) {
append("DROP ")
} else {
append("SET ")
}
append("NOT NULL")
}
if (columnDiff.defaults) {
column.dbDefaultValue?.let {
append(", ALTER COLUMN $colName SET DEFAULT ${PostgreSQLDataTypeProvider.processForDefaultValue(it)}")
} ?: run {
append(", ALTER COLUMN $colName DROP DEFAULT")
}
}
}
)
if (columnDiff.autoInc && column.autoIncColumnType != null && column.autoIncColumnType?.sequence == null) {
list.add(
buildString {
val fallbackSequenceName = fallbackSequenceName(tableName = column.table.tableName, columnName = column.name)
val q = if (column.table.tableName.contains('.')) "\"" else ""
append("ALTER SEQUENCE $fallbackSequenceName OWNED BY $q${column.table.tableName.replace("\"", "")}.${column.name}$q")
}
)
}
)
return list
}

override fun createDatabase(name: String): String = "CREATE DATABASE ${name.inProperCase()}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,16 @@ open class SQLServerDialect : VendorDialect(dialectName, SQLServerDataTypeProvid

val statements = mutableListOf<String>()

val autoIncColumnType = column.autoIncColumnType
val replaceWithNewColumn = columnDiff.autoInc && autoIncColumnType != null && autoIncColumnType.sequence == null

statements.add(
buildString {
append(alterTablePart + "ALTER COLUMN ${transaction.identity(column)} ${column.columnType.sqlType()}")
if (replaceWithNewColumn) {
append(alterTablePart + "ADD NEW_${transaction.identity(column)} ${column.columnType.sqlType()}")
} else {
append(alterTablePart + "ALTER COLUMN ${transaction.identity(column)} ${column.columnType.sqlType()}")
}

if (columnDiff.nullability) {
val defaultValue = column.dbDefaultValue
Expand Down Expand Up @@ -392,6 +399,13 @@ open class SQLServerDialect : VendorDialect(dialectName, SQLServerDataTypeProvid
)
}

if (replaceWithNewColumn) {
with(statements) {
add(alterTablePart + "DROP COLUMN ${transaction.identity(column)}")
add("EXEC sp_rename '${transaction.identity(column.table)}.NEW_${transaction.identity(column)}', '${transaction.identity(column)}', 'COLUMN'")
}
}

return statements
}

Expand Down
73 changes: 71 additions & 2 deletions exposed-migration/src/main/kotlin/MigrationUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import org.jetbrains.exposed.sql.SchemaUtils.checkExcessiveIndices
import org.jetbrains.exposed.sql.SchemaUtils.checkMappingConsistence
import org.jetbrains.exposed.sql.SchemaUtils.createStatements
import org.jetbrains.exposed.sql.SchemaUtils.statementsRequiredToActualizeScheme
import org.jetbrains.exposed.sql.Sequence
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.exposedLogger
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.H2Dialect
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.jetbrains.exposed.sql.vendors.SQLiteDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
Expand Down Expand Up @@ -69,6 +72,9 @@ object MigrationUtils {
val createStatements = logTimeSpent("Preparing create tables statements", withLogs) {
createStatements(tables = tablesToCreate.toTypedArray())
}
val createSequencesStatements = logTimeSpent("Preparing create sequences statements", withLogs) {
checkMissingSequences(tables = tables, withLogs).flatMap { it.createStatement() }
}
val alterStatements = logTimeSpent("Preparing alter table statements", withLogs) {
addMissingColumnsStatements(tables = tablesToAlter.toTypedArray(), withLogs)
}
Expand All @@ -80,7 +86,7 @@ object MigrationUtils {
).filter { it !in (createStatements + alterStatements) }
}

val allStatements = createStatements + alterStatements + modifyTablesStatements
val allStatements = createStatements + createSequencesStatements + alterStatements + modifyTablesStatements
return allStatements
}

Expand All @@ -92,7 +98,8 @@ object MigrationUtils {
return checkMissingIndices(tables = tables, withLogs).flatMap { it.createStatement() } +
checkUnmappedIndices(tables = tables, withLogs).flatMap { it.dropStatement() } +
checkExcessiveForeignKeyConstraints(tables = tables, withLogs).flatMap { it.dropStatement() } +
checkExcessiveIndices(tables = tables, withLogs).flatMap { it.dropStatement() }
checkExcessiveIndices(tables = tables, withLogs).flatMap { it.dropStatement() } +
checkUnmappedSequences(tables = tables, withLogs).flatMap { it.dropStatement() }
}

/**
Expand Down Expand Up @@ -216,6 +223,65 @@ object MigrationUtils {
return toDrop.toList()
}

/**
* Checks all [tables] for any that have sequences that are missing in the database but are defined in the code. If
* found, this function also logs the SQL statements that can be used to create these sequences.
*
* @return List of sequences that are missing and can be created.
*/
private fun checkMissingSequences(vararg tables: Table, withLogs: Boolean): List<Sequence> {
if (!currentDialect.supportsCreateSequence) {
return emptyList()
}

fun Collection<Sequence>.log(mainMessage: String) {
if (withLogs && isNotEmpty()) {
exposedLogger.warn(joinToString(prefix = "$mainMessage\n\t\t", separator = "\n\t\t"))
}
}

val existingSequencesNames: Set<String> = currentDialect.sequences().toSet()

val missingSequences = mutableSetOf<Sequence>()

val mappedSequences: Set<Sequence> = tables.flatMap { table -> table.sequences }.toSet()

missingSequences.addAll(mappedSequences.filterNot { it.identifier.inProperCase() in existingSequencesNames })

missingSequences.log("Sequences missed from database (will be created):")
return missingSequences.toList()
}

/**
* Checks all [tables] for any that have sequences that exist in the database but are not mapped in the code. If
* found, this function also logs the SQL statements that can be used to drop these sequences.
*
* @return List of sequences that are unmapped and can be dropped.
*/
private fun checkUnmappedSequences(vararg tables: Table, withLogs: Boolean): List<Sequence> {
if (!currentDialect.supportsCreateSequence || (currentDialect as? H2Dialect)?.majorVersion == H2Dialect.H2MajorVersion.One) {
return emptyList()
}

fun Collection<Sequence>.log(mainMessage: String) {
if (withLogs && isNotEmpty()) {
exposedLogger.warn(joinToString(prefix = "$mainMessage\n\t\t", separator = "\n\t\t"))
}
}

val existingSequencesNames: Set<String> = currentDialect.sequences().toSet()

val unmappedSequences = mutableSetOf<Sequence>()

val mappedSequencesNames: Set<String> = tables.flatMap { table -> table.sequences.map { it.identifier.inProperCase() } }.toSet()

unmappedSequences.addAll(existingSequencesNames.subtract(mappedSequencesNames).map { Sequence(it) })

unmappedSequences.log("Sequences exist in database and not mapped in code:")

return unmappedSequences.toList()
}

private inline fun <R> logTimeSpent(message: String, withLogs: Boolean, block: () -> R): R {
return if (withLogs) {
val start = System.currentTimeMillis()
Expand All @@ -227,3 +293,6 @@ object MigrationUtils {
}
}
}

internal fun String.inProperCase(): String =
TransactionManager.currentOrNull()?.db?.identifierManager?.inProperCase(this@inProperCase) ?: this
Loading

0 comments on commit bd74c5d

Please sign in to comment.