Skip to content

Commit

Permalink
Expose more JDBC/R2DBC statement methods for dialect authors (#5098)
Browse files Browse the repository at this point in the history
  • Loading branch information
hfhbd authored Apr 4, 2024
1 parent 4051c34 commit 8b4f448
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.INT
import com.squareup.kotlinpoet.LONG
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.SHORT
import com.squareup.kotlinpoet.STRING
import com.squareup.kotlinpoet.TypeName
Expand All @@ -24,19 +25,24 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
;

override fun prepareStatementBinder(columnIndex: CodeBlock, value: CodeBlock): CodeBlock {
return CodeBlock.builder()
.add(
when (this) {
SMALL_INT -> "bindShort"
INTEGER -> "bindInt"
BIG_INT -> "bindLong"
DATE, TIME, TIMESTAMP, TIMESTAMP_TIMEZONE, INTERVAL, UUID -> "bindObject"
NUMERIC -> "bindBigDecimal"
JSON -> "bindObjectOther"
},
return when (this) {
SMALL_INT -> CodeBlock.of("bindShort(%L, %L)\n", columnIndex, value)
INTEGER -> CodeBlock.of("bindInt(%L, %L)\n", columnIndex, value)
BIG_INT -> CodeBlock.of("bindLong(%L, %L)\n", columnIndex, value)
DATE, TIME, TIMESTAMP, TIMESTAMP_TIMEZONE, INTERVAL, UUID -> CodeBlock.of(
"bindObject(%L, %L)\n",
columnIndex,
value,
)
.add("(%L, %L)\n", columnIndex, value)
.build()

NUMERIC -> CodeBlock.of("bindBigDecimal(%L, %L)\n", columnIndex, value)
JSON -> CodeBlock.of(
"bindObject(%L, %L, %M)\n",
columnIndex,
value,
MemberName(ClassName("java.sql", "Types"), "OTHER"),
)
}
}

override fun cursorGetter(columnIndex: Int, cursorName: String): CodeBlock {
Expand Down
8 changes: 7 additions & 1 deletion drivers/jdbc-driver/api/jdbc-driver.api
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ public final class app/cash/sqldelight/driver/jdbc/JdbcCursor : app/cash/sqldeli
public fun getBoolean (I)Ljava/lang/Boolean;
public final fun getByte (I)Ljava/lang/Byte;
public fun getBytes (I)[B
public final fun getDate (I)Ljava/sql/Date;
public fun getDouble (I)Ljava/lang/Double;
public final fun getFloat (I)Ljava/lang/Float;
public final fun getInt (I)Ljava/lang/Integer;
public fun getLong (I)Ljava/lang/Long;
public final fun getResultSet ()Ljava/sql/ResultSet;
public final fun getShort (I)Ljava/lang/Short;
public fun getString (I)Ljava/lang/String;
public final fun getTime (I)Ljava/sql/Time;
public final fun getTimestamp (I)Ljava/sql/Timestamp;
public synthetic fun next ()Lapp/cash/sqldelight/db/QueryResult;
public fun next-mlR-ZEE ()Ljava/lang/Object;
}
Expand Down Expand Up @@ -58,14 +61,17 @@ public final class app/cash/sqldelight/driver/jdbc/JdbcPreparedStatement : app/c
public fun bindBoolean (ILjava/lang/Boolean;)V
public final fun bindByte (ILjava/lang/Byte;)V
public fun bindBytes (I[B)V
public final fun bindDate (ILjava/sql/Date;)V
public fun bindDouble (ILjava/lang/Double;)V
public final fun bindFloat (ILjava/lang/Float;)V
public final fun bindInt (ILjava/lang/Integer;)V
public fun bindLong (ILjava/lang/Long;)V
public final fun bindObject (ILjava/lang/Object;)V
public final fun bindObjectOther (ILjava/lang/Object;)V
public final fun bindObject (ILjava/lang/Object;I)V
public final fun bindShort (ILjava/lang/Short;)V
public fun bindString (ILjava/lang/String;)V
public final fun bindTime (ILjava/sql/Time;)V
public final fun bindTimestamp (ILjava/sql/Timestamp;)V
public final fun execute ()J
public final fun executeQuery (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ class JdbcPreparedStatement(
private val preparedStatement: PreparedStatement,
) : SqlPreparedStatement {
override fun bindBytes(index: Int, bytes: ByteArray?) {
if (bytes == null) {
preparedStatement.setNull(index + 1, Types.BLOB)
} else {
preparedStatement.setBytes(index + 1, bytes)
}
preparedStatement.setBytes(index + 1, bytes)
}

override fun bindBoolean(index: Int, boolean: Boolean?) {
Expand Down Expand Up @@ -246,11 +242,7 @@ class JdbcPreparedStatement(
}

fun bindBigDecimal(index: Int, decimal: BigDecimal?) {
if (decimal == null) {
preparedStatement.setNull(index + 1, Types.NUMERIC)
} else {
preparedStatement.setBigDecimal(index + 1, decimal)
}
preparedStatement.setBigDecimal(index + 1, decimal)
}

fun bindObject(index: Int, obj: Any?) {
Expand All @@ -261,20 +253,28 @@ class JdbcPreparedStatement(
}
}

fun bindObjectOther(index: Int, obj: Any?) {
fun bindObject(index: Int, obj: Any?, type: Int) {
if (obj == null) {
preparedStatement.setNull(index + 1, Types.OTHER)
preparedStatement.setNull(index + 1, type)
} else {
preparedStatement.setObject(index + 1, obj, Types.OTHER)
preparedStatement.setObject(index + 1, obj, type)
}
}

override fun bindString(index: Int, string: String?) {
if (string == null) {
preparedStatement.setNull(index + 1, Types.VARCHAR)
} else {
preparedStatement.setString(index + 1, string)
}
preparedStatement.setString(index + 1, string)
}

fun bindDate(index: Int, date: java.sql.Date?) {
preparedStatement.setDate(index, date)
}

fun bindTime(index: Int, date: java.sql.Time?) {
preparedStatement.setTime(index, date)
}

fun bindTimestamp(index: Int, timestamp: java.sql.Timestamp?) {
preparedStatement.setTimestamp(index, timestamp)
}

fun <R> executeQuery(mapper: (SqlCursor) -> R): R {
Expand Down Expand Up @@ -312,6 +312,9 @@ class JdbcCursor(val resultSet: ResultSet) : SqlCursor {
override fun getDouble(index: Int): Double? = getAtIndex(index, resultSet::getDouble)
fun getBigDecimal(index: Int): BigDecimal? = resultSet.getBigDecimal(index + 1)
inline fun <reified T : Any> getObject(index: Int): T? = resultSet.getObject(index + 1, T::class.java)
fun getDate(index: Int): java.sql.Date? = resultSet.getDate(index)
fun getTime(index: Int): java.sql.Time? = resultSet.getTime(index)
fun getTimestamp(index: Int): java.sql.Timestamp? = resultSet.getTimestamp(index)

@Suppress("UNCHECKED_CAST")
fun <T> getArray(index: Int) = getAtIndex(index, resultSet::getArray)?.array as Array<T>?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.db.SqlPreparedStatement
import io.r2dbc.spi.Connection
import io.r2dbc.spi.Statement
import java.math.BigDecimal
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -155,7 +156,7 @@ fun CoroutineScope.R2dbcDriver(
}

// R2DBC uses boxed Java classes instead primitives: https://r2dbc.io/spec/1.0.0.RELEASE/spec/html/#datatypes
class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStatement {
class R2dbcPreparedStatement(val statement: Statement) : SqlPreparedStatement {
override fun bindBytes(index: Int, bytes: ByteArray?) {
if (bytes == null) {
statement.bindNull(index, ByteArray::class.java)
Expand All @@ -164,6 +165,22 @@ class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStat
}
}

override fun bindBoolean(index: Int, boolean: Boolean?) {
if (boolean == null) {
statement.bindNull(index, Boolean::class.javaObjectType)
} else {
statement.bind(index, boolean)
}
}

fun bindByte(index: Int, byte: Byte?) {
if (byte == null) {
statement.bindNull(index, Byte::class.javaObjectType)
} else {
statement.bind(index, byte)
}
}

fun bindShort(index: Int, short: Short?) {
if (short == null) {
statement.bindNull(index, Short::class.javaObjectType)
Expand All @@ -188,6 +205,14 @@ class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStat
}
}

fun bindFloat(index: Int, float: Float?) {
if (float == null) {
statement.bindNull(index, Float::class.javaObjectType)
} else {
statement.bind(index, float)
}
}

override fun bindDouble(index: Int, double: Double?) {
if (double == null) {
statement.bindNull(index, Double::class.javaObjectType)
Expand All @@ -196,29 +221,38 @@ class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStat
}
}

override fun bindString(index: Int, string: String?) {
if (string == null) {
statement.bindNull(index, String::class.java)
fun bindBigDecimal(index: Int, decimal: BigDecimal?) {
if (decimal == null) {
statement.bindNull(index, BigDecimal::class.java)
} else {
statement.bind(index, string)
statement.bind(index, decimal)
}
}

override fun bindBoolean(index: Int, boolean: Boolean?) {
if (boolean == null) {
statement.bindNull(index, Boolean::class.javaObjectType)
fun bindObject(index: Int, any: Any?, ignoredSqlType: Int = 0) {
if (any == null) {
statement.bindNull(index, Any::class.java)
} else {
statement.bind(index, boolean)
statement.bind(index, any)
}
}

fun bindObject(index: Int, any: Any?) {
@JvmName("bindTypedObject")
inline fun <reified T : Any> bindObject(index: Int, any: T?) {
if (any == null) {
statement.bindNull(index, Any::class.java)
statement.bindNull(index, T::class.java)
} else {
statement.bind(index, any)
}
}

override fun bindString(index: Int, string: String?) {
if (string == null) {
statement.bindNull(index, String::class.java)
} else {
statement.bind(index, string)
}
}
}

internal fun <T : Any> Publisher<T>.asIterator(): AsyncPublisherIterator<T> =
Expand Down

0 comments on commit 8b4f448

Please sign in to comment.