Skip to content

Commit

Permalink
JetBrains#623 Add support of window functions in Exposed DSL
Browse files Browse the repository at this point in the history
-Support of partition by and order by clauses
-Support of window frame clause (without EXCLUDE)
-Factories for common window functions
-Support for using aggregate functions as window functions
  • Loading branch information
Dmitry Levin committed Jun 27, 2023
1 parent 0f55b8b commit 5d6e5de
Show file tree
Hide file tree
Showing 10 changed files with 1,037 additions and 10 deletions.
56 changes: 46 additions & 10 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ class Min<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the minimum value is obtained. */
val expr: Expression<in S>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MIN(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -174,8 +178,12 @@ class Max<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the maximum value is obtained. */
val expr: Expression<in S>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MAX(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -185,8 +193,12 @@ class Avg<T : Comparable<T>, in S : T?>(
/** Returns the expression from which the average is calculated. */
val expr: Expression<in S>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("AVG(", expr, ")") }

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -196,8 +208,12 @@ class Sum<T>(
/** Returns the expression from which the sum is calculated. */
val expr: Expression<T>,
columnType: IColumnType
) : Function<T?>(columnType) {
) : Function<T?>(columnType), WindowFunction<T?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("SUM(", expr, ")") }

override fun over(): WindowFunctionDefinition<T?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -208,13 +224,17 @@ class Count(
val expr: Expression<*>,
/** Returns whether only distinct element should be count. */
val distinct: Boolean = false
) : Function<Long>(LongColumnType()) {
) : Function<Long>(LongColumnType()), WindowFunction<Long> {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder {
+"COUNT("
if (distinct) +"DISTINCT "
+expr
+")"
}

override fun over(): WindowFunctionDefinition<Long> {
return WindowFunctionDefinition(LongColumnType(), this)
}
}

// Aggregate Functions for Statistics
Expand All @@ -227,7 +247,7 @@ class StdDevPop<T>(
/** Returns the expression from which the population standard deviation is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -237,6 +257,10 @@ class StdDevPop<T>(
functionProvider.stdDevPop(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -247,7 +271,7 @@ class StdDevSamp<T>(
/** Returns the expression from which the sample standard deviation is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -257,6 +281,10 @@ class StdDevSamp<T>(
functionProvider.stdDevSamp(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -267,7 +295,7 @@ class VarPop<T>(
/** Returns the expression from which the population variance is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -277,6 +305,10 @@ class VarPop<T>(
functionProvider.varPop(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

/**
Expand All @@ -287,7 +319,7 @@ class VarSamp<T>(
/** Returns the expression from which the sample variance is calculated. */
val expression: Expression<T>,
scale: Int
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)) {
) : Function<BigDecimal?>(DecimalColumnType(Int.MAX_VALUE, scale)), WindowFunction<BigDecimal?> {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
val functionProvider = when (currentDialect.h2Mode) {
Expand All @@ -297,6 +329,10 @@ class VarSamp<T>(
functionProvider.varSamp(expression, this)
}
}

override fun over(): WindowFunctionDefinition<BigDecimal?> {
return WindowFunctionDefinition(columnType, this)
}
}

// JSON Functions
Expand Down Expand Up @@ -325,7 +361,7 @@ class JsonExtract<T>(
/**
* Represents an SQL function that advances the specified [seq] and returns the new value.
*/
sealed class NextVal<T> (
sealed class NextVal<T>(
/** Returns the sequence from which the next value is obtained. */
val seq: Sequence,
columnType: IColumnType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,68 @@ inline fun <reified T : Any> ExpressionWithColumnType<*>.jsonExtract(vararg path
return JsonExtract(this, path = path, toScalar, this.columnType, columnType)
}

// Window Functions

/** Returns the number of the current row within its partition, counting from 1. */
fun rowNumber(): RowNumber = RowNumber()

/** Returns the rank of the current row, with gaps; that is, the row_number of the first row in its peer group. */
fun rank(): Rank = Rank()

/** Returns the rank of the current row, without gaps; this function effectively counts peer groups. */
fun denseRank(): DenseRank = DenseRank()

/**
* Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1).
* The value thus ranges from 0 to 1 inclusive.
*/
fun percentRank(): PercentRank = PercentRank()

/**
* Returns the cumulative distribution, that is (number of partition rows preceding or peers with current row) /
* (total partition rows). The value thus ranges from 1/N to 1.
*/
fun cumeDist(): CumeDist = CumeDist()

/** Returns an integer ranging from 1 to the [numBuckets], dividing the partition as equally as possible. */
fun ntile(numBuckets: ExpressionWithColumnType<Int>): Ntile = Ntile(numBuckets)

/**
* Returns value evaluated at the row that is [offset] rows before the current row within the partition;
* if there is no such row, instead returns [defaultValue].
* Both [offset] and [defaultValue] are evaluated with respect to the current row.
*/
fun <T> ExpressionWithColumnType<T>.lag(
offset: ExpressionWithColumnType<Int> = intLiteral(1),
defaultValue: ExpressionWithColumnType<T>? = null
): Lag<T> = Lag(this, offset, defaultValue)

/**
* Returns value evaluated at the row that is [offset] rows after the current row within the partition;
* if there is no such row, instead returns [defaultValue].
* Both [offset] and [defaultValue] are evaluated with respect to the current row.
*/
fun <T> ExpressionWithColumnType<T>.lead(
offset: ExpressionWithColumnType<Int> = intLiteral(1),
defaultValue: ExpressionWithColumnType<T>? = null
): Lead<T> = Lead(this, offset, defaultValue)

/**
* Returns value evaluated at the row that is the first row of the window frame.
*/
fun <T> ExpressionWithColumnType<T>.firstValue(): FirstValue<T> = FirstValue(this)

/**
* Returns value evaluated at the row that is the last row of the window frame.
*/
fun <T> ExpressionWithColumnType<T>.lastValue(): LastValue<T> = LastValue(this)

/**
* Returns value evaluated at the row that is the [index] row of the window frame
* (counting from 1); null if no such row
*/
fun <T> ExpressionWithColumnType<T>.nthValue(index: ExpressionWithColumnType<Int>): NthValue<T> = NthValue(this, index)

// Sequence Manipulation Functions

/** Advances this sequence and returns the new value. */
Expand Down
Loading

0 comments on commit 5d6e5de

Please sign in to comment.