From 586a99339db43596ecee594317dde5bd4f5ebb7f Mon Sep 17 00:00:00 2001 From: Nick Shelley Date: Fri, 5 May 2017 14:48:18 -0600 Subject: [PATCH 1/2] Add cursor type for more safety. --- Sources/SQLite/Core/Statement.swift | 8 ++ Sources/SQLite/Typed/Query.swift | 111 ++++++++++++++++++---------- Tests/SQLiteTests/QueryTests.swift | 10 +++ 3 files changed, 89 insertions(+), 40 deletions(-) diff --git a/Sources/SQLite/Core/Statement.swift b/Sources/SQLite/Core/Statement.swift index 8c13ff6c..c817e887 100644 --- a/Sources/SQLite/Core/Statement.swift +++ b/Sources/SQLite/Core/Statement.swift @@ -191,6 +191,14 @@ public final class Statement { } +extension Statement { + + func rowCursorNext() throws -> [Binding?]? { + return try step() ? Array(row) : nil + } + +} + extension Statement : Sequence { public func makeIterator() -> Statement { diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index 7bcda40b..2a11a910 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -894,58 +894,89 @@ public struct Delete : ExpressionType { } +public struct RowCursor { + let statement: Statement + let columnNames: [String: Int] + + public func next() throws -> Row? { + return try statement.rowCursorNext().flatMap { Row(columnNames, $0) } + } + + public func map(_ transform: (Row) throws -> T) throws -> [T] { + var elements = [T]() + while true { + if let row = try next() { + elements.append(try transform(row)) + } else { + break + } + } + + return elements + } +} + extension Connection { + + public func prepareCursor(_ query: QueryType) throws -> RowCursor { + let expression = query.expression + let statement = try prepare(expression.template, expression.bindings) + return RowCursor(statement: statement, columnNames: try columnNamesForQuery(query)) + } public func prepare(_ query: QueryType) throws -> AnySequence { let expression = query.expression let statement = try prepare(expression.template, expression.bindings) - let columnNames: [String: Int] = try { - var (columnNames, idx) = ([String: Int](), 0) - column: for each in query.clauses.select.columns { - var names = each.expression.template.characters.split { $0 == "." }.map(String.init) - let column = names.removeLast() - let namespace = names.joined(separator: ".") - - func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) { - return { (query: QueryType) throws -> (Void) in - var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) - q.clauses.select = query.clauses.select - let e = q.expression - var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() } - if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } } - for name in names { columnNames[name] = idx; idx += 1 } - } - } + let columnNames = try columnNamesForQuery(query) - if column == "*" { - var select = query - select.clauses.select = (false, [Expression(literal: "*") as Expressible]) - let queries = [select] + query.clauses.join.map { $0.query } - if !namespace.isEmpty { - for q in queries { - if q.tableName().expression.template == namespace { - try expandGlob(true)(q) - continue column - } + return AnySequence { + AnyIterator { statement.next().map { Row(columnNames, $0) } } + } + } + + private func columnNamesForQuery(_ query: QueryType) throws -> [String: Int] { + var (columnNames, idx) = ([String: Int](), 0) + column: for each in query.clauses.select.columns { + var names = each.expression.template.characters.split { $0 == "." }.map(String.init) + let column = names.removeLast() + let namespace = names.joined(separator: ".") + + func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) { + return { (query: QueryType) throws -> (Void) in + var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) + q.clauses.select = query.clauses.select + let e = q.expression + var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() } + if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } } + for name in names { columnNames[name] = idx; idx += 1 } + } + } + + if column == "*" { + var select = query + select.clauses.select = (false, [Expression(literal: "*") as Expressible]) + let queries = [select] + query.clauses.join.map { $0.query } + if !namespace.isEmpty { + for q in queries { + if q.tableName().expression.template == namespace { + try expandGlob(true)(q) + continue column } throw QueryError.noSuchTable(name: namespace) } - for q in queries { - try expandGlob(query.clauses.join.count > 0)(q) - } - continue + fatalError("no such table: \(namespace)") } - - columnNames[each.expression.template] = idx - idx += 1 + for q in queries { + try expandGlob(query.clauses.join.count > 0)(q) + } + continue } - return columnNames - }() - - return AnySequence { - AnyIterator { statement.next().map { Row(columnNames, $0) } } + + columnNames[each.expression.template] = idx + idx += 1 } + return columnNames } public func scalar(_ query: ScalarQuery) throws -> V { @@ -971,7 +1002,7 @@ extension Connection { } public func pluck(_ query: QueryType) throws -> Row? { - return try prepare(query.limit(1, query.clauses.limit?.offset)).makeIterator().next() + return try prepareCursor(query.limit(1, query.clauses.limit?.offset)).next() } /// Runs an `Insert` query. diff --git a/Tests/SQLiteTests/QueryTests.swift b/Tests/SQLiteTests/QueryTests.swift index 3ab0419a..c2f48bcb 100644 --- a/Tests/SQLiteTests/QueryTests.swift +++ b/Tests/SQLiteTests/QueryTests.swift @@ -347,6 +347,16 @@ class QueryIntegrationTests : SQLiteTestCase { _ = user[users[managerId]] } } + + func test_prepareCursor() { + let names = ["a", "b", "c"] + try! InsertUsers(names) + + let emailColumn = Expression("email") + let emails = try! db.prepareCursor(users).map { $0[emailColumn] } + + XCTAssertEqual(names.map({ "\($0)@example.com" }), emails.sorted()) + } func test_select_optional() { for _ in try! db.prepare(users) { From 79f80956e0c0763e29a3dae0f964217c0e6e6d02 Mon Sep 17 00:00:00 2001 From: Nick Shelley Date: Wed, 20 Sep 2017 07:53:39 -0600 Subject: [PATCH 2/2] Change fatalError to throw. --- Sources/SQLite/Typed/Query.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index 2a11a910..86aff26f 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -965,7 +965,7 @@ extension Connection { } throw QueryError.noSuchTable(name: namespace) } - fatalError("no such table: \(namespace)") + throw QueryError.noSuchTable(name: namespace) } for q in queries { try expandGlob(query.clauses.join.count > 0)(q)