Skip to content

Commit

Permalink
Bind arguments with SqlBinaryExpr (#4604)
Browse files Browse the repository at this point in the history
  • Loading branch information
griffio authored Nov 17, 2023
1 parent 3e51f0c commit 371f1b8
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.NULL
import app.cash.sqldelight.dialect.api.PrimitiveType.TEXT
import app.cash.sqldelight.dialect.api.SelectQueryable
import com.alecstrong.sql.psi.core.psi.SqlBetweenExpr
import com.alecstrong.sql.psi.core.psi.SqlBinaryBooleanExpr
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBinaryExpr
import com.alecstrong.sql.psi.core.psi.SqlBinaryLikeExpr
import com.alecstrong.sql.psi.core.psi.SqlBinaryPipeExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlCaseExpr
import com.alecstrong.sql.psi.core.psi.SqlCastExpr
import com.alecstrong.sql.psi.core.psi.SqlCollateExpr
import com.alecstrong.sql.psi.core.psi.SqlColumnExpr
import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt
import com.alecstrong.sql.psi.core.psi.SqlExpr
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
Expand Down Expand Up @@ -114,11 +118,24 @@ internal fun SqlExpr.argumentType(argument: SqlExpr): IntermediateType {
IntermediateType(PrimitiveType.BOOLEAN)
}
}
is SqlBetweenExpr, is SqlIsExpr, is SqlBinaryExpr -> {

is SqlBinaryPipeExpr, is SqlBinaryEqualityExpr, is SqlIsExpr, is SqlBinaryBooleanExpr, is SqlBetweenExpr -> {
val validArg = children.lastOrNull { it is SqlExpr && it !== argument && it !is SqlBindExpr }
validArg?.type() ?: children.last { it is SqlExpr && it !== argument }.type()
}

is SqlBinaryExpr -> {
val validArg = children.lastOrNull {
it is SqlCastExpr && it == argument
} ?: children.lastOrNull {
it is SqlColumnExpr
} ?: parent.children.lastOrNull {
it is SqlExpr && it !== argument && it !is SqlBinaryExpr
}

validArg?.type() ?: children.last { it is SqlExpr && it !== argument }.type()
}

is SqlNullExpr -> IntermediateType(NULL).asNullable()
is SqlBinaryLikeExpr -> {
val other = children.last { it is SqlExpr && it !== argument }.type()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package app.cash.sqldelight.core

import app.cash.sqldelight.core.lang.argumentType
import app.cash.sqldelight.core.lang.types.typeResolver
import app.cash.sqldelight.core.lang.util.argumentType
import app.cash.sqldelight.core.lang.util.findChildrenOfType
import app.cash.sqldelight.core.lang.util.isArrayParameter
import app.cash.sqldelight.dialect.api.PrimitiveType
import app.cash.sqldelight.dialects.postgresql.PostgreSqlDialect
import app.cash.sqldelight.dialects.sqlite_3_24.SqliteDialect
import app.cash.sqldelight.test.util.FixtureCompiler
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlColumnDef
import com.google.common.truth.Truth.assertThat
import com.squareup.kotlinpoet.asClassName
import java.time.Instant
import kotlin.test.assertFailsWith
import org.junit.Rule
import org.junit.Test
Expand Down Expand Up @@ -386,5 +389,121 @@ class BindArgsTest {
}
}

@Test fun `bind arg type in binary expression can be inferred from column`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| datum INTEGER NOT NULL
|);
|
|selectData:
|SELECT *
|FROM data
|WHERE datum > :datum1 - 2.5 AND datum < :datum2 + 2.5;
""".trimMargin(),
tempFolder,
)
val column = file.namedQueries.first()
column.parameters.let { args ->
assertThat(args[0].dialectType).isEqualTo(PrimitiveType.INTEGER)
assertThat(args[0].javaType).isEqualTo(Long::class.asClassName())
assertThat(args[0].name).isEqualTo("datum1")

assertThat(args[1].dialectType).isEqualTo(PrimitiveType.INTEGER)
assertThat(args[1].javaType).isEqualTo(Long::class.asClassName())
assertThat(args[1].name).isEqualTo("datum2")
}
}

@Test fun `bind arg in arithmetic binary expression can be cast as type`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| datum INTEGER NOT NULL,
| point INTEGER NOT NULL
|);
|
|selectData:
|SELECT *, (datum + CAST(:datum1 AS REAL) * point) AS expected_datum
|FROM data;
""".trimMargin(),
tempFolder,
)

