Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize integer numeric driver bindings #4588

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,8 @@ import com.squareup.kotlinpoet.SHORT
import com.squareup.kotlinpoet.TypeName

internal enum class PostgreSqlType(override val javaType: TypeName) : DialectType {
SMALL_INT(SHORT) {
override fun decode(value: CodeBlock) = CodeBlock.of("%L.toShort()", value)

override fun encode(value: CodeBlock) = CodeBlock.of("%L.toLong()", value)
},
INTEGER(INT) {
override fun decode(value: CodeBlock) = CodeBlock.of("%L.toInt()", value)

override fun encode(value: CodeBlock) = CodeBlock.of("%L.toLong()", value)
},
SMALL_INT(SHORT),
INTEGER(INT),
BIG_INT(LONG),
DATE(ClassName("java.time", "LocalDate")),
TIME(ClassName("java.time", "LocalTime")),
Expand All @@ -33,7 +25,9 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
return CodeBlock.builder()
.add(
when (this) {
SMALL_INT, INTEGER, BIG_INT -> "bindLong"
SMALL_INT -> "bindShort"
INTEGER -> "bindInt"
BIG_INT -> "bindLong"
DATE, TIME, TIMESTAMP, TIMESTAMP_TIMEZONE, INTERVAL, UUID -> "bindObject"
NUMERIC -> "bindBigDecimal"
},
Expand All @@ -45,7 +39,9 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
override fun cursorGetter(columnIndex: Int, cursorName: String): CodeBlock {
return CodeBlock.of(
when (this) {
SMALL_INT, INTEGER, BIG_INT -> "$cursorName.getLong($columnIndex)"
SMALL_INT -> "$cursorName.getShort($columnIndex)"
INTEGER -> "$cursorName.getInt($columnIndex)"
BIG_INT -> "$cursorName.getLong($columnIndex)"
DATE, TIME, TIMESTAMP, TIMESTAMP_TIMEZONE, INTERVAL, UUID -> "$cursorName.getObject<%T>($columnIndex)"
NUMERIC -> "$cursorName.getBigDecimal($columnIndex)"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,22 @@ class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStat
}
}

fun bindShort(index: Int, short: Short?) {
if (short == null) {
statement.bindNull(index, Short::class.javaObjectType)
} else {
statement.bind(index, short)
}
}

fun bindInt(index: Int, int: Int?) {
if (int == null) {
statement.bindNull(index, Int::class.javaObjectType)
} else {
statement.bind(index, int)
}
}

override fun bindLong(index: Int, long: Long?) {
if (long == null) {
statement.bindNull(index, Long::class.javaObjectType)
Expand Down Expand Up @@ -264,6 +280,8 @@ internal constructor(private val results: AsyncPublisherIterator<List<Any?>>) :
}

override fun getString(index: Int): String? = get(index)
fun getShort(index: Int): Short? = get<Number>(index)?.toShort()
fun getInt(index: Int): Int? = get<Number>(index)?.toInt()

override fun getLong(index: Int): Long? = get<Number>(index)?.toLong()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ class InterfaceGeneration {
| public fun insertUser(slack_user_id: String): Query<Int> = InsertUserQuery(slack_user_id) {
| cursor ->
| check(cursor is JdbcCursor)
| cursor.getLong(0)!!.toInt()
| cursor.getInt(0)!!
| }
|
| public fun insertSubscription(user_id2: Int) {
Expand All @@ -908,7 +908,7 @@ class InterfaceGeneration {
| |VALUES (?)
| ""${'"'}.trimMargin(), 1) {
| check(this is JdbcPreparedStatement)
| bindLong(0, user_id2.toLong())
| bindInt(0, user_id2)
| }
| notifyQueries(${result.compiledFile.namedMutators[0].id.withUnderscores}) { emit ->
| emit("subscriptionEntity")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ import app.cash.sqldelight.core.compiler.MutatorQueryGenerator
import app.cash.sqldelight.core.compiler.SelectQueryGenerator
import app.cash.sqldelight.core.dialects.binderCheck
import app.cash.sqldelight.core.dialects.cursorCheck
import app.cash.sqldelight.core.dialects.intKotlinType
import app.cash.sqldelight.core.dialects.textType
import app.cash.sqldelight.test.util.FixtureCompiler
import app.cash.sqldelight.test.util.withUnderscores
import com.google.common.truth.Truth.assertThat
import com.squareup.burst.BurstJUnit4
import com.squareup.kotlinpoet.INT
import com.squareup.kotlinpoet.LONG
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
Expand Down Expand Up @@ -175,8 +172,7 @@ class JavadocTest {
createTable(testDialect) + """
|/** Queries all values. */
|selectAll:
|SELECT *
|FROM test;
|SELECT CAST(:input AS ${testDialect.textType});
|
""".trimMargin(),
tempFolder,
Expand All @@ -189,35 +185,23 @@ class JavadocTest {
|/**
| * Queries all values.
| */
|public fun selectAll(): app.cash.sqldelight.Query<com.example.Test> = selectAll { _id, value_ ->
| com.example.Test(
| _id,
| value_
|public fun selectAll(input: kotlin.String?): app.cash.sqldelight.ExecutableQuery<com.example.SelectAll> = selectAll(input) { expr ->
| com.example.SelectAll(
| expr
| )
|}
|
""".trimMargin(),
)

val int = testDialect.intKotlinType
val toInt = when (int) {
LONG -> ""
INT -> ".toInt()"
else -> error("Unknown kotlinType $int")
}

assertThat(selectGenerator.customResultTypeFunction().toString()).isEqualTo(
"""
|/**
| * Queries all values.
| */
|public fun <T : kotlin.Any> selectAll(mapper: (_id: $int, value_: kotlin.String) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(-585_795_480, arrayOf("test"), driver, "Test.sq", "selectAll", ""${'"'}
||SELECT *
||FROM test
|""${'"'}.trimMargin()) { cursor ->
|public fun <T : kotlin.Any> selectAll(input: kotlin.String?, mapper: (expr: kotlin.String?) -> T): app.cash.sqldelight.ExecutableQuery<T> = SelectAllQuery(input) { cursor ->
| ${testDialect.cursorCheck(2)}mapper(
| cursor.getLong(0)!!$toInt,
| cursor.getString(1)!!
| cursor.getString(0)
| )
|}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,14 @@ class MutatorQueryTypeTest {
| |VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
| ""${'"'}.trimMargin(), 12) {
| check(this is ${dialect.dialect.runtimeTypes.preparedStatementType})
| bindLong(0, smallint0.toLong())
| bindLong(1, smallint1?.let { it.toLong() })
| bindLong(2, data_Adapter.smallint2Adapter.encode(smallint2).toLong())
| bindLong(3, smallint3?.let { data_Adapter.smallint3Adapter.encode(it).toLong() })
| bindLong(4, int0.toLong())
| bindLong(5, int1?.let { it.toLong() })
| bindLong(6, data_Adapter.int2Adapter.encode(int2).toLong())
| bindLong(7, int3?.let { data_Adapter.int3Adapter.encode(it).toLong() })
| bindShort(0, smallint0)
| bindShort(1, smallint1)
| bindShort(2, data_Adapter.smallint2Adapter.encode(smallint2))
| bindShort(3, smallint3?.let { data_Adapter.smallint3Adapter.encode(it) })
| bindInt(4, int0)
| bindInt(5, int1)
| bindInt(6, data_Adapter.int2Adapter.encode(int2))
| bindInt(7, int3?.let { data_Adapter.int3Adapter.encode(it) })
| bindLong(8, bigint0)
| bindLong(9, bigint1)
| bindLong(10, data_Adapter.bigint2Adapter.encode(bigint2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PgInsertOnConflictTest {
| |ON CONFLICT (id) DO UPDATE SET col1 = ?
| ""${'"'}.trimMargin(), 3) {
| check(this is app.cash.sqldelight.driver.jdbc.JdbcPreparedStatement)
| bindLong(0, id?.let { it.toLong() })
| bindInt(0, id)
| bindString(1, c1)
| bindString(2, c1)
| }
Expand Down Expand Up @@ -92,7 +92,7 @@ class PgInsertOnConflictTest {
| |ON CONFLICT (id) DO UPDATE SET col1 = ?, col2 = ?
| ""${'"'}.trimMargin(), 6) {
| check(this is app.cash.sqldelight.driver.jdbc.JdbcPreparedStatement)
| bindLong(0, id?.let { it.toLong() })
| bindInt(0, id)
| bindString(1, c1)
| bindString(2, c2)
| bindString(3, c3)
Expand Down Expand Up @@ -145,7 +145,7 @@ class PgInsertOnConflictTest {
| |ON CONFLICT (id) DO UPDATE SET col1 = ?, col2 = ?, col3 = ?
| ""${'"'}.trimMargin(), 7) {
| check(this is app.cash.sqldelight.driver.jdbc.JdbcPreparedStatement)
| bindLong(0, id?.let { it.toLong() })
| bindInt(0, id)
| bindString(1, c1)
| bindString(2, c2)
| bindString(3, c3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PgInsertReturningTest {
|public fun <T : kotlin.Any> insertReturn(data_: com.example.Data_, mapper: (id: kotlin.Int, data_: kotlin.String?) -> T): app.cash.sqldelight.ExecutableQuery<T> = InsertReturnQuery(data_) { cursor ->
| check(cursor is app.cash.sqldelight.driver.jdbc.JdbcCursor)
| mapper(
| cursor.getLong(0)!!.toInt(),
| cursor.getInt(0)!!,
| cursor.getString(1)
| )
|}
Expand Down Expand Up @@ -96,7 +96,7 @@ class PgInsertReturningTest {
| check(cursor is app.cash.sqldelight.driver.jdbc.JdbcCursor)
| mapper(
| cursor.getString(0),
| cursor.getLong(1)!!.toInt()
| cursor.getInt(1)!!
| )
|}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,14 +766,14 @@ class SelectQueryFunctionTest {
| check(cursor is ${dialect.dialect.runtimeTypes.cursorType})
| mapper(
| cursor.getArray<kotlin.Short>(0),
| cursor.getLong(1)!!.toShort(),
| cursor.getLong(2)?.let { it.toShort() },
| data_Adapter.smallint2Adapter.decode(cursor.getLong(3)!!.toShort()),
| cursor.getLong(4)?.let { data_Adapter.smallint3Adapter.decode(it.toShort()) },
| cursor.getLong(5)!!.toInt(),
| cursor.getLong(6)?.let { it.toInt() },
| data_Adapter.int2Adapter.decode(cursor.getLong(7)!!.toInt()),
| cursor.getLong(8)?.let { data_Adapter.int3Adapter.decode(it.toInt()) },
| cursor.getShort(1)!!,
| cursor.getShort(2),
| data_Adapter.smallint2Adapter.decode(cursor.getShort(3)!!),
| cursor.getShort(4)?.let { data_Adapter.smallint3Adapter.decode(it) },
| cursor.getInt(5)!!,
| cursor.getInt(6),
| data_Adapter.int2Adapter.decode(cursor.getInt(7)!!),
| cursor.getInt(8)?.let { data_Adapter.int3Adapter.decode(it) },
| cursor.getLong(9)!!,
| cursor.getLong(10),
| data_Adapter.bigint2Adapter.decode(cursor.getLong(11)!!),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class SelectQueryTypeTest {
)
}

@Test fun `returning clause in an update correctly generates a query function`(dialect: TestDialect) {
assumeTrue(dialect in listOf(TestDialect.POSTGRESQL, TestDialect.SQLITE_3_35))
@Test fun `returning clause in an update correctly generates a query function - postgres`(dialect: TestDialect) {
assumeTrue(dialect == TestDialect.POSTGRESQL)
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE IF NOT EXISTS users(
Expand All @@ -82,7 +82,7 @@ class SelectQueryTypeTest {
|
""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
dialect = dialect.dialect,
)

val query = file.namedQueries.first()
Expand All @@ -102,7 +102,55 @@ class SelectQueryTypeTest {
|): app.cash.sqldelight.ExecutableQuery<T> = UpdateQuery(firstname, lastname, id) { cursor ->
| check(cursor is app.cash.sqldelight.driver.jdbc.JdbcCursor)
| mapper(
| cursor.getLong(0)!!.toInt(),
| cursor.getInt(0)!!,
| cursor.getString(1)!!,
| cursor.getString(2)!!
| )
|}
|
""".trimMargin(),
)
}

@Test fun `returning clause in an update correctly generates a query function - sqlite`(dialect: TestDialect) {
assumeTrue(dialect == TestDialect.SQLITE_3_35)
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE IF NOT EXISTS users(
| id ${dialect.intType} PRIMARY KEY,
| firstname ${dialect.textType} NOT NULL,
| lastname ${dialect.textType} NOT NULL
|);
|
|update:
|UPDATE users SET
| firstname = :firstname,
| lastname = :lastname
|WHERE id = :id
|RETURNING id, firstname, lastname;
|
""".trimMargin(),
tempFolder,
dialect = dialect.dialect,
)

val query = file.namedQueries.first()
val generator = SelectQueryGenerator(query)

assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun <T : kotlin.Any> update(
| firstname: kotlin.String,
| lastname: kotlin.String,
| id: kotlin.Long,
| mapper: (
| id: kotlin.Long,
| firstname: kotlin.String,
| lastname: kotlin.String,
| ) -> T,
|): app.cash.sqldelight.ExecutableQuery<T> = UpdateQuery(firstname, lastname, id) { cursor ->
| mapper(
| cursor.getLong(0)!!,
| cursor.getString(1)!!,
| cursor.getString(2)!!
| )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AsyncSelectQueryTypeTest {
|
""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
dialect = dialect.dialect,
generateAsync = true,
)

Expand All @@ -96,7 +96,7 @@ class AsyncSelectQueryTypeTest {
|): app.cash.sqldelight.ExecutableQuery<T> = UpdateQuery(firstname, lastname, id) { cursor ->
| check(cursor is app.cash.sqldelight.driver.r2dbc.R2dbcCursor)
| mapper(
| cursor.getLong(0)!!.toInt(),
| cursor.getInt(0)!!,
| cursor.getString(1)!!,
| cursor.getString(2)!!
| )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class OptimisticLockTest {
| ""${'"'}.trimMargin(), 4) {
| check(this is app.cash.sqldelight.driver.jdbc.JdbcPreparedStatement)
| bindString(0, text)
| bindLong(1, version.version.toLong())
| bindLong(2, id.id.toLong())
| bindLong(3, version.version.toLong())
| bindInt(1, version.version)
| bindInt(2, id.id)
| bindInt(3, version.version)
| }
| if (result.value == 0L) throw app.cash.sqldelight.db.OptimisticLockException("UPDATE on test failed because optimistic lock version did not match")
| notifyQueries(${mutator.id.withUnderscores}) { emit ->
Expand Down