Skip to content

Commit

Permalink
Fix: RETURNING multiple values on INSERT, UPDATE, DELETE generates wr…
Browse files Browse the repository at this point in the history
…ong code

The code generation was correct for scenarios where a single column is returned, as well as when all columns are returned. However, if multiple, but not all, columns are returned, then the generated code still used the table type, rather than generating a new interface for the projection.

With this change, the pureTable computation more closely resembles that of SelectQueryable, which contains more complex logic for purity identification.
  • Loading branch information
MariusVolkhart committed Feb 13, 2023
1 parent 94a2a13 commit d9089c4
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.REAL
import app.cash.sqldelight.dialect.api.PrimitiveType.TEXT
import app.cash.sqldelight.dialect.api.QueryWithResults
import app.cash.sqldelight.dialect.api.ReturningQueryable
import app.cash.sqldelight.dialect.api.TypeResolver
import app.cash.sqldelight.dialect.api.encapsulatingType
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.BIG_INT
Expand All @@ -18,7 +19,6 @@ import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlDeleteStmtL
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlInsertStmt
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypeName
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlCreateTableStmt
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -94,33 +94,15 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
override fun queryWithResults(sqlStmt: SqlStmt): QueryWithResults? {
sqlStmt.insertStmt?.let { insert ->
check(insert is PostgreSqlInsertStmt)
insert.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = insert
override val select = it
override val pureTable = insert.tableName
}
}
insert.returningClause?.let { return ReturningQueryable(insert, it, insert.tableName) }
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is PostgreSqlUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
update.returningClause?.let { return ReturningQueryable(update, it, update.qualifiedTableName.tableName) }
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is PostgreSqlDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
delete.returningClause?.let { return ReturningQueryable(delete, it, delete.qualifiedTableName?.tableName) }
}
return parentResolver.queryWithResults(sqlStmt)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,45 +1,27 @@
package app.cash.sqldelight.dialects.sqlite_3_35

import app.cash.sqldelight.dialect.api.QueryWithResults
import app.cash.sqldelight.dialect.api.ReturningQueryable
import app.cash.sqldelight.dialect.api.TypeResolver
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteDeleteStmtLimited
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteInsertStmt
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlStmt
import app.cash.sqldelight.dialects.sqlite_3_24.SqliteTypeResolver as Sqlite324TypeResolver

class SqliteTypeResolver(private val parentResolver: TypeResolver) : Sqlite324TypeResolver(parentResolver) {
override fun queryWithResults(sqlStmt: SqlStmt): QueryWithResults? {
sqlStmt.insertStmt?.let { insert ->
check(insert is SqliteInsertStmt)
insert.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = insert
override val select = it
override val pureTable = insert.tableName
}
}
insert.returningClause?.let { return ReturningQueryable(insert, it, insert.tableName) }
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is SqliteUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
update.returningClause?.let { return ReturningQueryable(update, it, update.qualifiedTableName.tableName) }
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is SqliteDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
delete.returningClause?.let { return ReturningQueryable(delete, it, delete.qualifiedTableName?.tableName) }
}
return parentResolver.queryWithResults(sqlStmt)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.QueryElement

internal fun List<QueryElement.QueryColumn>.flattenCompounded(): List<QueryElement.QueryColumn> {
return map { column ->
if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) {
column.copy(compounded = emptyList())
} else {
column
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.QueryElement
import com.alecstrong.sql.psi.core.psi.Queryable
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlTableName
import com.intellij.psi.util.PsiTreeUtil

/**
* Query deriving from the `RETURNING` clause of an expression.
*
* Typical use cases include `INSERT`, `UPDATE`, and `DELETE`. This class is similar to [SelectQueryable] but differs in
* the fact that only 1 table can be part of the query, and that table is guaranteed to be "real".
*
* @param statement Parent statement. Typically, this is the `INSERT`, `UPDATE`, or `DELETE` statement.
* @param select The `RETURNING` clause of the statement. Represented as a query since it returns values to the caller.
* @param tableName Name of the table the [statement] is operating on.
*/
class ReturningQueryable(
override var statement: SqlAnnotatedElement,
override val select: QueryElement,
private val tableName: SqlTableName?,
) : QueryWithResults {

override val pureTable by lazy {
val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded()
val resolvedTable = tableName?.reference?.resolve()
val table = PsiTreeUtil.getParentOfType(resolvedTable, Queryable::class.java)?.tableExposed()
?: return@lazy null
val requestedColumnsAreIdenticalToTable = table.query.columns.flattenCompounded() == pureColumns
if (requestedColumnsAreIdenticalToTable) {
tableName
} else {
null
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.NamedElement
import com.alecstrong.sql.psi.core.psi.QueryElement.QueryColumn
import com.alecstrong.sql.psi.core.psi.Queryable
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt
Expand All @@ -21,16 +20,6 @@ class SelectQueryable(
* which points to that table (Pure meaning it has exactly the same columns in the same order).
*/
override val pureTable: NamedElement? by lazy {
fun List<QueryColumn>.flattenCompounded(): List<QueryColumn> {
return map { column ->
if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) {
column.copy(compounded = emptyList())
} else {
column
}
}
}

val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded()

// First check to see if its just the table we're observing directly.
Expand Down

0 comments on commit d9089c4

Please sign in to comment.