val column = file.namedQueries.first()
column.parameters.let { args ->
assertThat(args[0].dialectType).isEqualTo(PrimitiveType.REAL)
assertThat(args[0].javaType).isEqualTo(Double::class.asClassName().copy(nullable = true))
assertThat(args[0].name).isEqualTo("datum1")
}
}

@Test fun `bind arg in binary expression can be cast as type`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| datum INTEGER NOT NULL
|);
|
|selectData:
|SELECT CAST(:datum1 AS REAL) + CAST(:datum2 AS INTEGER) - 10.5
|FROM data;
""".trimMargin(),
tempFolder,
)

val column = file.namedQueries.first()
column.parameters.let { args ->
assertThat(args[0].dialectType).isEqualTo(PrimitiveType.REAL)
assertThat(args[0].javaType).isEqualTo(Double::class.asClassName().copy(nullable = true))
assertThat(args[0].name).isEqualTo("datum1")

assertThat(args[1].dialectType).isEqualTo(PrimitiveType.INTEGER)
assertThat(args[1].javaType).isEqualTo(Long::class.asClassName().copy(nullable = true))
assertThat(args[1].name).isEqualTo("datum2")
}
}

@Test fun `bind arg in binary expression can be cast as custom type`() {
val file = FixtureCompiler.parseSql(
"""
|import java.time.Instant;
|
|CREATE TABLE session (
| id UUID PRIMARY KEY,
| created_at TIMESTAMP AS Instant NOT NULL,
| updated_at TIMESTAMP AS Instant NOT NULL
|);
|
|selectSession1:
|SELECT *
|FROM session
|WHERE created_at = :createdAt - INTERVAL '2 days' OR updated_at = :updatedAt + INTERVAL '2 days';
|
|selectSession2:
|SELECT *
|FROM session
|WHERE created_at BETWEEN :createdAt - INTERVAL '2 days' AND :createdAt + INTERVAL '2 days';
""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
)

val selectSession1 = file.namedQueries[0]
selectSession1.parameters.let { args ->
assertThat(args[0].javaType).isEqualTo(Instant::class.asClassName())
assertThat(args[0].name).isEqualTo("createdAt")

assertThat(args[1].javaType).isEqualTo(Instant::class.asClassName())
assertThat(args[1].name).isEqualTo("updatedAt")
}

val selectSession2 = file.namedQueries[1]
selectSession2.parameters.let { args ->
assertThat(args[0].javaType).isEqualTo(Instant::class.asClassName())
assertThat(args[0].name).isEqualTo("createdAt")
}
}

private fun SqlBindExpr.argumentType() = typeResolver.argumentType(this)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import java.time.Instant;

CREATE TABLE data (
datum INTEGER NOT NULL,
point INTEGER NOT NULL,
created_at TIMESTAMP AS Instant NOT NULL,
updated_at TIMESTAMP AS Instant NOT NULL
);

insertData:
INSERT INTO data(datum, point, created_at, updated_at)
VALUES(?, ?, ?, ?);

selectDataBinaryComparison:
SELECT *
FROM data
WHERE datum > :datum1 - 2.5 AND datum < :datum2 + 2.5;

selectDataBinaryCast1:
SELECT *, (datum + CAST(:datum1 AS REAL) * point) AS expected_datum
FROM data;

selectDataBinaryCast2:
SELECT CAST(:datum1 AS REAL) + CAST(:datum2 AS INTEGER) - 10.5 AS expected_datum
FROM data;

selectDataBinaryIntervalComparison1:
SELECT *
FROM data
WHERE created_at = :createdAt - INTERVAL '2 days' OR updated_at = :updatedAt + INTERVAL '2 days';

