diff --git a/Sources/SwiftGRPC/Core/Channel.swift b/Sources/SwiftGRPC/Core/Channel.swift index 7e386d447..ebde440da 100644 --- a/Sources/SwiftGRPC/Core/Channel.swift +++ b/Sources/SwiftGRPC/Core/Channel.swift @@ -14,8 +14,8 @@ * limitations under the License. */ #if SWIFT_PACKAGE - import CgRPC - import Dispatch +import CgRPC +import Dispatch #endif import Foundation @@ -23,9 +23,10 @@ import Foundation public class Channel { /// Pointer to underlying C representation private let underlyingChannel: UnsafeMutableRawPointer - /// Completion queue for channel call operations private let completionQueue: CompletionQueue + /// Observer for connectivity state changes. + private lazy var connectivityObserver = ConnectivityObserver(underlyingChannel: self.underlyingChannel) /// Timeout for new calls public var timeout: TimeInterval = 600.0 @@ -33,9 +34,6 @@ public class Channel { /// Default host to use for new calls public var host: String - /// Connectivity state observers - private var connectivityObservers: [ConnectivityObserver] = [] - /// Initializes a gRPC channel /// /// - Parameter address: the address of the server to be called @@ -47,12 +45,12 @@ public class Channel { let argumentWrappers = arguments.map { $0.toCArg() } underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - if secure { - return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count)) - } else { - return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count)) - } + var argumentValues = argumentWrappers.map { $0.wrapped } + if secure { + return cgrpc_channel_create_secure(address, kRootCertificates, nil, nil, &argumentValues, Int32(arguments.count)) + } else { + return cgrpc_channel_create(address, &argumentValues, Int32(arguments.count)) + } } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") completionQueue.run() // start a loop that watches the channel's completion queue @@ -66,10 +64,10 @@ public class Channel { gRPC.initialize() host = googleAddress let argumentWrappers = arguments.map { $0.toCArg() } - + underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count)) + var argumentValues = argumentWrappers.map { $0.wrapped } + return cgrpc_channel_create_google(googleAddress, &argumentValues, Int32(arguments.count)) } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") @@ -89,15 +87,15 @@ public class Channel { let argumentWrappers = arguments.map { $0.toCArg() } underlyingChannel = withExtendedLifetime(argumentWrappers) { - var argumentValues = argumentWrappers.map { $0.wrapped } - return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count)) + var argumentValues = argumentWrappers.map { $0.wrapped } + return cgrpc_channel_create_secure(address, certificates, clientCertificates, clientKey, &argumentValues, Int32(arguments.count)) } completionQueue = CompletionQueue(underlyingCompletionQueue: cgrpc_channel_completion_queue(underlyingChannel), name: "Client") completionQueue.run() // start a loop that watches the channel's completion queue } deinit { - connectivityObservers.forEach { $0.shutdown() } + connectivityObserver.shutdown() cgrpc_channel_destroy(underlyingChannel) completionQueue.shutdown() } @@ -109,7 +107,7 @@ public class Channel { /// - Parameter timeout: a timeout value in seconds /// - Returns: a Call object that can be used to perform the request public func makeCall(_ method: String, host: String = "", timeout: TimeInterval? = nil) -> Call { - let host = (host == "") ? self.host : host + let host = host.isEmpty ? self.host : host let timeout = timeout ?? self.timeout let underlyingCall = cgrpc_channel_create_call(underlyingChannel, method, host, timeout)! return Call(underlyingCall: underlyingCall, owned: true, completionQueue: completionQueue) @@ -126,8 +124,8 @@ public class Channel { /// Subscribe to connectivity state changes /// /// - Parameter callback: block executed every time a new connectivity state is detected - public func subscribe(callback: @escaping (ConnectivityState) -> Void) { - connectivityObservers.append(ConnectivityObserver(underlyingChannel: underlyingChannel, currentState: connectivityState(), callback: callback)) + public func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) { + connectivityObserver.addConnectivityObserver(callback: callback) } } @@ -136,18 +134,16 @@ private extension Channel { private let completionQueue: CompletionQueue private let underlyingChannel: UnsafeMutableRawPointer private let underlyingCompletionQueue: UnsafeMutableRawPointer - private let callback: (ConnectivityState) -> Void - private var lastState: ConnectivityState + private var callbacks = [(ConnectivityState) -> Void]() private var hasBeenShutdown = false - private let stateMutex: Mutex = Mutex() + private let stateMutex = Mutex() - init(underlyingChannel: UnsafeMutableRawPointer, currentState: ConnectivityState, callback: @escaping (ConnectivityState) -> ()) { + init(underlyingChannel: UnsafeMutableRawPointer) { self.underlyingChannel = underlyingChannel self.underlyingCompletionQueue = cgrpc_completion_queue_create_for_next() - self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, name: "Connectivity State") - self.callback = callback - self.lastState = currentState - run() + self.completionQueue = CompletionQueue(underlyingCompletionQueue: self.underlyingCompletionQueue, + name: "Connectivity State") + self.run() } deinit { @@ -156,19 +152,20 @@ private extension Channel { private func run() { let spinloopThreadQueue = DispatchQueue(label: "SwiftGRPC.ConnectivityObserver.run.spinloopThread") - + var lastState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0)) spinloopThreadQueue.async { while true { - guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else { + guard (self.stateMutex.synchronize { !self.hasBeenShutdown }) else { return } - - guard let underlyingState = self.lastState.underlyingState else { return } + + guard let underlyingState = lastState.underlyingState else { return } let deadline: TimeInterval = 0.2 - cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, underlyingState, deadline, nil) + cgrpc_channel_watch_connectivity_state(self.underlyingChannel, self.underlyingCompletionQueue, + underlyingState, deadline, nil) + let event = self.completionQueue.wait(timeout: deadline) - guard (self.stateMutex.synchronize{ !self.hasBeenShutdown }) else { return } @@ -176,11 +173,12 @@ private extension Channel { switch event.type { case .complete: let newState = ConnectivityState(cgrpc_channel_check_connectivity_state(self.underlyingChannel, 0)) + guard newState != lastState else { continue } - if newState != self.lastState { - self.callback(newState) + lastState = newState + self.stateMutex.synchronize { + self.callbacks.forEach { callback in callback(newState) } } - self.lastState = newState case .queueShutdown: return @@ -192,6 +190,12 @@ private extension Channel { } } + func addConnectivityObserver(callback: @escaping (ConnectivityState) -> Void) { + self.stateMutex.synchronize { + self.callbacks.append(callback) + } + } + func shutdown() { stateMutex.synchronize { hasBeenShutdown = true diff --git a/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift b/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift index 2271e15cf..da1165d01 100644 --- a/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift +++ b/Tests/SwiftGRPCTests/ChannelConnectivityTests.swift @@ -21,7 +21,8 @@ final class ChannelConnectivityTests: BasicEchoTestCase { static var allTests: [(String, (ChannelConnectivityTests) -> () throws -> Void)] { return [ - ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash) + ("testDanglingConnectivityObserversDontCrash", testDanglingConnectivityObserversDontCrash), + ("testMultipleConnectivityObserversAreCalled", testMultipleConnectivityObserversAreCalled), ] } } @@ -30,12 +31,12 @@ extension ChannelConnectivityTests { func testDanglingConnectivityObserversDontCrash() { let completionHandlerExpectation = expectation(description: "completion handler called") - client?.channel.subscribe { connectivityState in + client.channel.addConnectivityObserver { connectivityState in print("ConnectivityState: \(connectivityState)") } let request = Echo_EchoRequest(text: "foo bar baz foo bar baz") - _ = try! client!.expand(request) { callResult in + _ = try! client.expand(request) { callResult in print("callResult.statusCode: \(callResult.statusCode)") completionHandlerExpectation.fulfill() } @@ -46,4 +47,21 @@ extension ChannelConnectivityTests { waitForExpectations(timeout: 0.5) } + + func testMultipleConnectivityObserversAreCalled() { + let completionHandlerExpectation = expectation(description: "completion handler called") + var firstObserverCalled = false + var secondObserverCalled = false + + client.channel.addConnectivityObserver { _ in firstObserverCalled = true } + client.channel.addConnectivityObserver { _ in secondObserverCalled = true } + + _ = try! client.expand(Echo_EchoRequest(text: "foo bar baz foo bar baz")) { _ in + completionHandlerExpectation.fulfill() + } + + waitForExpectations(timeout: 0.5) + XCTAssertTrue(firstObserverCalled) + XCTAssertTrue(secondObserverCalled) + } }