Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
helje5 committed Nov 27, 2024
2 parents 25c667e + 9c0f562 commit 628fb69
Showing 1 changed file with 124 additions and 52 deletions.
176 changes: 124 additions & 52 deletions Sources/PostgreSQLAdaptor/PostgreSQLAdaptorChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {
case NotImplemented
case ConnectionClosed
}

static let isDebugDefaultOn =
UserDefaults.standard.bool(forKey: "PGDebugEnabled")

public let expressionFactory : SQLExpressionFactory
public var handle : OpaquePointer?
Expand All @@ -39,7 +42,7 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {
init(adaptor: Adaptor, handle: OpaquePointer) {
self.expressionFactory = adaptor.expressionFactory
self.handle = handle
self.logSQL = UserDefaults.standard.bool(forKey: "PGDebugEnabled")
self.logSQL = Self.isDebugDefaultOn
}

deinit {
Expand Down Expand Up @@ -193,67 +196,31 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {
for bind in bindings {
// if logSQL { print(" BIND[\(idx)]: \(bind)") }

struct Bind { // move out
var type : Oid = 0
var length : Int32 = 0
var isBinary : Int32 = BinaryFlag
var rawValue : UnsafePointer<Int8>? = nil
}

if let attr = bind.attribute {
if logSQL { print(" BIND[\(idx)]: \(attr.name)") }

// TODO: ask attribute for OID
}

// FIXME(hh 2024-11-25): Unnested all this stuff.
// TODO: Add a protocol to do this?
func bindAnyValue(_ value: Any?) throws -> Bind {
guard let value = value else {
if logSQL { print(" [\(idx)]> bind NULL") }
// TODO: set value to NULL
return Bind(type: 0 /*Hmmm*/, length: 0, rawValue: nil)
}
switch value {
case let value as String:
if logSQL { print(" [\(idx)]> bind string \"\(value)\"") }
// TODO: include 0 in length?
let rawValue = UnsafePointer(strdup(value))
return Bind(type: OIDs.VARCHAR,
length: rawValue.flatMap { Int32(strlen($0)) } ?? 0,
rawValue: rawValue)
case let value as Int:
if logSQL { print(" [\(idx)]> bind int \(value)") }
let bp = tdup(value.bigEndian)
return Bind(type: MemoryLayout<Int>.size == 8
? OIDs.INT8 : OIDs.INT4,
length: Int32(bp.count), rawValue: bp.baseAddress!)
case let value as Int32: return try bindAnyValue(Int(value))
case let value as Int64: return try bindAnyValue(Int(value))
case let value as any BinaryInteger:
return try bindAnyValue(Int(value))
case let value as GlobalID:
assert(value.keyCount == 1)
switch value.value {
case .singleNil : return try bindAnyValue(nil)
case .int (let value) : return try bindAnyValue(value)
case .string(let value) : return try bindAnyValue(value)
case .uuid (let value) : return try bindAnyValue(value)
case .values(let values):
if values.count > 1 {
throw Error.ExecError(reason: "Invalid multi-gid bind",
sql: sql)
}
if let value = values.first { return try bindAnyValue(value) }
else { return try bindAnyValue(nil) }
}
default: // TODO
if logSQL { print(" [\(idx)]> bind other \(value)") }
assertionFailure("Unexpected value, please add explicit type")
let rawValue = UnsafePointer(strdup(String(describing: value)))
return Bind(type: OIDs.VARCHAR,
length: rawValue.flatMap { Int32(strlen($0)) } ?? 0,
rawValue: rawValue)

if let value = value as? PGBindableValue {
return try value.bind(index: idx, log: logSQL)
}

if logSQL { print(" [\(idx)]> bind other \(value)") }
assertionFailure("Unexpected value, please add explicit type")
let rawValue = UnsafePointer(strdup(String(describing: value)))
return Bind(type: OIDs.VARCHAR,
length: rawValue.flatMap { Int32(strlen($0)) } ?? 0,
rawValue: rawValue)
}

let bindInfo = try bindAnyValue(bind.value)
Expand Down Expand Up @@ -366,7 +333,7 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {
// - Max resolution 1 microsecond
if value.count == 8 {
// 1_000_000
let msecs = Double(UInt64(bigEndian: cast(value.baseAddress!)))
let msecs = Double(Int64(bigEndian: cast(value.baseAddress!)))
let date = Date(timeInterval: TimeInterval(msecs) / 1000000.0,
since: Date.pgReferenceDate)
return date
Expand Down Expand Up @@ -453,17 +420,20 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {

public var isTransactionInProgress : Bool = false

@inlinable
public func begin() throws {
guard !isTransactionInProgress
else { throw AdaptorChannelError.TransactionInProgress }

try performSQL("BEGIN TRANSACTION;")
isTransactionInProgress = true
}
@inlinable
public func commit() throws {
isTransactionInProgress = false
try performSQL("COMMIT TRANSACTION;")
}
@inlinable
public func rollback() throws {
isTransactionInProgress = false
try performSQL("ROLLBACK TRANSACTION;")
Expand Down Expand Up @@ -492,17 +462,21 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {

// MARK: - reflection

@inlinable
public func describeSequenceNames() throws -> [ String ] {
return try PostgreSQLModelFetch(channel: self).describeSequenceNames()
}

@inlinable
public func describeDatabaseNames() throws -> [ String ] {
return try PostgreSQLModelFetch(channel: self).describeDatabaseNames()
}
@inlinable
public func describeTableNames() throws -> [ String ] {
return try PostgreSQLModelFetch(channel: self).describeTableNames()
}

@inlinable
public func describeEntityWithTableName(_ table: String) throws -> Entity? {
return try PostgreSQLModelFetch(channel: self)
.describeEntityWithTableName(table)
Expand All @@ -511,12 +485,11 @@ open class PostgreSQLAdaptorChannel : AdaptorChannel, SmartDescription {

// MARK: - Insert w/ auto-increment support

open func insertRow(_ row: AdaptorRow, _ entity: Entity?, refetchAll: Bool)
@inlinable
open func insertRow(_ row: AdaptorRow, _ entity: Entity, refetchAll: Bool)
throws -> AdaptorRow
{
let attributes : [ Attribute ]? = {
guard let entity = entity else { return nil }

if refetchAll { return entity.attributes }

// TBD: refetch-all if no pkeys are assigned
Expand Down Expand Up @@ -564,3 +537,102 @@ fileprivate extension Date {
// 2000-01-01
static let pgReferenceDate = Date(timeIntervalSince1970: 946684800)
}


// MARK: - Binding

fileprivate struct Bind {
// So this always has the value *malloc*'ed in rawValue, which is not
// particularily great :-)

var type : Oid = 0
var length : Int32 = 0
var isBinary : Int32 = BinaryFlag
var rawValue : UnsafePointer<Int8>? = nil
}

fileprivate protocol PGBindableValue {

func bind(index: Int, log: Bool) throws -> Bind
}

extension Optional: PGBindableValue where Wrapped: PGBindableValue {

fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
switch self {
case .some(let value): return try value.bind(index: idx, log: log)
case .none:
if log { print(" [\(idx)]> bind NULL") }
return Bind(type: 0 /*Hmmm*/, length: 0, rawValue: nil)
}
}
}

extension String: PGBindableValue {

fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
if log { print(" [\(idx)]> bind string \"\(self)\"") }
// TODO: include 0 in length?
let rawValue = UnsafePointer(strdup(self))
return Bind(type: OIDs.VARCHAR,
length: rawValue.flatMap { Int32(strlen($0)) } ?? 0,
rawValue: rawValue)
}
}
extension BinaryInteger {

fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
if log { print(" [\(idx)]> bind int \(self)") }
let value = Int(self) // Hmm
let bp = tdup(value.bigEndian)
return Bind(type: MemoryLayout<Int>.size == 8
? OIDs.INT8 : OIDs.INT4,
length: Int32(bp.count), rawValue: bp.baseAddress!)
}
}
extension Int : PGBindableValue {}
extension UInt : PGBindableValue {}
extension Int32 : PGBindableValue {}
extension Int64 : PGBindableValue {}

extension Date: PGBindableValue {

fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
let diff = self.timeIntervalSince(Date.pgReferenceDate)
let msecs = Int64(diff * 1000000.0) // seconds to milliseconds
let bp = tdup(msecs.bigEndian)
return Bind(type: OIDs.TIMESTAMPTZ, // 1184
length: 8,
rawValue: bp.baseAddress!)
}
}

extension UUID: PGBindableValue {
fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
return try uuidString.bind(index: idx, log: log)
}
}

extension KeyGlobalID: PGBindableValue {

fileprivate func bind(index idx: Int, log: Bool) throws -> Bind {
assert(keyCount == 1)
switch value {
case .singleNil:
return try Optional<String>.none.bind(index: idx, log: log)
case .int (let value) : return try value.bind(index: idx, log: log)
case .string(let value) : return try value.bind(index: idx, log: log)
case .uuid (let value) : return try value.bind(index: idx, log: log)
case .values(let values):
if values.count > 1 {
throw PostgreSQLAdaptorChannel.Error
.ExecError(reason: "Invalid multi-gid bind", sql: "")
}
assert(values.first is PGBindableValue)
if let value = values.first as? PGBindableValue {
return try value.bind(index: idx, log: log)
}
else { return try Optional<String>.none.bind(index: idx, log: log) }
}
}
}

0 comments on commit 628fb69

Please sign in to comment.