selectDataBinaryIntervalComparison2:
SELECT *
FROM data
WHERE created_at BETWEEN :createdAt - INTERVAL '2 days' AND :createdAt + INTERVAL '2 days';
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import app.cash.sqldelight.driver.jdbc.JdbcDriver
import com.google.common.truth.Truth.assertThat
import java.sql.Connection
import java.sql.DriverManager
import java.time.Instant
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.LocalTime
Expand Down Expand Up @@ -38,6 +39,26 @@ class PostgreSqlTest {
value.map { it.toInt() }.toTypedArray()
},
),
data_Adapter = Data_.Adapter(
object : ColumnAdapter<Instant, LocalDateTime> {
override fun encode(value: Instant): LocalDateTime {
return LocalDateTime.ofInstant(value, ZoneOffset.UTC)
}

override fun decode(databaseValue: LocalDateTime): Instant {
return databaseValue.toInstant(ZoneOffset.UTC)
}
},
object : ColumnAdapter<Instant, LocalDateTime> {
override fun encode(value: Instant): LocalDateTime {
return LocalDateTime.ofInstant(value, ZoneOffset.UTC)
}

override fun decode(databaseValue: LocalDateTime): Instant {
return databaseValue.toInstant(ZoneOffset.UTC)
}
},
),
)

@Before fun before() {
Expand Down Expand Up @@ -502,4 +523,51 @@ class PostgreSqlTest {
assertThat(series.first()).isEqualTo(start)
assertThat(series.last()).isEqualTo(finish)
}

@Test
fun testSelectDataBinaryComparison() {
val created = Instant.parse("2017-12-03T10:00:00.00Z")
val updated = Instant.parse("2022-05-01T10:00:00.00Z")
database.binaryArgumentsQueries.insertData(10, 5, created, updated)
val result = database.binaryArgumentsQueries.selectDataBinaryComparison(10, 10).executeAsList()
assertThat(result.first().datum).isEqualTo(10)
}

@Test
fun testSelectDataBinaryCast1() {
val created = Instant.parse("2017-12-03T10:00:00.00Z")
val updated = Instant.parse("2022-05-01T10:00:00.00Z")
database.binaryArgumentsQueries.insertData(10, 5, created, updated)
val result = database.binaryArgumentsQueries.selectDataBinaryCast1(10.0).executeAsOne()
assertThat(result.expected_datum).isEqualTo(60.toDouble())
}

@Test
fun testSelectDataBinaryCast2() {
val created = Instant.parse("2017-12-03T10:00:00.00Z")
val updated = Instant.parse("2022-05-01T10:00:00.00Z")
database.binaryArgumentsQueries.insertData(10, 5, created, updated)
val result = database.binaryArgumentsQueries.selectDataBinaryCast2(10.0, 10).executeAsOne()
assertThat(result.expected_datum).isEqualTo(9.5)
}

@Test
fun testSelectDataBinaryIntervalComparison1() {
val created = Instant.parse("2017-12-03T10:00:00.00Z")
val updated = Instant.parse("2022-05-01T10:00:00.00Z")
val createdAt = Instant.parse("2017-12-05T10:00:00.00Z")
val updatedAt = Instant.parse("2022-05-01T10:00:00.00Z")
database.binaryArgumentsQueries.insertData(10, 5, created, updated)
val result = database.binaryArgumentsQueries.selectDataBinaryIntervalComparison1(createdAt, updatedAt).executeAsList()
assertThat(result.first().datum).isEqualTo(10)
}

@Test
fun testSelectDataBinaryIntervalComparison2() {
val created = Instant.parse("2017-12-03T10:00:00.00Z")
val updated = Instant.parse("2022-05-01T10:00:00.00Z")
database.binaryArgumentsQueries.insertData(10, 5, created, updated)
val result = database.binaryArgumentsQueries.selectDataBinaryIntervalComparison2(created).executeAsList()
assertThat(result.first().datum).isEqualTo(10)
}
}

0 comments on commit 371f1b8

Please sign in to comment.