Skip to content

Commit

Permalink
More strict type checking for expressions and column functions/expres…
Browse files Browse the repository at this point in the history
…sions
  • Loading branch information
Tapac committed Nov 28, 2017
1 parent 69ffa6c commit 008db32
Show file tree
Hide file tree
Showing 20 changed files with 196 additions and 151 deletions.
94 changes: 48 additions & 46 deletions src/main/kotlin/org/jetbrains/exposed/dao/Entity.kt

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions src/main/kotlin/org/jetbrains/exposed/dao/EntityHook.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@ enum class EntityChangeType {
}


data class EntityChange<ID: Any>(val entityClass: EntityClass<ID, Entity<ID>>, val id: EntityID<ID>, var changeType: EntityChangeType)
data class EntityChange(val entityClass: EntityClass<*, Entity<*>>, val id: EntityID<*>, var changeType: EntityChangeType)

fun<ID: Any> EntityChange<ID>.toEntity() : Entity<ID>? = entityClass.findById(id)
fun<ID: Comparable<ID>, T: Entity<ID>> EntityChange.toEntity() : T? = (entityClass as EntityClass<ID, T>).findById(id as EntityID<ID>)

fun<ID: Any,T: Entity<ID>> EntityChange<*>.toEntity(klass: EntityClass<ID, T>) : T? {
fun<ID: Comparable<ID>,T: Entity<ID>> EntityChange.toEntity(klass: EntityClass<ID, T>) : T? {
if (!entityClass.isAssignableTo(klass)) return null
@Suppress("UNCHECKED_CAST")
return toEntity() as? T
return toEntity<ID, T>()
}

object EntityHook {
private val entitySubscribers = CopyOnWriteArrayList<(EntityChange<*>) -> Unit>()
private val entitySubscribers = CopyOnWriteArrayList<(EntityChange) -> Unit>()

private val events by transactionScope { CopyOnWriteArrayList<EntityChange<*>>() }
private val events by transactionScope { CopyOnWriteArrayList<EntityChange>() }

val registeredEvents: List<EntityChange<*>> get() = events.toList()
val registeredEvents: List<EntityChange> get() = events.toList()

fun subscribe (action: (EntityChange<*>) -> Unit): (EntityChange<*>) -> Unit {
fun subscribe (action: (EntityChange) -> Unit): (EntityChange) -> Unit {
entitySubscribers.add(action)
return action
}

fun unsubscribe (action: (EntityChange<*>) -> Unit) {
fun unsubscribe (action: (EntityChange) -> Unit) {
entitySubscribers.remove(action)
}

fun registerChange(change: EntityChange<*>) {
fun registerChange(change: EntityChange) {
if (events.lastOrNull() != change) {
events.add(change)
}
Expand All @@ -53,7 +53,7 @@ object EntityHook {
}
}

fun <T> withHook(action: (EntityChange<*>) -> Unit, body: ()->T): T {
fun <T> withHook(action: (EntityChange) -> Unit, body: ()->T): T {
EntityHook.subscribe(action)
try {
return body().apply {
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jetbrains/exposed/dao/IdTable.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.jetbrains.exposed.dao
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.Table

abstract class IdTable<T:Any>(name: String): Table(name) {
abstract class IdTable<T:Comparable<T>>(name: String): Table(name) {
abstract val id : Column<EntityID<T>>

}
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jetbrains/exposed/sql/Alias.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Alias<out T:Table>(val delegate: T, val alias: String) : Table() {
}


class ExpressionAlias<out T: Any?>(val delegate: Expression<T>, val alias: String) : Expression<T>() {
class ExpressionAlias<T>(val delegate: Expression<T>, val alias: String) : Expression<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "${delegate.toSQL(queryBuilder)} $alias"

fun aliasOnlyExpression() = object: Expression<T>() {
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ val Column<*>.autoIncSeqName : String? get() {
?: (columnType as? EntityIDColumnType<*>)?.idColumn?.autoIncSeqName
}

class EntityIDColumnType<T:Any>(val idColumn: Column<T>) : ColumnType(false) {
class EntityIDColumnType<T:Comparable<T>>(val idColumn: Column<T>) : ColumnType(false) {

init {
assert(idColumn.table is IdTable<*>){"EntityId supported only for IdTables"}
Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class QueryBuilder(val prepared: Boolean) {
}
}

abstract class Expression<out T> {
abstract class Expression<T> {
private val _hashCode by lazy {
toString().hashCode()
}
Expand All @@ -46,7 +46,7 @@ abstract class Expression<out T> {
}
}

abstract class ExpressionWithColumnType<out T> : Expression<T>() {
abstract class ExpressionWithColumnType<T> : Expression<T>() {
// used for operations with literals
abstract val columnType: IColumnType
}
45 changes: 23 additions & 22 deletions src/main/kotlin/org/jetbrains/exposed/sql/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.joda.time.DateTime
import java.math.BigDecimal
import java.util.*

abstract class Function<out T> : ExpressionWithColumnType<T>()
abstract class Function<T> : ExpressionWithColumnType<T>()

class Count(val expr: Expression<*>, val distinct: Boolean = false): Function<Int>() {
override fun toSQL(queryBuilder: QueryBuilder): String =
Expand All @@ -13,7 +13,7 @@ class Count(val expr: Expression<*>, val distinct: Boolean = false): Function<In
override val columnType: IColumnType = IntegerColumnType()
}

class Date(val expr: Expression<DateTime?>): Function<DateTime>() {
class Date<T:DateTime?>(val expr: Expression<T>): Function<DateTime>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "DATE(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DateColumnType(false)
Expand All @@ -24,80 +24,80 @@ class CurrentDateTime : Function<DateTime>() {
override val columnType: IColumnType = DateColumnType(false)
}

class Month(val expr: Expression<DateTime?>): Function<DateTime>() {
class Month<T:DateTime?>(val expr: Expression<T>): Function<DateTime>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "MONTH(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DateColumnType(false)
}

class LowerCase<out T: String?>(val expr: Expression<T>) : Function<T>() {
class LowerCase<T: String?>(val expr: Expression<T>) : Function<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "LOWER(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = StringColumnType()
}

class UpperCase<out T: String?>(val expr: Expression<T>) : Function<T>() {
class UpperCase<T: String?>(val expr: Expression<T>) : Function<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "UPPER(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = StringColumnType()
}

class Min<out T>(val expr: Expression<T>, _columnType: IColumnType): Function<T?>() {
class Min<T:Comparable<T>, S:T?>(val expr: Expression<in S>, _columnType: IColumnType): Function<S?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "MIN(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = _columnType
}

class Max<out T>(val expr: Expression<T>, _columnType: IColumnType): Function<T?>() {
class Max<T:Comparable<T>, S:T?>(val expr: Expression<in S>, _columnType: IColumnType): Function<S>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "MAX(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = _columnType
}

class Avg<out T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
class Avg<T:Comparable<T>, in S:T?>(val expr: Expression<in S>, scale: Int): Function<BigDecimal?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "AVG(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DecimalColumnType(Int.MAX_VALUE, scale)
}

class StdDevPop<out T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
class StdDevPop<T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "STDDEV_POP(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DecimalColumnType(Int.MAX_VALUE, scale)
}

class StdDevSamp<out T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
class StdDevSamp<T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "STDDEV_SAMP(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DecimalColumnType(Int.MAX_VALUE, scale)
}

class VarPop<out T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
class VarPop<T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "VAR_POP(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DecimalColumnType(Int.MAX_VALUE, scale)
}

class VarSamp<out T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
class VarSamp<T>(val expr: Expression<T>, scale: Int): Function<BigDecimal?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "VAR_SAMP(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = DecimalColumnType(Int.MAX_VALUE, scale)
}

class Sum<out T>(val expr: Expression<T>, _columnType: IColumnType): Function<T?>() {
class Sum<T>(val expr: Expression<T>, _columnType: IColumnType): Function<T?>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "SUM(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = _columnType
}

class Coalesce<out T>(val expr: ExpressionWithColumnType<T?>, val alternate: ExpressionWithColumnType<out T>): Function<T>() {
class Coalesce<out T, S:T?, R:T>(private val expr: ExpressionWithColumnType<S>, private val alternate: ExpressionWithColumnType<out T>): Function<R>() {
override fun toSQL(queryBuilder: QueryBuilder): String =
"COALESCE(${expr.toSQL(queryBuilder)}, ${alternate.toSQL(queryBuilder)})"

override val columnType: IColumnType = alternate.columnType
}

class Substring(val expr: Expression<String?>, val start: ExpressionWithColumnType<Int>, val length: ExpressionWithColumnType<Int>): Function<String>() {
class Substring<T:String?>(private val expr: Expression<T>, private val start: ExpressionWithColumnType<Int>, val length: ExpressionWithColumnType<Int>): Function<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String
= currentDialect.functionProvider.substring(expr, start, length, queryBuilder)

Expand All @@ -112,12 +112,12 @@ class Random(val seed: Int? = null) : Function<BigDecimal>() {
override val columnType: IColumnType = DecimalColumnType(38, 20)
}

class Cast<out T>(val expr: Expression<*>, override val columnType: IColumnType) : Function<T?>() {
class Cast<T>(val expr: Expression<*>, override val columnType: IColumnType) : Function<T?>() {
override fun toSQL(queryBuilder: QueryBuilder): String
= currentDialect.functionProvider.cast(expr, columnType, queryBuilder)
}

class Trim(val expr: Expression<*>): Function<String>() {
class Trim<T:String?>(val expr: Expression<T>): Function<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String = "TRIM(${expr.toSQL(queryBuilder)})"

override val columnType: IColumnType = StringColumnType()
Expand All @@ -129,17 +129,18 @@ class Case(val value: Expression<*>? = null) {
}

class CaseWhen<T> (val value: Expression<*>?) {
val cases: ArrayList<Pair<Expression<Boolean>, Expression<T>>> = ArrayList()
val cases: ArrayList<Pair<Expression<Boolean>, Expression<out T>>> = ArrayList()

fun When (cond: Expression<Boolean>, result: Expression<T>) : CaseWhen<T> {
@Suppress("UNCHECKED_CAST")
fun <R:T> When (cond: Expression<Boolean>, result: Expression<R>) : CaseWhen<R> {
cases.add( cond to result )
return this
return this as CaseWhen<R>
}

fun Else(e: Expression<T>) : Expression<T> = CaseWhenElse(this, e)
fun <R:T> Else(e: Expression<R>) : Expression<R> = CaseWhenElse(this, e)
}

class CaseWhenElse<T> (val caseWhen: CaseWhen<T>, val elseResult: Expression<T>) : Expression<T>() {
class CaseWhenElse<T, R:T> (val caseWhen: CaseWhen<T>, val elseResult: Expression<R>) : Expression<R>() {
override fun toSQL(queryBuilder: QueryBuilder): String = buildString {
append("CASE")
if (caseWhen.value != null)
Expand Down
22 changes: 11 additions & 11 deletions src/main/kotlin/org/jetbrains/exposed/sql/Op.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.jetbrains.exposed.sql
import org.jetbrains.exposed.dao.EntityID
import org.joda.time.DateTime

abstract class Op<out T> : Expression<T>() {
abstract class Op<T> : Expression<T>() {
companion object {
inline fun <T> build(op: SqlExpressionBuilder.()-> Op<T>): Op<T> = SqlExpressionBuilder.op()
}
Expand All @@ -21,7 +21,7 @@ class IsNotNullOp(val expr: Expression<*>): Op<Boolean>() {
override fun toSQL(queryBuilder: QueryBuilder):String = "${expr.toSQL(queryBuilder)} IS NOT NULL"
}

class LiteralOp<out T>(override val columnType: IColumnType, val value: T): ExpressionWithColumnType<T>() {
class LiteralOp<T>(override val columnType: IColumnType, val value: T): ExpressionWithColumnType<T>() {
override fun toSQL(queryBuilder: QueryBuilder):String = columnType.valueToString(value)
}

Expand All @@ -30,11 +30,11 @@ class Between(val expr: Expression<*>, val from: LiteralOp<*>, val to: LiteralOp
"${expr.toSQL(queryBuilder)} BETWEEN ${from.toSQL(queryBuilder)} AND ${to.toSQL(queryBuilder)}"
}

class NoOpConversion<out T, out S>(val expr: Expression<T>, override val columnType: IColumnType): ExpressionWithColumnType<S>() {
class NoOpConversion<T, S>(val expr: Expression<T>, override val columnType: IColumnType): ExpressionWithColumnType<S>() {
override fun toSQL(queryBuilder: QueryBuilder): String = expr.toSQL(queryBuilder)
}

class InListOrNotInListOp<out T>(val expr: ExpressionWithColumnType<T>, val list: Iterable<T>, val isInList: Boolean = true): Op<Boolean>() {
class InListOrNotInListOp<T>(val expr: ExpressionWithColumnType<T>, val list: Iterable<T>, val isInList: Boolean = true): Op<Boolean>() {

override fun toSQL(queryBuilder: QueryBuilder): String = buildString{
list.iterator().let { i ->
Expand Down Expand Up @@ -66,11 +66,11 @@ class InListOrNotInListOp<out T>(val expr: ExpressionWithColumnType<T>, val list
}
}

class QueryParameter<out T>(val value: T, val sqlType: IColumnType) : Expression<T>() {
class QueryParameter<T>(val value: T, val sqlType: IColumnType) : Expression<T>() {
override fun toSQL(queryBuilder: QueryBuilder): String = queryBuilder.registerArgument(sqlType, value)
}

fun <T:Any> idParam(value: EntityID<T>, column: Column<EntityID<T>>): Expression<EntityID<T>> = QueryParameter(value, EntityIDColumnType(column))
fun <T:Comparable<T>> idParam(value: EntityID<T>, column: Column<EntityID<T>>): Expression<EntityID<T>> = QueryParameter(value, EntityIDColumnType(column))
fun booleanParam(value: Boolean): Expression<Boolean> = QueryParameter(value, BooleanColumnType())
fun intParam(value: Int): Expression<Int> = QueryParameter(value, IntegerColumnType())
fun longParam(value: Long): Expression<Long> = QueryParameter(value, LongColumnType())
Expand Down Expand Up @@ -128,7 +128,7 @@ class AndOp(val expr1: Expression<Boolean>, val expr2: Expression<Boolean>): Op<
}
}

class OrOp<out T>(val expr1: Expression<T>, val expr2: Expression<T>): Op<Boolean>() {
class OrOp<T>(val expr1: Expression<T>, val expr2: Expression<T>): Op<Boolean>() {
override fun toSQL(queryBuilder: QueryBuilder) = "(${expr1.toSQL(queryBuilder)}) OR (${expr2.toSQL(queryBuilder)})"
}

Expand All @@ -140,19 +140,19 @@ class notExists(val query: Query) : Op<Boolean>() {
override fun toSQL(queryBuilder: QueryBuilder) = "NOT EXISTS (${query.prepareSQL(queryBuilder)})"
}

class PlusOp<out T, out S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
class PlusOp<T, S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
override fun toSQL(queryBuilder: QueryBuilder) = "${expr1.toSQL(queryBuilder)}+${expr2.toSQL(queryBuilder)}"
}

class MinusOp<out T, out S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
class MinusOp<T, S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
override fun toSQL(queryBuilder: QueryBuilder) = "${expr1.toSQL(queryBuilder)}-${expr2.toSQL(queryBuilder)}"
}

class TimesOp<out T, out S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
class TimesOp<T, S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
override fun toSQL(queryBuilder: QueryBuilder):String = "(${expr1.toSQL(queryBuilder)}) * (${expr2.toSQL(queryBuilder)})"
}

class DivideOp<out T, out S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
class DivideOp<T, S: T>(val expr1: Expression<T>, val expr2: Expression<S>, override val columnType: IColumnType): ExpressionWithColumnType<T>() {
override fun toSQL(queryBuilder: QueryBuilder):String =
"(${expr1.toSQL(queryBuilder)}) / (${expr2.toSQL(queryBuilder)})"
}
4 changes: 2 additions & 2 deletions src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fun <T:Table> T.insert(body: T.(InsertStatement<Number>)->Unit): InsertStatement
/**
* @sample org.jetbrains.exposed.sql.tests.shared.DMLTests.testGeneratedKey03
*/
fun <Key:Any, T: IdTable<Key>> T.insertAndGetId(ignore: Boolean = false, body: T.(InsertStatement<EntityID<Key>>)->Unit) =
fun <Key:Comparable<Key>, T: IdTable<Key>> T.insertAndGetId(ignore: Boolean = false, body: T.(InsertStatement<EntityID<Key>>)->Unit) =
InsertStatement<EntityID<Key>>(this, ignore).run {
body(this)
execute(TransactionManager.current())
Expand Down Expand Up @@ -97,7 +97,7 @@ fun <T:Table> T.insertIgnore(body: T.(UpdateBuilder<*>)->Unit): InsertStatement<
execute(TransactionManager.current())
}

fun <Key:Any, T: IdTable<Key>> T.insertIgnoreAndGetId(body: T.(UpdateBuilder<*>)->Unit) = InsertStatement<EntityID<Key>>(this, isIgnore = true).run {
fun <Key:Comparable<Key>, T: IdTable<Key>> T.insertIgnoreAndGetId(body: T.(UpdateBuilder<*>)->Unit) = InsertStatement<EntityID<Key>>(this, isIgnore = true).run {
body(this)
execute(TransactionManager.current())
generatedKey
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jetbrains/exposed/sql/Query.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ResultRow(size: Int, private val fieldIndex: Map<Expression<*>, Int>) {
internal fun create(columns : List<Column<*>>): ResultRow =
ResultRow(columns.size, columns.mapIndexed { i, c -> c to i }.toMap()).apply {
columns.forEach {
this[it] = it.defaultValueFun?.invoke() ?: if (!it.columnType.nullable) NotInitializedValue else null
this[it as Expression<Any?>] = it.defaultValueFun?.invoke() ?: if (!it.columnType.nullable) NotInitializedValue else null
}
}
}
Expand Down
Loading

0 comments on commit 008db32

Please sign in to comment.