Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve async support for Cursors #4102

Merged
merged 10 commits into from
Apr 25, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AndroidSqliteDriver private constructor(
createStatement: () -> AndroidStatement,
binders: (SqlPreparedStatement.() -> Unit)?,
result: AndroidStatement.() -> T,
): QueryResult<T> {
): QueryResult.Value<T> {
var statement: AndroidStatement? = null
if (identifier != null) {
statement = statements.remove(identifier)
Expand Down Expand Up @@ -171,7 +171,7 @@ class AndroidSqliteDriver private constructor(
override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
mapper: (SqlCursor) -> QueryResult<R>,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
) = execute(identifier, { AndroidQuery(sql, database, parameters) }, binders) { executeQuery(mapper) }
Expand Down Expand Up @@ -207,7 +207,7 @@ class AndroidSqliteDriver private constructor(

internal interface AndroidStatement : SqlPreparedStatement {
fun execute(): Long
fun <R> executeQuery(mapper: (SqlCursor) -> R): R
fun <R> executeQuery(mapper: (SqlCursor) -> QueryResult<R>): R
fun close()
}

Expand Down Expand Up @@ -238,7 +238,7 @@ private class AndroidPreparedStatement(
}
}

override fun <R> executeQuery(mapper: (SqlCursor) -> R): R = throw UnsupportedOperationException()
override fun <R> executeQuery(mapper: (SqlCursor) -> QueryResult<R>): R = throw UnsupportedOperationException()

override fun execute(): Long {
return statement.executeUpdateDelete().toLong()
Expand Down Expand Up @@ -284,9 +284,9 @@ private class AndroidQuery(

override fun execute() = throw UnsupportedOperationException()

override fun <R> executeQuery(mapper: (SqlCursor) -> R): R {
override fun <R> executeQuery(mapper: (SqlCursor) -> QueryResult<R>): R {
return database.query(this)
.use { cursor -> mapper(AndroidCursor(cursor)) }
.use { cursor -> mapper(AndroidCursor(cursor)).value }
}

override fun bindTo(statement: SupportSQLiteProgram) {
Expand All @@ -303,7 +303,7 @@ private class AndroidQuery(
private class AndroidCursor(
private val cursor: Cursor,
) : SqlCursor {
override fun next() = cursor.moveToNext()
override fun next(): QueryResult.Value<Boolean> = QueryResult.Value(cursor.moveToNext())
override fun getString(index: Int) = if (cursor.isNull(index)) null else cursor.getString(index)
override fun getLong(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index)
override fun getBytes(index: Int) = if (cursor.isNull(index)) null else cursor.getBlob(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ class AndroidDriverTest : DriverTest() {
fun `cached statement can be reused`() {
val driver = AndroidSqliteDriver(schema, getApplicationContext(), cacheSize = 1)
lateinit var bindable: SqlPreparedStatement
driver.executeQuery(1, "SELECT * FROM test", {}, 0, { bindable = this })
driver.executeQuery(1, "SELECT * FROM test", { QueryResult.Unit }, 0, { bindable = this })

driver.executeQuery(
1,
"SELECT * FROM test",
{},
{ QueryResult.Unit },
0,
{
assertSame(bindable, this)
Expand All @@ -43,14 +43,14 @@ class AndroidDriverTest : DriverTest() {
fun `cached statement is evicted and closed`() {
val driver = AndroidSqliteDriver(schema, getApplicationContext(), cacheSize = 1)
lateinit var bindable: SqlPreparedStatement
driver.executeQuery(1, "SELECT * FROM test", {}, 0, { bindable = this })
driver.executeQuery(1, "SELECT * FROM test", { QueryResult.Unit }, 0, { bindable = this })

driver.executeQuery(2, "SELECT * FROM test", {}, 0)
driver.executeQuery(2, "SELECT * FROM test", { QueryResult.Unit }, 0)

driver.executeQuery(
1,
"SELECT * FROM test",
{},
{ QueryResult.Unit },
0,
{
assertNotSame(bindable, this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ abstract class DriverTest {
private fun changes(): Long? {
// wrap in a transaction to ensure read happens on transaction thread/connection
return transacter.value!!.transactionWithResult {
val mapper: (SqlCursor) -> Long? = {
val mapper: (SqlCursor) -> QueryResult<Long?> = {
it.next()
it.getLong(0)
QueryResult.Value(it.getLong(0))
}
driver.executeQuery(null, "SELECT changes()", mapper, 0).value
}
Expand All @@ -87,12 +87,13 @@ abstract class DriverTest {
val insert = { binders: SqlPreparedStatement.() -> Unit ->
driver.execute(2, "INSERT INTO test VALUES (?, ?);", 2, binders)
}
fun query(mapper: (SqlCursor) -> Unit) {
fun query(mapper: (SqlCursor) -> QueryResult<Unit>) {
driver.executeQuery(3, "SELECT * FROM test", mapper, 0)
}

query {
assertFalse(it.next())
assertFalse(it.next().value)
QueryResult.Unit
}

insert {
Expand All @@ -101,16 +102,18 @@ abstract class DriverTest {
}

query {
assertTrue(it.next())
assertFalse(it.next())
assertTrue(it.next().value)
assertFalse(it.next().value)
QueryResult.Unit
}

assertEquals(1, changes())

query {
assertTrue(it.next())
assertTrue(it.next().value)
assertEquals(1, it.getLong(0))
assertEquals("Alec", it.getString(1))
QueryResult.Unit
}

insert {
Expand All @@ -120,19 +123,21 @@ abstract class DriverTest {
assertEquals(1, changes())

query {
assertTrue(it.next())
assertTrue(it.next().value)
assertEquals(1, it.getLong(0))
assertEquals("Alec", it.getString(1))
assertTrue(it.next())
assertTrue(it.next().value)
assertEquals(2, it.getLong(0))
assertEquals("Jake", it.getString(1))
QueryResult.Unit
}

driver.execute(5, "DELETE FROM test", 0)
assertEquals(2, changes())

query {
assertFalse(it.next())
assertFalse(it.next().value)
QueryResult.Unit
}
}

Expand All @@ -153,7 +158,7 @@ abstract class DriverTest {
}
assertEquals(1, changes())

fun query(binders: SqlPreparedStatement.() -> Unit, mapper: (SqlCursor) -> Unit) {
fun query(binders: SqlPreparedStatement.() -> Unit, mapper: (SqlCursor) -> QueryResult<Unit>) {
driver.executeQuery(6, "SELECT * FROM test WHERE value = ?", mapper, 1, binders)
}

Expand All @@ -162,9 +167,10 @@ abstract class DriverTest {
bindString(0, "Jake")
},
mapper = {
assertTrue(it.next())
assertTrue(it.next().value)
assertEquals(2, it.getLong(0))
assertEquals("Jake", it.getString(1))
QueryResult.Unit
},
)

Expand All @@ -174,9 +180,10 @@ abstract class DriverTest {
bindString(0, "Jake")
},
mapper = {
assertTrue(it.next())
assertTrue(it.next().value)
assertEquals(2, it.getLong(0))
assertEquals("Jake", it.getString(1))
QueryResult.Unit
},
)
}
Expand All @@ -194,13 +201,14 @@ abstract class DriverTest {
}
assertEquals(1, changes())

val mapper: (SqlCursor) -> Unit = {
assertTrue(it.next())
val mapper: (SqlCursor) -> QueryResult<Unit> = {
assertTrue(it.next().value)
assertEquals(1, it.getLong(0))
assertNull(it.getLong(1))
assertNull(it.getString(2))
assertNull(it.getBytes(3))
assertNull(it.getDouble(4))
QueryResult.Unit
}
driver.executeQuery(8, "SELECT * FROM nullability_test", mapper, 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ abstract class QueryTest {

private fun testDataQuery(): Query<TestData> {
return object : Query<TestData>(mapper) {
override fun <R> execute(mapper: (SqlCursor) -> R): QueryResult<R> {
override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> {
return driver.executeQuery(0, "SELECT * FROM test", mapper, 0, null)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,15 @@ abstract class JdbcDriver : SqlDriver, ConnectionManager {
override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
mapper: (SqlCursor) -> QueryResult<R>,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<R> {
val (connection, onClose) = connectionAndClose()
try {
return QueryResult.Value(
JdbcPreparedStatement(connection.prepareStatement(sql))
.apply { if (binders != null) this.binders() }
.executeQuery(mapper),
)
return JdbcPreparedStatement(connection.prepareStatement(sql))
.apply { if (binders != null) this.binders() }
.executeQuery(mapper)
} finally {
onClose()
}
Expand Down Expand Up @@ -310,5 +308,5 @@ class JdbcCursor(val resultSet: ResultSet) : SqlCursor {
private fun <T> getAtIndex(index: Int, converter: (Int) -> T): T? =
converter(index + 1).takeUnless { resultSet.wasNull() }

override fun next() = resultSet.next()
override fun next(): QueryResult.Value<Boolean> = QueryResult.Value(resultSet.next())
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,12 @@ sealed class ConnectionWrapper : SqlDriver {
final override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
mapper: (SqlCursor) -> QueryResult<R>,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<R> = QueryResult.Value(
accessStatement(true, identifier, sql, binders) { statement ->
mapper(SqliterSqlCursor(statement.query()))
},
)
): QueryResult<R> = accessStatement(true, identifier, sql, binders) { statement ->
mapper(SqliterSqlCursor(statement.query()))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package app.cash.sqldelight.driver.native

import app.cash.sqldelight.db.QueryResult
import app.cash.sqldelight.db.SqlCursor
import co.touchlab.sqliter.Cursor
import co.touchlab.sqliter.getBytesOrNull
Expand All @@ -25,5 +26,5 @@ internal class SqliterSqlCursor(private val cursor: Cursor) : SqlCursor {
return (cursor.getLongOrNull(index) ?: return null) == 1L
}

override fun next(): Boolean = cursor.next()
override fun next(): QueryResult.Value<Boolean> = QueryResult.Value(cursor.next())
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ abstract class BaseConcurrencyTest {
return myDriver.executeQuery(
0,
"SELECT count(*) FROM test",
{ it.next(); it.getLong(0)!! },
{ it.next(); QueryResult.Value(it.getLong(0)!!) },
0,
).value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class WalConcurrencyTest : BaseConcurrencyTest() {

private fun testDataQuery(): Query<TestData> {
return object : Query<TestData>(mapper) {
override fun <R> execute(mapper: (SqlCursor) -> R): QueryResult<R> {
override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> {
return driver.executeQuery(0, "SELECT * FROM test", mapper, 0, null)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class R2dbcDriver(val connection: Connection) : SqlDriver {
override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
mapper: (SqlCursor) -> QueryResult<R>,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<R> {
Expand All @@ -32,7 +32,7 @@ class R2dbcDriver(val connection: Connection) : SqlDriver {
List(rowMetadata.columnMetadatas.size) { index -> row.get(index) }
}.asFlow().toList()

return@AsyncValue mapper(R2dbcCursor(rowSet))
return@AsyncValue mapper(R2dbcCursor(rowSet)).await()
}
}

Expand Down Expand Up @@ -157,7 +157,7 @@ class R2dbcCursor(val rowSet: List<List<Any?>>) : SqlCursor {
var row = -1
private set

override fun next(): Boolean = ++row < rowSet.size
override fun next(): QueryResult.Value<Boolean> = QueryResult.Value(++row < rowSet.size)

override fun getString(index: Int): String? = rowSet[row][index] as String?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class JsSqlDriver(private val db: Database) : SqlDriver {
override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
mapper: (SqlCursor) -> QueryResult<R>,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<R> {
Expand All @@ -60,7 +60,7 @@ class JsSqlDriver(private val db: Database) : SqlDriver {
}

return try {
QueryResult.Value(mapper(cursor))
mapper(cursor)
} finally {
cursor.close()
}
Expand Down Expand Up @@ -120,7 +120,7 @@ class JsSqlDriver(private val db: Database) : SqlDriver {
}

private class JsSqlCursor(private val statement: Statement) : SqlCursor {
override fun next(): Boolean = statement.step()
override fun next(): QueryResult.Value<Boolean> = QueryResult.Value(statement.step())
override fun getString(index: Int): String? = statement.get()[index]
override fun getLong(index: Int): Long? = (statement.get()[index] as? Double)?.toLong()
override fun getBytes(index: Int): ByteArray? = (statement.get()[index] as? Uint8Array)?.let {
Expand Down
Loading