Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix codegen for RETURNING clause #3872

Merged
merged 2 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.REAL
import app.cash.sqldelight.dialect.api.PrimitiveType.TEXT
import app.cash.sqldelight.dialect.api.QueryWithResults
import app.cash.sqldelight.dialect.api.ReturningQueryable
import app.cash.sqldelight.dialect.api.TypeResolver
import app.cash.sqldelight.dialect.api.encapsulatingType
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.BIG_INT
Expand All @@ -18,7 +19,6 @@ import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlDeleteStmtL
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlInsertStmt
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypeName
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlCreateTableStmt
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -94,33 +94,15 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
override fun queryWithResults(sqlStmt: SqlStmt): QueryWithResults? {
sqlStmt.insertStmt?.let { insert ->
check(insert is PostgreSqlInsertStmt)
insert.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = insert
override val select = it
override val pureTable = insert.tableName
}
}
insert.returningClause?.let { return ReturningQueryable(insert, it, insert.tableName) }
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is PostgreSqlUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
update.returningClause?.let { return ReturningQueryable(update, it, update.qualifiedTableName.tableName) }
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is PostgreSqlDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
delete.returningClause?.let { return ReturningQueryable(delete, it, delete.qualifiedTableName?.tableName) }
}
return parentResolver.queryWithResults(sqlStmt)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,45 +1,27 @@
package app.cash.sqldelight.dialects.sqlite_3_35

import app.cash.sqldelight.dialect.api.QueryWithResults
import app.cash.sqldelight.dialect.api.ReturningQueryable
import app.cash.sqldelight.dialect.api.TypeResolver
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteDeleteStmtLimited
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteInsertStmt
import app.cash.sqldelight.dialects.sqlite_3_35.grammar.psi.SqliteUpdateStmtLimited
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlStmt
import app.cash.sqldelight.dialects.sqlite_3_24.SqliteTypeResolver as Sqlite324TypeResolver

class SqliteTypeResolver(private val parentResolver: TypeResolver) : Sqlite324TypeResolver(parentResolver) {
override fun queryWithResults(sqlStmt: SqlStmt): QueryWithResults? {
sqlStmt.insertStmt?.let { insert ->
check(insert is SqliteInsertStmt)
insert.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = insert
override val select = it
override val pureTable = insert.tableName
}
}
insert.returningClause?.let { return ReturningQueryable(insert, it, insert.tableName) }
}
sqlStmt.updateStmtLimited?.let { update ->
check(update is SqliteUpdateStmtLimited)
update.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = update
override val select = it
override val pureTable = update.qualifiedTableName.tableName
}
}
update.returningClause?.let { return ReturningQueryable(update, it, update.qualifiedTableName.tableName) }
}
sqlStmt.deleteStmtLimited?.let { delete ->
check(delete is SqliteDeleteStmtLimited)
delete.returningClause?.let {
return object : QueryWithResults {
override var statement: SqlAnnotatedElement = delete
override val select = it
override val pureTable = delete.qualifiedTableName?.tableName
}
}
delete.returningClause?.let { return ReturningQueryable(delete, it, delete.qualifiedTableName?.tableName) }
}
return parentResolver.queryWithResults(sqlStmt)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.QueryElement

