Skip to content

Commit

Permalink
Defaults (#3559)
Browse files Browse the repository at this point in the history
* Fix DEFAULT in binding

Add test for PostgreSQL

Spotless

Move mixin to sqldelight

Use mixin instead

Don't add DEFAULT bindings to query

Fix DEFAULT in binding

* Remove isDefault api

* Fix grammar

* Add mysql async test

* Add postgresql async test

* Fix wildards

* Fix code style

* Add default to mysql too

* Downgrade r2dbc spi due binary incompatible with drivers

Fix docker test dependencies

* Fix DEFAULT in binding

Add test for PostgreSQL

Spotless

Move mixin to sqldelight

Use mixin instead

Don't add DEFAULT bindings to query

Fix DEFAULT in binding

* Remove isDefault api

* Fix grammar

* Add mysql async test

* Add postgresql async test

* Fix wildards

* Fix code style

* Add default to mysql too

* Downgrade r2dbc spi due binary incompatible with drivers

Fix docker test dependencies

* Remove useless check for default

Co-authored-by: hfhbd <hfhbd@users.noreply.github.com>
  • Loading branch information
hfhbd and hfhbd authored Oct 3, 2022
1 parent 0dfca55 commit 328daf8
Show file tree
Hide file tree
Showing 30 changed files with 418 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,16 @@ type_name ::= (
implements = "com.alecstrong.sql.psi.core.psi.SqlTypeName"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
}

identity_clause ::= 'IDENTITY'

generated_clause ::= GENERATED ( (ALWAYS AS <<expr '-1'>> 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
generated_clause ::= GENERATED ( (ALWAYS AS LP <<expr '-1'>> RP 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlGeneratedClauseImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlGeneratedClause"
override = true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : BindParameterMixin(node) {
override fun replaceWith(isAsync: Boolean, index: Int): String = when {
isAsync -> "$$index"
else -> "?"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PostgreSqlFixturesTest(name: String, fixtureRoot: File) : FixturesTest(nam
"?1" to "?",
"?2" to "?",
"BLOB" to "TEXT",
"id TEXT GENERATED ALWAYS AS (2) UNIQUE NOT NULL" to "id TEXT GENERATED ALWAYS AS (2) STORED UNIQUE NOT NULL",
)

override fun setupDialect() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ int_data_type ::= 'INTEGER'
real_data_type ::= 'REAL'

bind_parameter ::= ( '?' [digit] | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class R2dbcDriver(private val connection: Connection) : SqlDriver {

return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
return@AsyncValue result.rowsUpdated.awaitFirstOrNull() ?: 0
return@AsyncValue result.rowsUpdated.awaitFirstOrNull()?.toLong() ?: 0
}
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ testhelp = { module = "co.touchlab:testhelp", version.ref = "testhelp" }
burst = { module = "com.squareup.burst:burst-junit4", version = "1.2.0" }
testParameterInjector = { module = "com.google.testparameterinjector:test-parameter-injector", version = "1.8" }

r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "1.0.0.RELEASE" }
r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "0.9.1.RELEASE" }

[plugins]
android-library = { id = "com.android.library", version.ref = "agp" }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package app.cash.sqldelight.dialect.grammar.mixins

import com.alecstrong.sql.psi.core.psi.SqlBindParameter
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : SqlCompositeElementImpl(node), SqlBindParameter {
/**
* Overwrite, if the user provided sql parameter should be overwritten by sqldelight with [replaceWith].
*
* Some sql dialects support other bind parameter besides `?`, but sqldelight should still replace the
* user provided parameter with [replaceWith] for a homogen generated code.
*/
open fun replaceWith(isAsync: Boolean, index: Int): String = "?"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import app.cash.sqldelight.core.lang.util.rawSqlText
import app.cash.sqldelight.core.lang.util.sqFile
import app.cash.sqldelight.core.psi.SqlDelightStmtClojureStmtList
import app.cash.sqldelight.dialect.api.IntermediateType
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -174,36 +175,39 @@ abstract class QueryGenerator(
precedingArrays.add(type.name)
argumentCounts.add("${type.name}.size")
} else {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
val bindParameter = bindArg?.bindParameter as? BindParameterMixin
if (bindParameter == null || bindParameter.text != "DEFAULT") {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}
}

// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))
// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))

// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindArg != null) {
replacements.add(bindArg.range to "?")
// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindParameter != null) {
replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount))
}
}
}
}
Expand Down Expand Up @@ -293,7 +297,7 @@ abstract class QueryGenerator(
"""
if (result%L == 0L) throw %T(%S)
""".trimIndent(),
if (generateAsync) ".await()" else ".value",
if (generateAsync) "" else ".value",
ClassName("app.cash.sqldelight.db", "OptimisticLockException"),
"UPDATE on ${query.tablesAffected.single().name} failed because optimistic lock ${optimisticLock.name} did not match",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.ARGUMENT
import app.cash.sqldelight.dialect.api.PrimitiveType.BOOLEAN
import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.NULL
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlBindParameter
Expand Down Expand Up @@ -92,29 +93,32 @@ abstract class BindableQuery(
val namesSeen = mutableSetOf<String>()
var maxIndexSeen = 0
statement.findChildrenOfType<SqlBindExpr>().forEach { bindArg ->
bindArg.bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
val bindParameter = bindArg.bindParameter
if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") {
bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
bindArg.bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}
val index = ++maxIndexSeen
indexesSeen.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}

// If there are still naming conflicts (edge case where the name we generate is the same as
Expand Down
2 changes: 2 additions & 0 deletions sqldelight-gradle-plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ tasks.named('dockerTest') {
":sqlite-migrations:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-compiler:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-gradle-plugin:publishAllPublicationsToInstallLocallyRepository",
":drivers:r2dbc-driver:publishAllPublicationsToInstallLocallyRepository",
":extensions:async-extensions:publishAllPublicationsToInstallLocallyRepository",
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlSchemaDefinitions() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-schema"))
Expand All @@ -34,6 +43,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsPostgreSqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-postgresql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun `dialect accepts version catalog dependency`() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-catalog"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog (name, breed, is_good, id)
VALUES (?, ?, ?, ?);
VALUES (?, ?, DEFAULT, ?);

selectDogs:
SELECT *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class HsqlTest {
}

@Test fun simpleSelect() {
database.dogQueries.insertDog("Tilda", "Pomeranian", true, 1)
database.dogQueries.insertDog("Tilda", "Pomeranian", 1)
assertThat(database.dogQueries.selectDogs().executeAsOne())
.isEqualTo(
Dog(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ apply plugin: 'app.cash.sqldelight'

sqldelight {
MyDatabase {
packageName = "app.cash.sqldelight.mysql.integration"
packageName = "app.cash.sqldelight.mysql.integration.async"
dialect("app.cash.sqldelight:mysql-dialect:${app.cash.sqldelight.VersionKt.VERSION}")
generateAsync = true
}
Expand All @@ -26,7 +26,7 @@ dependencies {
implementation "org.testcontainers:r2dbc:1.16.2"
implementation "dev.miku:r2dbc-mysql:0.8.2.RELEASE"
implementation "app.cash.sqldelight:r2dbc-driver:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:coroutines-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:async-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation libs.truth
implementation libs.kotlin.coroutines.core
implementation libs.kotlin.coroutines.test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
apply from: "../settings.gradle"

rootProject.name = 'sqldelight-mysql-integration'
rootProject.name = 'sqldelight-mysql-integration-async'
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog
VALUES (?, ?, ?);
VALUES (?, ?, DEFAULT);

selectDogs:
SELECT *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package app.cash.sqldelight.mysql.integration
package app.cash.sqldelight.mysql.integration.async

import app.cash.sqldelight.async.coroutines.awaitAsList
import app.cash.sqldelight.async.coroutines.awaitAsOne
import app.cash.sqldelight.async.coroutines.awaitCreate
import app.cash.sqldelight.driver.r2dbc.R2dbcDriver
import com.google.common.truth.Truth.assertThat
import io.r2dbc.spi.ConnectionFactories
Expand All @@ -13,13 +16,13 @@ class MySqlTest {
val connection = factory.create().awaitSingle()
val driver = R2dbcDriver(connection)

val db = MyDatabase(driver).also { MyDatabase.Schema.create(driver) }
val db = MyDatabase(driver).also { MyDatabase.Schema.awaitCreate(driver) }
block(db)
}

@Test fun simpleSelect() = runTest { database ->
database.dogQueries.insertDog("Tilda", "Pomeranian", true)
assertThat(database.dogQueries.selectDogs().executeAsOne())
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.selectDogs().awaitAsOne())
.isEqualTo(
Dog(
name = "Tilda",
Expand All @@ -32,15 +35,15 @@ class MySqlTest {
@Test
fun simpleSelectWithIn() = runTest { database ->
with(database) {
dogQueries.insertDog("Tilda", "Pomeranian", true)
dogQueries.insertDog("Tucker", "Portuguese Water Dog", true)
dogQueries.insertDog("Cujo", "Pomeranian", false)
dogQueries.insertDog("Buddy", "Pomeranian", true)
dogQueries.insertDog("Tilda", "Pomeranian")
dogQueries.insertDog("Tucker", "Portuguese Water Dog")
dogQueries.insertDog("Cujo", "Pomeranian")
dogQueries.insertDog("Buddy", "Pomeranian")
assertThat(
dogQueries.selectDogsByBreedAndNames(
breed = "Pomeranian",
name = listOf("Tilda", "Buddy"),
).executeAsList(),
).awaitAsList(),
)
.containsExactly(
Dog(
Expand Down
Loading

0 comments on commit 328daf8

Please sign in to comment.