Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres: Add window function support #4283

Merged
merged 21 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.BIG_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.SMALL_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP_TIMEZONE
import app.cash.sqldelight.dialects.postgresql.grammar.mixins.WindowFunctionMixin
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlDeleteStmtLimited
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlExtensionExpr
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlInsertStmt
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypeName
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlUpdateStmtLimited
Expand Down Expand Up @@ -90,7 +92,19 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
"min" -> encapsulatingType(exprList, BLOB, TEXT, SMALL_INT, INTEGER, PostgreSqlType.INTEGER, BIG_INT, REAL, TIMESTAMP_TIMEZONE, TIMESTAMP).asNullable()
"date_trunc" -> encapsulatingType(exprList, TIMESTAMP_TIMEZONE, TIMESTAMP)
"date_part" -> IntermediateType(REAL)
"percentile_disc" -> IntermediateType(REAL).asNullable()
"now" -> IntermediateType(TIMESTAMP_TIMEZONE)
"corr", "covar_pop", "covar_samp", "regr_avgx", "regr_avgy", "regr_intercept",
"regr_r2", "regr_slope", "regr_sxx", "regr_sxy", "regr_syy",
-> IntermediateType(REAL).asNullable()
"stddev", "stddev_pop", "stddev_samp", "variance",
"var_pop", "var_samp",
-> if (resolvedType(exprList[0]).dialectType == REAL) {
IntermediateType(REAL).asNullable()
} else IntermediateType(
PostgreSqlType.NUMERIC,
).asNullable()
"regr_count" -> IntermediateType(BIG_INT).asNullable()
"gen_random_uuid" -> IntermediateType(PostgreSqlType.UUID)
"length", "character_length", "char_length" -> IntermediateType(PostgreSqlType.INTEGER).nullableIf(resolvedType(exprList[0]).javaType.isNullable)
else -> null
Expand Down Expand Up @@ -141,6 +155,13 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
literalValue.text.startsWith("INTERVAL") -> IntermediateType(PostgreSqlType.INTERVAL)
else -> parentResolver.resolvedType(this)
}
is PostgreSqlExtensionExpr -> when {
windowFunctionExpr != null -> {
val windowFunctionExpr = windowFunctionExpr as WindowFunctionMixin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this works with the typo in PostgreSql.bnf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or this code is never executed because we don't try to generate types based on the window function in tests anywhere

Copy link
Collaborator Author

@hfhbd hfhbd Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I guess, this explains the failures... I will fix it.

functionType(windowFunctionExpr.functionExpr)!!
}
else -> parentResolver.resolvedType(this)
}

else -> parentResolver.resolvedType(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"static com.alecstrong.sql.psi.core.psi.SqlTypes.FOREIGN"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.FROM"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.GENERATED"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.GROUP"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.ID"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.IGNORE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INSERT"
Expand Down Expand Up @@ -310,12 +311,16 @@ compound_select_stmt ::= [ {with_clause} ] {select_stmt} ( {compound_operator}
override = true
}

extension_expr ::= json_expression | boolean_literal | boolean_not_expression {
extension_expr ::= json_expression | boolean_literal | boolean_not_expression | window_function_expr {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlExtensionExprImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlExtensionExpr"
override = true
}

window_function_expr ::= {function_expr} 'WITHIN' GROUP LP ORDER BY <<expr '-1'>> ( COMMA <<expr '-1'>> ) * RP {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.WindowFunctionMixin"
}

boolean_not_expression ::= NOT (boolean_literal | {column_name})

boolean_literal ::= TRUE | FALSE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlWindowFunctionExpr
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
import com.intellij.lang.ASTNode

internal abstract class WindowFunctionMixin(
node: ASTNode,
) : SqlCompositeElementImpl(node),
PostgreSqlWindowFunctionExpr {
val functionExpr get() = children.filterIsInstance<SqlFunctionExpr>().single()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE TABLE myTable(
myColumn REAL NOT NULL
);

SELECT percentile_disc(.5) WITHIN GROUP (ORDER BY myTable.myColumn) AS P5
FROM myTable;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE myTable(
foo REAL NOT NULL,
bar NUMERIC NOT NULL
);

SELECT
corr(foo),
stddev(bar),
stddev(foo),
regr_count(foo)
FROM myTable GROUP BY foo, bar;
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.intellij.psi.PsiElement
interface TypeResolver {
/**
* @param expr The expression to be resolved to a type.
* @return The type for [expr] for null if this resolver cannot solve.
* @return The resolved type
*/
fun resolvedType(expr: SqlExpr): IntermediateType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ const val SQLDELIGHT_EXTENSION = "sq"
object SqlDelightFileType : LanguageFileType(SqlDelightLanguage) {
private val ICON = AllIcons.Providers.Sqlite

const val FOLDER_NAME = "sqldelight"

override fun getName() = "SqlDelight"
override fun getDescription() = "SqlDelight"
override fun getDefaultExtension() = SQLDELIGHT_EXTENSION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE TABLE myTable(
foo REAL NOT NULL,
bar NUMERIC NOT NULL
);

INSERT INTO myTable VALUES (1, 1), (2, 2), (3, 3);

selectPercentile:
SELECT percentile_disc(.5) WITHIN GROUP (ORDER BY foo) AS P5
FROM myTable;

selectStats:
SELECT
corr(foo, bar),
stddev(foo),
regr_count(foo, bar)
FROM myTable
GROUP BY foo, bar;
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,18 @@ class PostgreSqlTest {
val desc = database.charactersQueries.selectDescriptionLength().executeAsOne()
assertThat(desc.length).isNull()
}

@Test fun statFunctions() {
val percentile: SelectPercentile = database.functionsQueries.selectPercentile().executeAsOne()
val result: Double? = 2.0
assertThat(percentile).isEqualTo(SelectPercentile(result))
val stats: List<SelectStats> = database.functionsQueries.selectStats().executeAsList()
assertThat(stats).isEqualTo(
listOf(
SelectStats(null, null, 1),
SelectStats(null, null, 1),
SelectStats(null, null, 1),
),
)
}
}
Loading