internal fun List<QueryElement.QueryColumn>.flattenCompounded(): List<QueryElement.QueryColumn> {
return map { column ->
if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) {
column.copy(compounded = emptyList())
} else {
column
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.QueryElement
import com.alecstrong.sql.psi.core.psi.Queryable
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlTableName
import com.intellij.psi.util.PsiTreeUtil

/**
* Query deriving from the `RETURNING` clause of an expression.
*
* Typical use cases include `INSERT`, `UPDATE`, and `DELETE`. This class is similar to [SelectQueryable] but differs in
* the fact that only 1 table can be part of the query, and that table is guaranteed to be "real".
*
* @param statement Parent statement. Typically, this is the `INSERT`, `UPDATE`, or `DELETE` statement.
* @param select The `RETURNING` clause of the statement. Represented as a query since it returns values to the caller.
* @param tableName Name of the table the [statement] is operating on.
*/
class ReturningQueryable(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used by SQLite and Postgres dialects at the moment, but reasonable to think additional dialects would want this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, thanks.

override var statement: SqlAnnotatedElement,
override val select: QueryElement,
private val tableName: SqlTableName?,
) : QueryWithResults {

override val pureTable by lazy {
val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded()
val resolvedTable = tableName?.reference?.resolve()
val table = PsiTreeUtil.getParentOfType(resolvedTable, Queryable::class.java)?.tableExposed()
?: return@lazy null
val requestedColumnsAreIdenticalToTable = table.query.columns.flattenCompounded() == pureColumns
if (requestedColumnsAreIdenticalToTable) {
tableName
} else {
null
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package app.cash.sqldelight.dialect.api

import com.alecstrong.sql.psi.core.psi.NamedElement
import com.alecstrong.sql.psi.core.psi.QueryElement.QueryColumn
import com.alecstrong.sql.psi.core.psi.Queryable
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt
Expand All @@ -21,16 +20,6 @@ class SelectQueryable(
* which points to that table (Pure meaning it has exactly the same columns in the same order).
*/
override val pureTable: NamedElement? by lazy {
fun List<QueryColumn>.flattenCompounded(): List<QueryColumn> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to FlattenCompounded with internal visibility, for reuse in new class ReturningQueryable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names are okay.

return map { column ->
if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) {
column.copy(compounded = emptyList())
} else {
column
}
}
}

val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded()

// First check to see if its just the table we're observing directly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class PgInsertReturningTest {

assertThat(generator.defaultResultTypeFunction().toString()).isEqualTo(
"""
|public fun insertReturn(data_: com.example.Data_): app.cash.sqldelight.ExecutableQuery<com.example.Data_> = insertReturn(data_) { data__, id ->
| com.example.Data_(
|public fun insertReturn(data_: com.example.Data_): app.cash.sqldelight.ExecutableQuery<com.example.InsertReturn> = insertReturn(data_) { data__, id ->
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test verified that the code got generated, but the generated code didn't compile. I'd recommend actually deleting the test, given the integration tests have, in my opinion, better coverage, but either way, the test passes now.

| com.example.InsertReturn(
| data__,
| id
| )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,50 @@ SELECT *
FROM dog
WHERE :someBoolean AND 1 = 1;

insertAndReturn:
insertAndReturn1:
INSERT INTO dog
VALUES (?, ?, DEFAULT)
RETURNING name;

insertAndReturnMany:
INSERT INTO dog
VALUES (?, ?, DEFAULT)
RETURNING name, breed;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the problem scenario. 1 and All variants are present for coverage.


insertAndReturnAll:
INSERT INTO dog
VALUES (?, ?, DEFAULT)
RETURNING *;

updateAndReturn1:
UPDATE dog
SET is_good = ?
WHERE name = ?
RETURNING name;

updateAndReturnMany:
UPDATE dog
SET is_good = ?
WHERE name = ?
RETURNING name, breed;

updateAndReturnAll:
UPDATE dog
SET is_good = ?
WHERE name = ?
RETURNING *;

deleteAndReturn1:
DELETE FROM dog
WHERE name = ?
RETURNING name;

deleteAndReturnMany:
DELETE FROM dog
WHERE name = ?
RETURNING name, breed;

deleteAndReturnAll:
DELETE FROM dog
WHERE name = ?
RETURNING *;
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,87 @@ class PostgreSqlTest {
)
}

@Test fun returningInsert() {
assertThat(database.dogQueries.insertAndReturn("Tilda", "Pomeranian").executeAsOne())
@Test fun insertReturning1() {
assertThat(database.dogQueries.insertAndReturn1("Tilda", "Pomeranian").executeAsOne())
.isEqualTo(
"Tilda",
)
}

@Test fun insertReturningMany() {
assertThat(database.dogQueries.insertAndReturnMany("Tilda", "Pomeranian").executeAsOne())
.isEqualTo(
InsertAndReturnMany(
name = "Tilda",
breed = "Pomeranian",
),
)
}

@Test fun insertReturningAll() {
assertThat(database.dogQueries.insertAndReturnAll("Tilda", "Pomeranian").executeAsOne())
.isEqualTo(
Dog(
name = "Tilda",
breed = "Pomeranian",
is_good = 1,
),
)
}

@Test fun updateReturning1() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.updateAndReturn1(1, "Tilda").executeAsOne())
.isEqualTo(
"Tilda",
)
}

@Test fun updateReturningMany() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.updateAndReturnMany(1, "Tilda").executeAsOne())
.isEqualTo(
UpdateAndReturnMany(
name = "Tilda",
breed = "Pomeranian",
),
)
}

@Test fun updateReturningAll() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.updateAndReturnAll(1, "Tilda").executeAsOne())
.isEqualTo(
Dog(
name = "Tilda",
breed = "Pomeranian",
is_good = 1,
),
)
}

@Test fun deleteReturning1() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.deleteAndReturn1("Tilda").executeAsOne())
.isEqualTo(
"Tilda",
)
}

@Test fun deleteReturningMany() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.deleteAndReturnMany("Tilda").executeAsOne())
.isEqualTo(
DeleteAndReturnMany(
name = "Tilda",
breed = "Pomeranian",
),
)
}

@Test fun deleteReturningAll() {
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.deleteAndReturnAll("Tilda").executeAsOne())
.isEqualTo(
Dog(
name = "Tilda",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,50 @@ CREATE TABLE person (
last_name TEXT NOT NULL
);

insertAndReturn:
insertAndReturn1:
INSERT INTO person
VALUES (?, ?, ?)
RETURNING name;

insertAndReturnMany:
INSERT INTO person
VALUES (?, ?, ?)
RETURNING _id, first_name;

insertAndReturnAll:
INSERT INTO person
VALUES (?, ?, ?)
RETURNING *;

updateAndReturn1:
UPDATE person
SET last_name = ?
WHERE last_name = ?
RETURNING name;

updateAndReturnMany:
UPDATE person
SET last_name = ?
WHERE last_name = ?
RETURNING _id, first_name;

updateAndReturnAll:
UPDATE person
SET last_name = ?
WHERE last_name = ?
RETURNING *;

deleteAndReturn1:
DELETE FROM person
WHERE last_name = ?
RETURNING name;

deleteAndReturnMany:
DELETE FROM person
WHERE last_name = ?
RETURNING _id, first_name;

deleteAndReturnAll:
DELETE FROM person
WHERE last_name = ?
RETURNING *;
Loading