Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: EXPOSED-54 CaseWhen.Else returns narrow Expression<R> #1800

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public final class org/jetbrains/exposed/sql/Case {

public final class org/jetbrains/exposed/sql/CaseWhen {
public fun <init> (Lorg/jetbrains/exposed/sql/Expression;)V
public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/Expression;
public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/ExpressionWithColumnType;
public final fun When (Lorg/jetbrains/exposed/sql/Expression;Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/CaseWhen;
public final fun getCases ()Ljava/util/List;
public final fun getValue ()Lorg/jetbrains/exposed/sql/Expression;
Expand Down
28 changes: 19 additions & 9 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class Random(
class CharLength<T : String?>(
val expr: Expression<T>
) : Function<Int?>(IntegerColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.charLength(expr, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
bog-walk marked this conversation as resolved.
Show resolved Hide resolved
currentDialect.functionProvider.charLength(expr, queryBuilder)
}

/**
Expand Down Expand Up @@ -105,7 +106,8 @@ class Concat(
/** Returns the expressions being concatenated. */
vararg val expr: Expression<*>
) : Function<String>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr)
}

/**
Expand All @@ -121,7 +123,8 @@ class GroupConcat<T : String?>(
/** Returns the order in which the elements of each group are sorted. */
vararg val orderBy: Pair<Expression<*>, SortOrder>
) : Function<T>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.groupConcat(this, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
currentDialect.functionProvider.groupConcat(this, queryBuilder)
}

/**
Expand All @@ -133,7 +136,8 @@ class Substring<T : String?>(
/** Returns the length of the substring. */
val length: Expression<Int>
) : Function<T>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.substring(expr, start, length, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
currentDialect.functionProvider.substring(expr, start, length, queryBuilder)
}

/**
Expand Down Expand Up @@ -346,7 +350,8 @@ sealed class NextVal<T>(
columnType: IColumnType
) : Function<T>(columnType) {

override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.nextVal(seq, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
currentDialect.functionProvider.nextVal(seq, queryBuilder)

class IntNextVal(seq: Sequence) : NextVal<Int>(seq, IntegerColumnType())
class LongNextVal(seq: Sequence) : NextVal<Long>(seq, LongColumnType())
Expand All @@ -368,10 +373,13 @@ class CaseWhen<T>(val value: Expression<*>?) {
return this as CaseWhen<R>
}

fun <R : T> Else(e: Expression<R>): Expression<R> = CaseWhenElse(this, e)
fun <R : T> Else(e: Expression<R>): ExpressionWithColumnType<R> = CaseWhenElse(this, e)
}

class CaseWhenElse<T, R : T>(val caseWhen: CaseWhen<T>, val elseResult: Expression<R>) : ExpressionWithColumnType<R>(), ComplexExpression {
class CaseWhenElse<T, R : T>(
val caseWhen: CaseWhen<T>,
val elseResult: Expression<R>
) : ExpressionWithColumnType<R>(), ComplexExpression {

override val columnType: IColumnType =
(elseResult as? ExpressionWithColumnType<R>)?.columnType
Expand All @@ -382,10 +390,11 @@ class CaseWhenElse<T, R : T>(val caseWhen: CaseWhen<T>, val elseResult: Expressi
append("CASE ")
if (caseWhen.value != null) {
+caseWhen.value
+" "
}

for ((first, second) in caseWhen.cases) {
append(" WHEN ", first, " THEN ", second)
append("WHEN ", first, " THEN ", second)
}

append(" ELSE ", elseResult, " END")
Expand Down Expand Up @@ -419,5 +428,6 @@ class Cast<T>(
val expr: Expression<*>,
columnType: IColumnType
) : Function<T>(columnType) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.cast(expr, columnType, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit =
currentDialect.functionProvider.cast(expr, columnType, queryBuilder)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class ConditionsTests : DatabaseTestsBase() {
@Test
fun testTRUEandFALSEOps() {
withCitiesAndUsers { cities, _, _ ->
val allSities = cities.selectAll().toCityNameList()
val allCities = cities.selectAll().toCityNameList()
assertEquals(0L, cities.select { Op.FALSE }.count())
assertEquals(allSities.size.toLong(), cities.select { Op.TRUE }.count())
assertEquals(allCities.size.toLong(), cities.select { Op.TRUE }.count())
}
}

Expand Down Expand Up @@ -158,9 +158,9 @@ class ConditionsTests : DatabaseTestsBase() {
@Test
fun nullOpInCaseTest() {
withCitiesAndUsers { cities, _, _ ->
val caseCondition = Case().
When(Op.build { cities.id eq 1 }, Op.nullOp<String>()).
Else(cities.name)
val caseCondition = Case()
.When(Op.build { cities.id eq 1 }, Op.nullOp<String>())
.Else(cities.name)
var nullBranchWasExecuted = false
cities.slice(cities.id, cities.name, caseCondition).selectAll().forEach {
val result = it[caseCondition]
Expand All @@ -174,4 +174,39 @@ class ConditionsTests : DatabaseTestsBase() {
assertEquals(true, nullBranchWasExecuted)
}
}

@Test
fun testCaseWhenElseAsArgument() {
withCitiesAndUsers { cities, _, _ ->
val original = "ORIGINAL"
val copy = "COPY"
val condition = Op.build { cities.id eq 1 }

val caseCondition1 = Case()
.When(condition, stringLiteral(original))
.Else(Op.nullOp())
// Case().When().Else() invokes CaseWhenElse() so the 2 formats should be interchangeable as arguments
val caseCondition2 = CaseWhenElse(
Case().When(condition, stringLiteral(original)),
Op.nullOp()
)
val function1 = Coalesce(caseCondition1, stringLiteral(copy))
val function2 = Coalesce(caseCondition2, stringLiteral(copy))

// confirm both formats produce identical SQL
val query1 = cities.slice(cities.id, function1).selectAll().prepareSQL(this, prepared = false)
val query2 = cities.slice(cities.id, function2).selectAll().prepareSQL(this, prepared = false)
assertEquals(query1, query2)

val results1 = cities.slice(cities.id, function1).selectAll().toList()
cities.slice(cities.id, function2).selectAll().forEachIndexed { i, row ->
val currentId = row[cities.id]
val functionResult = row[function2]

assertEquals(if (currentId == 1) original else copy, functionResult)
assertEquals(currentId, results1[i][cities.id])
assertEquals(functionResult, results1[i][function1])
}
}
}
}