Skip to content

Commit

Permalink
fix: EXPOSED-50 customEnumeration reference column error (JetBrains#1785
Browse files Browse the repository at this point in the history
)

It is possible to to create a column with a reference to both EnumerationColumnType
and EnumerationNameColumnType, but attempting to do so on a column created via
customEnumeration() throws an NPE.

This occurs when attempting to clone() the column with the custom database type,
as the resulting column's kClass does not have a primary constructor. A primary
constructor is needed to create a new instance and this is not provided by the
anonymous object used as a type when the column is being registered.

Extracting this object to its own CustomEnumerationColumnType means a valid KClass
instance is provided to clone(), with a valid primary constructor.

Add unit tests for all 3 types of enumeration columns with references.

* fix: EXPOSED-50 customEnumeration reference column error

Fix broken tests:
- Shared table object is being altered differently in each unit test. New test
that needs to create a unique index for referencing must have that unique index
dropped by other unit tests.
- MySQL tests require shared table object's unique index to only
be added if it doesn't already exist (created by a previous test).
  • Loading branch information
bog-walk authored and saral committed Oct 3, 2023
1 parent 1afe347 commit 1a07180
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 21 deletions.
13 changes: 13 additions & 0 deletions exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,19 @@ public final class org/jetbrains/exposed/sql/CurrentRowWindowFrameBound : org/je
public fun toQueryBuilder (Lorg/jetbrains/exposed/sql/QueryBuilder;)V
}

