diff --git a/Sources/SQLite/Core/Statement.swift b/Sources/SQLite/Core/Statement.swift index a9232bba..b763e40e 100644 --- a/Sources/SQLite/Core/Statement.swift +++ b/Sources/SQLite/Core/Statement.swift @@ -191,6 +191,14 @@ public final class Statement { } +extension Statement { + + func rowCursorNext() throws -> [Binding?]? { + return try step() ? Array(row) : nil + } + +} + extension Statement : Sequence { public func makeIterator() -> Statement { diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index c9d2ea9c..f3584f46 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -890,58 +890,88 @@ public struct Delete : ExpressionType { } +public struct RowCursor { + let statement: Statement + let columnNames: [String: Int] + + public func next() throws -> Row? { + return try statement.rowCursorNext().flatMap { Row(columnNames, $0) } + } + + public func map(_ transform: (Row) throws -> T) throws -> [T] { + var elements = [T]() + while true { + if let row = try next() { + elements.append(try transform(row)) + } else { + break + } + } + + return elements + } +} + extension Connection { + + public func prepareCursor(_ query: QueryType) throws -> RowCursor { + let expression = query.expression + let statement = try prepare(expression.template, expression.bindings) + return RowCursor(statement: statement, columnNames: try columnNamesForQuery(query)) + } public func prepare(_ query: QueryType) throws -> AnySequence { let expression = query.expression let statement = try prepare(expression.template, expression.bindings) - let columnNames: [String: Int] = try { - var (columnNames, idx) = ([String: Int](), 0) - column: for each in query.clauses.select.columns { - var names = each.expression.template.characters.split { $0 == "." }.map(String.init) - let column = names.removeLast() - let namespace = names.joined(separator: ".") - - func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) { - return { (query: QueryType) throws -> (Void) in - var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) - q.clauses.select = query.clauses.select - let e = q.expression - var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() } - if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } } - for name in names { columnNames[name] = idx; idx += 1 } - } - } + let columnNames = try columnNamesForQuery(query) - if column == "*" { - var select = query - select.clauses.select = (false, [Expression(literal: "*") as Expressible]) - let queries = [select] + query.clauses.join.map { $0.query } - if !namespace.isEmpty { - for q in queries { - if q.tableName().expression.template == namespace { - try expandGlob(true)(q) - continue column - } - } - fatalError("no such table: \(namespace)") - } + return AnySequence { + AnyIterator { statement.next().map { Row(columnNames, $0) } } + } + } + + private func columnNamesForQuery(_ query: QueryType) throws -> [String: Int] { + var (columnNames, idx) = ([String: Int](), 0) + column: for each in query.clauses.select.columns { + var names = each.expression.template.characters.split { $0 == "." }.map(String.init) + let column = names.removeLast() + let namespace = names.joined(separator: ".") + + func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) { + return { (query: QueryType) throws -> (Void) in + var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) + q.clauses.select = query.clauses.select + let e = q.expression + var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() } + if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } } + for name in names { columnNames[name] = idx; idx += 1 } + } + } + + if column == "*" { + var select = query + select.clauses.select = (false, [Expression(literal: "*") as Expressible]) + let queries = [select] + query.clauses.join.map { $0.query } + if !namespace.isEmpty { for q in queries { - try expandGlob(query.clauses.join.count > 0)(q) + if q.tableName().expression.template == namespace { + try expandGlob(true)(q) + continue column + } } - continue + fatalError("no such table: \(namespace)") } - - columnNames[each.expression.template] = idx - idx += 1 + for q in queries { + try expandGlob(query.clauses.join.count > 0)(q) + } + continue } - return columnNames - }() - - return AnySequence { - AnyIterator { statement.next().map { Row(columnNames, $0) } } + + columnNames[each.expression.template] = idx + idx += 1 } + return columnNames } public func scalar(_ query: ScalarQuery) throws -> V { @@ -967,7 +997,7 @@ extension Connection { } public func pluck(_ query: QueryType) throws -> Row? { - return try prepare(query.limit(1, query.clauses.limit?.offset)).makeIterator().next() + return try prepareCursor(query.limit(1, query.clauses.limit?.offset)).next() } /// Runs an `Insert` query. diff --git a/Tests/SQLiteTests/QueryTests.swift b/Tests/SQLiteTests/QueryTests.swift index 2cf164c6..bbf1ee42 100644 --- a/Tests/SQLiteTests/QueryTests.swift +++ b/Tests/SQLiteTests/QueryTests.swift @@ -333,6 +333,16 @@ class QueryIntegrationTests : SQLiteTestCase { _ = user[users[managerId]] } } + + func test_prepareCursor() { + let names = ["a", "b", "c"] + try! InsertUsers(names) + + let emailColumn = Expression("email") + let emails = try! db.prepareCursor(users).map { $0[emailColumn] } + + XCTAssertEqual(names.map({ "\($0)@example.com" }), emails.sorted()) + } func test_scalar() { XCTAssertEqual(0, try! db.scalar(users.count))