diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 5e4ae6e01..77e4835f6 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -35,6 +35,9 @@ final class HTTP2Connection { let multiplexer: HTTP2StreamMultiplexer let logger: Logger + /// A method with access to the stream channel that is called when creating the stream. + let streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + /// the connection pool that created the connection let delegate: HTTP2ConnectionDelegate @@ -95,7 +98,8 @@ final class HTTP2Connection { decompression: HTTPClient.Decompression, maximumConnectionUses: Int?, delegate: HTTP2ConnectionDelegate, - logger: Logger + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil ) { self.channel = channel self.id = connectionID @@ -114,6 +118,7 @@ final class HTTP2Connection { ) self.delegate = delegate self.state = .initialized + self.streamChannelDebugInitializer = streamChannelDebugInitializer } deinit { @@ -128,7 +133,8 @@ final class HTTP2Connection { delegate: HTTP2ConnectionDelegate, decompression: HTTPClient.Decompression, maximumConnectionUses: Int?, - logger: Logger + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil ) -> EventLoopFuture<(HTTP2Connection, Int)> { let connection = HTTP2Connection( channel: channel, @@ -136,7 +142,8 @@ final class HTTP2Connection { decompression: decompression, maximumConnectionUses: maximumConnectionUses, delegate: delegate, - logger: logger + logger: logger, + streamChannelDebugInitializer: streamChannelDebugInitializer ) return connection._start0().map { maxStreams in (connection, maxStreams) } } @@ -259,8 +266,14 @@ final class HTTP2Connection { self.openStreams.remove(box) } - channel.write(request, promise: nil) - return channel.eventLoop.makeSucceededVoidFuture() + if let streamChannelDebugInitializer = self.streamChannelDebugInitializer { + return streamChannelDebugInitializer(channel).map { _ in + channel.write(request, promise: nil) + } + } else { + channel.write(request, promise: nil) + return channel.eventLoop.makeSucceededVoidFuture() + } } catch { return channel.eventLoop.makeFailedFuture(error) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 32af23830..9a3d66a3a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -84,7 +84,19 @@ extension HTTPConnectionPool.ConnectionFactory { decompression: self.clientConfiguration.decompression, logger: logger ) - requester.http1ConnectionCreated(connection) + + if let connectionDebugInitializer = self.clientConfiguration.http1_1ConnectionDebugInitializer { + connectionDebugInitializer(channel).whenComplete { debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http1ConnectionCreated(connection) + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + } else { + requester.http1ConnectionCreated(connection) + } } catch { requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -95,11 +107,34 @@ extension HTTPConnectionPool.ConnectionFactory { delegate: http2ConnectionDelegate, decompression: self.clientConfiguration.decompression, maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection, - logger: logger + logger: logger, + streamChannelDebugInitializer: + self.clientConfiguration.http2StreamChannelDebugInitializer ).whenComplete { result in switch result { case .success((let connection, let maximumStreams)): - requester.http2ConnectionCreated(connection, maximumStreams: maximumStreams) + if let connectionDebugInitializer = self.clientConfiguration.http2ConnectionDebugInitializer { + connectionDebugInitializer(channel).whenComplete { + debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http2ConnectionCreated( + connection, + maximumStreams: maximumStreams + ) + case .failure(let error): + requester.failedToCreateHTTPConnection( + connectionID, + error: error + ) + } + } + } else { + requester.http2ConnectionCreated( + connection, + maximumStreams: maximumStreams + ) + } case .failure(let error): requester.failedToCreateHTTPConnection(connectionID, error: error) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index eebe4d029..ebcecbdc5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -324,7 +324,9 @@ final class HTTPConnectionPool: connection.executeRequest(request.req) case .executeRequests(let requests, let connection): - for request in requests { connection.executeRequest(request.req) } + for request in requests { + connection.executeRequest(request.req) + } case .failRequest(let request, let error): request.req.fail(error) diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index f1655c7c5..ff222bd6f 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -847,6 +847,15 @@ public class HTTPClient { /// By default, don't use it public var enableMultipath: Bool + /// A method with access to the HTTP/1 connection channel that is called when creating the connection. + public var http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 connection channel that is called when creating the connection. + public var http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 stream channel that is called when creating the stream. + public var http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + public init( tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, @@ -949,6 +958,32 @@ public class HTTPClient { decompression: decompression ) } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer + self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer + self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer + } } /// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership. diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 546d1c3f4..360632cdd 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -4306,4 +4306,174 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { request.setBasicAuth(username: "foo", password: "bar") XCTAssertEqual(request.headers.first(name: "Authorization"), "Basic Zm9vOmJhcg==") } + + func runBaseTestForHTTP1ConnectionDebugInitializer(ssl: Bool) { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http1_1ConnectionDebugInitializer` (rather than manually + // modifying `config`) to ensure that the matching `init` actually wires up this argument + // with the respective property. This is necessary as these parameters are defaulted and can + // be easy to miss. + var config = HTTPClient.Configuration( + http1_1ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + } + ) + config.httpVersion = .http1Only + + if ssl { + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + } + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http1_1(ssl: ssl, compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let scheme = ssl ? "https" : "http" + + for _ in 0..<3 { + XCTAssertNoThrow( + try clientWithHigherTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) + } + + // Even though multiple requests were made, the connection debug initializer must be called + // only once. + XCTAssertEqual(connectionDebugInitializerUtil.executionCount, 1) + + let lowerConnectTimeout = CountingDebugInitializerUtil.duration - .milliseconds(100) + var configWithLowerTimeout = config + configWithLowerTimeout.timeout = .init(connect: lowerConnectTimeout) + + let clientWithLowerTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithLowerTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithLowerTimeout.syncShutdown()) } + + XCTAssertThrowsError( + try clientWithLowerTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) { + XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + } + } + + func testHTTP1PlainTextConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: false) + } + + func testHTTP1EncryptedConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: true) + } + + func testHTTP2ConnectionAndStreamChannelDebugInitializers() { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + let streamChannelDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http2ConnectionDebugInitializer` and + // `http2StreamChannelDebugInitializer` (rather than manually modifying `config`) to ensure + // that the matching `init` actually wires up these arguments with the respective + // properties. This is necessary as these parameters are defaulted and can be easy to miss. + var config = HTTPClient.Configuration( + http2ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + }, + http2StreamChannelDebugInitializer: { channel in + streamChannelDebugInitializerUtil.initialize(channel: channel) + } + ) + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + config.httpVersion = .automatic + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let numberOfRequests = 3 + + for _ in 0..(0) + var executionCount: Int { self._executionCount.withLockedValue { $0 } } + + /// The minimum time to spend running the debug initializer. + static let duration: TimeAmount = .milliseconds(300) + + /// The actual debug initializer. + func initialize(channel: Channel) -> EventLoopFuture { + self._executionCount.withLockedValue { $0 += 1 } + + let someScheduledTask = channel.eventLoop.scheduleTask(in: Self.duration) { + channel.eventLoop.makeSucceededVoidFuture() + } + + return someScheduledTask.futureResult.flatMap { $0 } + } }