public final class org/jetbrains/exposed/sql/CustomEnumerationColumnType : org/jetbrains/exposed/sql/StringColumnType {
public fun <init> (Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V
public final fun getFromDb ()Lkotlin/jvm/functions/Function1;
public final fun getName ()Ljava/lang/String;
public final fun getSql ()Ljava/lang/String;
public final fun getToDb ()Lkotlin/jvm/functions/Function1;
public fun nonNullValueToString (Ljava/lang/Object;)Ljava/lang/String;
public fun notNullValueToDB (Ljava/lang/Object;)Ljava/lang/Object;
public fun sqlType ()Ljava/lang/String;
public fun valueFromDB (Ljava/lang/Object;)Ljava/lang/Enum;
public synthetic fun valueFromDB (Ljava/lang/Object;)Ljava/lang/Object;
}

public class org/jetbrains/exposed/sql/CustomFunction : org/jetbrains/exposed/sql/Function {
public fun <init> (Ljava/lang/String;Lorg/jetbrains/exposed/sql/IColumnType;[Lorg/jetbrains/exposed/sql/Expression;)V
public final fun getExpr ()[Lorg/jetbrains/exposed/sql/Expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import java.sql.Clob
import java.sql.ResultSet
import java.util.*
import kotlin.reflect.KClass
import kotlin.reflect.full.isSubclassOf

/**
* Interface common to all column types.
Expand Down Expand Up @@ -901,6 +902,30 @@ class EnumerationNameColumnType<T : Enum<T>>(
}
}

/**
* Enumeration column for storing enums of type [T] using the custom SQL type [sql].
*/
class CustomEnumerationColumnType<T : Enum<T>>(
/** Returns the name of this column type instance. */
val name: String,
/** Returns the SQL definition used for this column type. */
val sql: String?,
/** Returns the function that converts a value received from a database to an enumeration instance [T]. */
val fromDb: (Any) -> T,
/** Returns the function that converts an enumeration instance [T] to a value that will be stored to a database. */
val toDb: (T) -> Any
) : StringColumnType() {
override fun sqlType(): String = sql ?: error("Column $name should exist in database")

@Suppress("UNCHECKED_CAST")
override fun valueFromDB(value: Any): T = if (value::class.isSubclassOf(Enum::class)) value as T else fromDb(value)

@Suppress("UNCHECKED_CAST")
override fun notNullValueToDB(value: Any): Any = toDb(value as T)

override fun nonNullValueToString(value: Any): String = super.nonNullValueToString(notNullValueToDB(value))
}

// JSON columns

/**
Expand Down
25 changes: 7 additions & 18 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import kotlin.reflect.KClass
import kotlin.reflect.KMutableProperty1
import kotlin.reflect.KParameter
import kotlin.reflect.KProperty1
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.primaryConstructor

Expand Down Expand Up @@ -626,31 +625,21 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
enumerationByName(name, length, T::class)

/**
* Creates an enumeration column with custom SQL type.
* The main usage is to use a database specific type.
* Creates an enumeration column, with the custom SQL type [sql], for storing enums of type [T] using this database-specific type.
*
* See [https://github.com/JetBrains/Exposed/wiki/DataTypes#how-to-use-database-enum-types] for more details.
* See [Wiki](https://github.com/JetBrains/Exposed/wiki/DataTypes#how-to-use-database-enum-types) for more details.
*
* @param name The column name
* @param sql A SQL definition for the column
* @param fromDb A lambda to convert a value received from a database to an enumeration instance
* @param toDb A lambda to convert an enumeration instance to a value which will be stored to a database
* @param name Name of the column
* @param sql SQL definition for the column
* @param fromDb Function that converts a value received from a database to an enumeration instance [T]
* @param toDb Function that converts an enumeration instance [T] to a value that will be stored to a database
*/
@Suppress("UNCHECKED_CAST")
fun <T : Enum<T>> customEnumeration(
name: String,
sql: String? = null,
fromDb: (Any) -> T,
toDb: (T) -> Any
): Column<T> = registerColumn(
name,
object : StringColumnType() {
override fun sqlType(): String = sql ?: error("Column $name should exists in database ")
override fun valueFromDB(value: Any): T = if (value::class.isSubclassOf(Enum::class)) value as T else fromDb(value)
override fun notNullValueToDB(value: Any): Any = toDb(value as T)
override fun nonNullValueToString(value: Any): String = super.nonNullValueToString(notNullValueToDB(value))
}
)
): Column<T> = registerColumn(name, CustomEnumerationColumnType(name, sql, fromDb, toDb))

// JSON columns

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import org.jetbrains.exposed.sql.vendors.PostgreSQLDialect
import org.junit.Test

class EnumerationTests : DatabaseTestsBase() {
private val supportsCustomEnumerationDB = TestDB.mySqlRelatedDB + listOf(TestDB.H2, TestDB.H2_PSQL, TestDB.POSTGRESQL, TestDB.POSTGRESQLNG)

object EnumTable : IntIdTable("EnumTable") {
internal var enumColumn: Column<DDLTests.Foo> = enumeration("enumColumn")

Expand All @@ -41,7 +43,7 @@ class EnumerationTests : DatabaseTestsBase() {

@Test
fun testCustomEnumeration01() {
withDb(listOf(TestDB.H2, TestDB.H2_PSQL, TestDB.MYSQL, TestDB.POSTGRESQL, TestDB.POSTGRESQLNG)) {
withDb(supportsCustomEnumerationDB) {
val sqlType = when (currentDialectTest) {
is H2Dialect, is MysqlDialect -> "ENUM('Bar', 'Baz')"
is PostgreSQLDialect -> "FooEnum"
Expand All @@ -61,6 +63,10 @@ class EnumerationTests : DatabaseTestsBase() {
}
EnumTable.initEnumColumn(sqlType)
SchemaUtils.create(EnumTable)
// drop shared table object's unique index if created in other test
if (EnumTable.indices.isNotEmpty()) {
exec(EnumTable.indices.first().dropStatement().single())
}
EnumTable.insert {
it[enumColumn] = DDLTests.Foo.Bar
}
Expand Down Expand Up @@ -90,7 +96,7 @@ class EnumerationTests : DatabaseTestsBase() {

@Test
fun testCustomEnumerationWithDefaultValue() {
withDb(listOf(TestDB.H2, TestDB.H2_MYSQL, TestDB.H2_PSQL, TestDB.MYSQL, TestDB.POSTGRESQL, TestDB.POSTGRESQLNG)) {
withDb(supportsCustomEnumerationDB) {
val sqlType = when (currentDialectTest) {
is H2Dialect, is MysqlDialect -> "ENUM('Bar', 'Baz')"
is PostgreSQLDialect -> "FooEnum2"
Expand All @@ -103,9 +109,13 @@ class EnumerationTests : DatabaseTestsBase() {
}
EnumTable.initEnumColumn(sqlType)
with(EnumTable) {
EnumTable.enumColumn.default(DDLTests.Foo.Bar)
enumColumn.default(DDLTests.Foo.Bar)
}
SchemaUtils.create(EnumTable)
// drop shared table object's unique index if created in other test
if (EnumTable.indices.isNotEmpty()) {
exec(EnumTable.indices.first().dropStatement().single())
}

EnumTable.insert { }
val default = EnumTable.selectAll().single()[EnumTable.enumColumn]
Expand All @@ -117,4 +127,84 @@ class EnumerationTests : DatabaseTestsBase() {
}
}
}

@Test
fun testCustomEnumerationWithReference() {
val referenceTable = object : Table("ref_table") {
var referenceColumn: Column<DDLTests.Foo> = enumeration("ref_column")

fun initRefColumn() {
(columns as MutableList<Column<*>>).remove(referenceColumn)
referenceColumn = reference("ref_column", EnumTable.enumColumn)
}
}

withDb(supportsCustomEnumerationDB) {
val sqlType = when (currentDialectTest) {
is H2Dialect, is MysqlDialect -> "ENUM('Bar', 'Baz')"
is PostgreSQLDialect -> "RefEnum"
else -> error("Unsupported case")
}
try {
if (currentDialectTest is PostgreSQLDialect) {
exec("DROP TYPE IF EXISTS $sqlType;")
exec("CREATE TYPE $sqlType AS ENUM ('Bar', 'Baz');")
}
EnumTable.initEnumColumn(sqlType)
with(EnumTable) {
if (indices.isEmpty()) enumColumn.uniqueIndex()
}
SchemaUtils.create(EnumTable)

referenceTable.initRefColumn()
SchemaUtils.create(referenceTable)

val fooBar = DDLTests.Foo.Bar
val id1 = EnumTable.insert {
it[enumColumn] = fooBar
} get EnumTable.enumColumn
referenceTable.insert {
it[referenceColumn] = id1
}

assertEquals(fooBar, EnumTable.selectAll().single()[EnumTable.enumColumn])
assertEquals(fooBar, referenceTable.selectAll().single()[referenceTable.referenceColumn])
} finally {
SchemaUtils.drop(referenceTable)
exec(EnumTable.indices.first().dropStatement().single())
SchemaUtils.drop(EnumTable)
}
}
}

@Test
fun testEnumerationColumnsWithReference() {
val tester = object : Table("tester") {
val enumColumn = enumeration<DDLTests.Foo>("enum_column").uniqueIndex()
val enumNameColumn = enumerationByName<DDLTests.Foo>("enum_name_column", 32).uniqueIndex()
}
val referenceTable = object : Table("ref_table") {
val referenceColumn = reference("ref_column", tester.enumColumn)
val referenceNameColumn = reference("ref_name_column", tester.enumNameColumn)
}

withTables(tester, referenceTable) {
val fooBar = DDLTests.Foo.Bar
val fooBaz = DDLTests.Foo.Baz
val entry = tester.insert {
it[enumColumn] = fooBar
it[enumNameColumn] = fooBaz
}
referenceTable.insert {
it[referenceColumn] = entry[tester.enumColumn]
it[referenceNameColumn] = entry[tester.enumNameColumn]
}

assertEquals(fooBar, tester.selectAll().single()[tester.enumColumn])
assertEquals(fooBar, referenceTable.selectAll().single()[referenceTable.referenceColumn])

assertEquals(fooBaz, tester.selectAll().single()[tester.enumNameColumn])
assertEquals(fooBaz, referenceTable.selectAll().single()[referenceTable.referenceNameColumn])
}
}
}

0 comments on commit 1a07180

Please sign in to comment.