Skip to content

Commit

Permalink
Ensure updates and deletes with RETURNING statements execute queries
Browse files Browse the repository at this point in the history
  • Loading branch information
AlecKazakova committed Apr 14, 2022
1 parent baf4663 commit 895e499
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.BIG_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.SMALL_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP_TIMEZONE
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlDeleteStmtLimited
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlInsertStmt
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypeName
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -94,6 +96,26 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
}
}
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is PostgreSqlUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is PostgreSqlDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
}
return parentResolver.queryWithResults(sqlStmt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package app.cash.sqldelight.dialects.sqlite_3_35

import app.cash.sqldelight.dialect.api.QueryWithResults
import app.cash.sqldelight.dialect.api.TypeResolver
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteDeleteStmtLimited
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteInsertStmt
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlStmt
import app.cash.sqldelight.dialects.sqlite_3_24.SqliteTypeResolver as Sqlite324TypeResolver
Expand All @@ -19,6 +21,26 @@ class SqliteTypeResolver(private val parentResolver: TypeResolver) : Sqlite324Ty
}
}
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is SqliteUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is SqliteDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
}
return parentResolver.queryWithResults(sqlStmt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ class SelectQueryTypeTest {
)
}

@Test fun `returning clause in an update correctly generates a query function`(dialect: TestDialect) {
assumeTrue(dialect in listOf(TestDialect.POSTGRESQL, TestDialect.SQLITE_3_35))
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE IF NOT EXISTS users(
| id ${dialect.intType} PRIMARY KEY,
| firstname ${dialect.textType} NOT NULL,
| lastname ${dialect.textType} NOT NULL
|);
|
|update:
|UPDATE users SET
| firstname = :firstname,
| lastname = :lastname
|WHERE id = :id
|RETURNING id, firstname, lastname;
|""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect()
)

val query = file.namedQueries.first()
val generator = SelectQueryGenerator(query)

assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun <T : kotlin.Any> update(
| firstname: kotlin.String,
| lastname: kotlin.String,
| id: kotlin.Int,
| mapper: (
| id: kotlin.Int,
| firstname: kotlin.String,
| lastname: kotlin.String,
| ) -> T,
|): app.cash.sqldelight.ExecutableQuery<T> = UpdateQuery(firstname, lastname, id) { cursor ->
| check(cursor is app.cash.sqldelight.driver.jdbc.JdbcCursor)
| mapper(
| cursor.getLong(0)!!.toInt(),
| cursor.getString(1)!!,
| cursor.getString(2)!!
| )
|}
|""".trimMargin()
)
}

@Test fun `query type generates properly`() {
val file = FixtureCompiler.parseSql(
"""
Expand Down

0 comments on commit 895e499

Please sign in to comment.