diff --git a/NOTICE.txt b/NOTICE.txt index 473df88..b74911c 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -32,3 +32,21 @@ This product contains derivations of various scripts from SwiftNIO. * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/apple/swift-nio + +------------------------------------------------------------------------------- + +This product contains AsyncSequence implementations from Swift Async Algorithms. + + * LICENSE (Apache License 2.0): + * https://github.com/apple/swift-async-algorithms/blob/main/LICENSE.txt + * HOMEPAGE: + * https://github.com/apple/swift-async-algorithms + +------------------------------------------------------------------------------- + +This product contains AsyncSequence implementations from Swift. + + * LICENSE (Apache License 2.0): + * https://github.com/apple/swift/blob/main/LICENSE.txt + * HOMEPAGE: + * https://github.com/apple/swift diff --git a/Package.swift b/Package.swift index f1816f0..de0994e 100644 --- a/Package.swift +++ b/Package.swift @@ -24,8 +24,16 @@ swiftSettings.append( // Require `any` for existential types. .enableUpcomingFeature("ExistentialAny") ) + +// Strict concurrency is enabled in CI; use this environment variable to enable it locally. +if ProcessInfo.processInfo.environment["SWIFT_OPENAPI_STRICT_CONCURRENCY"].flatMap(Bool.init) ?? false { + swiftSettings.append(contentsOf: [ + .define("SWIFT_OPENAPI_STRICT_CONCURRENCY"), .enableExperimentalFeature("StrictConcurrency"), + ]) +} #endif + let package = Package( name: "swift-openapi-urlsession", platforms: [ @@ -40,19 +48,29 @@ let package = Package( dependencies: [ .package(url: "https://github.com/apple/swift-openapi-runtime", .upToNextMinor(from: "0.3.0")), .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-collections", from: "1.0.0"), ], targets: [ .target( name: "OpenAPIURLSession", dependencies: [ + .product(name: "DequeModule", package: "swift-collections"), .product(name: "OpenAPIRuntime", package: "swift-openapi-runtime"), ], swiftSettings: swiftSettings ), .testTarget( name: "OpenAPIURLSessionTests", - dependencies: ["OpenAPIURLSession"], + dependencies: [ + "OpenAPIURLSession", + .product(name: "NIOTestUtils", package: "swift-nio"), + ], swiftSettings: swiftSettings ), ] ) + +// Test-only dependencies. +package.dependencies += [ + .package(url: "https://github.com/apple/swift-nio", from: "2.62.0") +] diff --git a/README.md b/README.md index 2b0ae2d..138c8c9 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,14 @@ A client transport that uses the [URLSession](https://developer.apple.com/docume Use the transport with client code generated by [Swift OpenAPI Generator](https://github.com/apple/swift-openapi-generator). ## Supported platforms and minimum versions - | macOS | Linux | iOS | tvOS | watchOS | - | :-: | :-: | :-: | :-: | :-: | - | ✅ 10.15+ | ✅ | ✅ 13+ | ✅ 13+ | ✅ 6+ | + +| macOS | Linux | iOS | tvOS | watchOS | +| :-: | :-: | :-: | :-: | :-: | +| ✅ 10.15+ | ✅ | ✅ 13+ | ✅ 13+ | ✅ 6+ | + +Note: Streaming support only available on macOS 12+, iOS 15+, tvOS 15+, and +watchOS 8+.For streaming support on Linux, please use the [AsyncHTTPClient +Transport](https://github.com/swift-server/swift-openapi-async-http-client) ## Usage diff --git a/Sources/OpenAPIURLSession/AsyncBackpressuredStream/AsyncBackpressuredStream.swift b/Sources/OpenAPIURLSession/AsyncBackpressuredStream/AsyncBackpressuredStream.swift new file mode 100644 index 0000000..6aa7cd8 --- /dev/null +++ b/Sources/OpenAPIURLSession/AsyncBackpressuredStream/AsyncBackpressuredStream.swift @@ -0,0 +1,1324 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore-file +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020-2021 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +import DequeModule + +struct AsyncBackpressuredStream: Sendable { + /// A mechanism to interface between producer code and an asynchronous stream. + /// + /// Use this source to provide elements to the stream by calling one of the `write` methods, then terminate the stream normally + /// by calling the `finish()` method. You can also use the source's `finish(throwing:)` method to terminate the stream by + /// throwing an error. + struct Source: Sendable { + /// A strategy that handles the back pressure of the asynchronous stream. + struct BackPressureStrategy: Sendable { + var internalBackPressureStrategy: InternalBackPressureStrategy + + /// When the high water mark is reached producers will be suspended. All producers will be resumed again once + /// the low water mark is reached. + static func highLowWatermark(lowWatermark: Int, highWatermark: Int) -> BackPressureStrategy { + .init( + internalBackPressureStrategy: .highLowWatermark( + .init(lowWatermark: lowWatermark, highWatermark: highWatermark) + ) + ) + } + + /// When the high water mark is reached producers will be suspended. All producers will be resumed again once + /// the low water mark is reached. When `usingElementCounts` is true, the counts of the element types will + /// be used to compute the watermark. + static func highLowWatermarkWithElementCounts(lowWatermark: Int, highWatermark: Int) + -> BackPressureStrategy where Element: RandomAccessCollection + { + .init( + internalBackPressureStrategy: .highLowWatermark( + .init( + lowWatermark: lowWatermark, + highWatermark: highWatermark, + waterLevelForElement: { $0.count } + ) + ) + ) + } + } + + /// A type that indicates the result of writing elements to the source. + enum WriteResult: Sendable { + /// A token that is returned when the asynchronous stream's back pressure strategy indicated that any producer should + /// be suspended. Use this token to enqueue a callback by calling the ``enqueueCallback(_:)`` method. + struct WriteToken: Sendable { + let id: UInt + + init(id: UInt) { self.id = id } + } + /// Indicates that more elements should be produced and written to the source. + case produceMore + /// Indicates that a callback should be enqueued. + /// + /// The associated token should be passed to the ``enqueueCallback(_:)`` method. + case enqueueCallback(WriteToken) + } + + private var storage: Storage + + init(storage: Storage) { self.storage = storage } + + /// Write new elements to the asynchronous stream. + /// + /// If there is a task consuming the stream and awaiting the next element then the task will get resumed with the + /// first element of the provided sequence. If the asynchronous stream already terminated then this method will throw an error + /// indicating the failure. + /// + /// - Parameter sequence: The elements to write to the asynchronous stream. + /// - Returns: The result that indicates if more elements should be produced at this time. + func write(contentsOf sequence: S) throws -> WriteResult where S.Element == Element { + try self.storage.write(contentsOf: sequence) + } + + /// Enqueues a callback that will be invoked once more elements should be produced. + /// + /// Call this method after ``write(contentsOf:)`` returned a ``WriteResult/enqueueCallback(_:)``. + /// + /// - Parameters: + /// - writeToken: The write token produced by ``write(contentsOf:)``. + /// - onProduceMore: The callback which gets invoked once more elements should be produced. + func enqueueCallback( + writeToken: WriteResult.WriteToken, + onProduceMore: @escaping @Sendable (Result) -> Void + ) { self.storage.enqueueProducer(writeToken: writeToken, onProduceMore: onProduceMore) } + + /// Cancel an enqueued callback. + /// + /// Call this method to cancel a callback enqueued by the ``enqueueCallback(writeToken:onProduceMore:)`` method. + /// + /// > Note: This methods supports being called before ``enqueueCallback(writeToken:onProduceMore:)`` is called and + /// will mark the passed `writeToken` as cancelled. + /// - Parameter writeToken: The write token produced by ``write(contentsOf:)``. + func cancelCallback(writeToken: WriteResult.WriteToken) { + self.storage.cancelProducer(writeToken: writeToken) + } + + /// Write new elements to the asynchronous stream and provide a callback which will be invoked once more elements should be produced. + /// + /// - Parameters: + /// - sequence: The elements to write to the asynchronous stream. + /// - onProduceMore: The callback which gets invoked once more elements should be produced. This callback might be + /// invoked during the call to ``write(contentsOf:onProduceMore:)``. + func write( + contentsOf sequence: S, + onProduceMore: @escaping @Sendable (Result) -> Void + ) where S.Element == Element { + do { + let writeResult = try self.write(contentsOf: sequence) + + switch writeResult { + case .produceMore: onProduceMore(.success(())) + + case .enqueueCallback(let writeToken): + self.enqueueCallback(writeToken: writeToken, onProduceMore: onProduceMore) + } + } catch { onProduceMore(.failure(error)) } + } + + /// Write new elements to the asynchronous stream. + /// + /// This method returns once more elements should be produced. + /// + /// - Parameters: + /// - sequence: The elements to write to the asynchronous stream. + func asyncWrite(contentsOf sequence: S) async throws where S.Element == Element { + let writeResult = try self.write(contentsOf: sequence) + + switch writeResult { + case .produceMore: return + + case .enqueueCallback(let writeToken): + try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in + self.enqueueCallback( + writeToken: writeToken, + onProduceMore: { result in + switch result { + case .success(): continuation.resume(returning: ()) + case .failure(let error): continuation.resume(throwing: error) + } + } + ) + } + } onCancel: { + self.cancelCallback(writeToken: writeToken) + } + + } + } + + func finish(throwing failure: Failure?) { self.storage.finish(failure) } + } + + private var storage: Storage + + init(storage: Storage) { self.storage = storage } + + static func makeStream( + of elementType: Element.Type = Element.self, + backPressureStrategy: Source.BackPressureStrategy, + onTermination: (@Sendable () -> Void)? = nil + ) -> (Self, Source) where Failure == any Error { + let storage = Storage( + backPressureStrategy: backPressureStrategy.internalBackPressureStrategy, + onTerminate: onTermination + ) + let source = Source(storage: storage) + + return (.init(storage: storage), source) + } +} + +extension AsyncBackpressuredStream: AsyncSequence { + struct AsyncIterator: AsyncIteratorProtocol { + private var storage: Storage + + init(storage: Storage) { self.storage = storage } + + mutating func next() async throws -> Element? { return try await storage.next() } + } + + func makeAsyncIterator() -> AsyncIterator { return AsyncIterator(storage: self.storage) } +} + +extension AsyncBackpressuredStream { + struct HighLowWatermarkBackPressureStrategy { + private let lowWatermark: Int + private let highWatermark: Int + private(set) var currentWatermark: Int + + typealias CustomWaterLevelForElement = @Sendable (Element) -> Int + private let waterLevelForElement: CustomWaterLevelForElement? + + /// Initializes a new ``HighLowWatermarkBackPressureStrategy``. + /// + /// - Parameters: + /// - lowWatermark: The low watermark where demand should start. + /// - highWatermark: The high watermark where demand should be stopped. + init(lowWatermark: Int, highWatermark: Int, waterLevelForElement: CustomWaterLevelForElement? = nil) { + precondition(lowWatermark <= highWatermark, "Low watermark must be <= high watermark") + self.lowWatermark = lowWatermark + self.highWatermark = highWatermark + self.currentWatermark = 0 + self.waterLevelForElement = waterLevelForElement + } + + mutating func didYield(elements: Deque.SubSequence) -> Bool { + if let waterLevelForElement { + self.currentWatermark += elements.reduce(0) { $0 + waterLevelForElement($1) } + } else { + self.currentWatermark += elements.count + } + precondition(self.currentWatermark >= 0, "Watermark below zero") + // We are demanding more until we reach the high watermark + return self.currentWatermark < self.highWatermark + } + + mutating func didConsume(elements: Deque.SubSequence) -> Bool { + if let waterLevelForElement { + self.currentWatermark -= elements.reduce(0) { $0 + waterLevelForElement($1) } + } else { + self.currentWatermark -= elements.count + } + precondition(self.currentWatermark >= 0, "Watermark below zero") + // We start demanding again once we are below the low watermark + return self.currentWatermark < self.lowWatermark + } + + mutating func didConsume(element: Element) -> Bool { + if let waterLevelForElement { + self.currentWatermark -= waterLevelForElement(element) + } else { + self.currentWatermark -= 1 + } + precondition(self.currentWatermark >= 0, "Watermark below zero") + // We start demanding again once we are below the low watermark + return self.currentWatermark < self.lowWatermark + } + } + + enum InternalBackPressureStrategy { + case highLowWatermark(HighLowWatermarkBackPressureStrategy) + + mutating func didYield(elements: Deque.SubSequence) -> Bool { + switch self { + case .highLowWatermark(var strategy): + let result = strategy.didYield(elements: elements) + self = .highLowWatermark(strategy) + return result + } + } + + mutating func didConsume(elements: Deque.SubSequence) -> Bool { + switch self { + case .highLowWatermark(var strategy): + let result = strategy.didConsume(elements: elements) + self = .highLowWatermark(strategy) + return result + } + } + + mutating func didConsume(element: Element) -> Bool { + switch self { + case .highLowWatermark(var strategy): + let result = strategy.didConsume(element: element) + self = .highLowWatermark(strategy) + return result + } + } + } +} + +extension AsyncBackpressuredStream { + final class Storage: @unchecked Sendable { + /// The lock that protects the state machine and the nextProducerID. + let lock = NIOLock() + + /// The state machine. + var stateMachine: StateMachine + + /// The next producer's id. + var nextProducerID: UInt = 0 + + init(backPressureStrategy: InternalBackPressureStrategy, onTerminate: (() -> Void)?) { + self.stateMachine = .init(backPressureStrategy: backPressureStrategy, onTerminate: onTerminate) + } + + func sequenceDeinitialized() { + let onTerminate = self.lock.withLock { + let action = self.stateMachine.sequenceDeinitialized() + + switch action { + case .callOnTerminate(let onTerminate): + // We have to call onTerminate without the lock to avoid potential deadlocks + return onTerminate + + case .none: return nil + } + } + + onTerminate?() + } + + func iteratorInitialized() { self.lock.withLock { self.stateMachine.iteratorInitialized() } } + + func iteratorDeinitialized() { + let onTerminate = self.lock.withLock { + let action = self.stateMachine.iteratorDeinitialized() + + switch action { + case .callOnTerminate(let onTerminate): + // We have to call onTerminate without the lock to avoid potential deadlocks + return onTerminate + + case .none: return nil + } + } + + onTerminate?() + } + + func write(contentsOf sequence: S) throws -> Source.WriteResult where S.Element == Element { + let action = self.lock.withLock { return self.stateMachine.write(sequence) } + + switch action { + case .returnProduceMore: return .produceMore + + case .returnEnqueue: + // TODO: Move the id into the state machine or use an atomic + let id = self.lock.withLock { + let id = self.nextProducerID + self.nextProducerID += 1 + return id + } + return .enqueueCallback(.init(id: id)) + + case .resumeConsumerContinuationAndReturnProduceMore(let continuation, let element): + continuation.resume(returning: element) + return .produceMore + + case .resumeConsumerContinuationAndReturnEnqueue(let continuation, let element): + continuation.resume(returning: element) + // TODO: Move the id into the state machine or use an atomic + let id = self.lock.withLock { + let id = self.nextProducerID + self.nextProducerID += 1 + return id + } + return .enqueueCallback(.init(id: id)) + + case .throwFinishedError: + // TODO: Introduce new Error + throw CancellationError() + } + } + + func enqueueProducer( + writeToken: Source.WriteResult.WriteToken, + onProduceMore: @escaping @Sendable (Result) -> Void + ) { + let action = self.lock.withLock { + return self.stateMachine.enqueueProducer(writeToken: writeToken, onProduceMore: onProduceMore) + } + + switch action { + case .resumeProducer(let onProduceMore): onProduceMore(.success(())) + + case .resumeProducerWithCancellationError(let onProduceMore): onProduceMore(.failure(CancellationError())) + + case .none: break + } + } + + func cancelProducer(writeToken: Source.WriteResult.WriteToken) { + let action = self.lock.withLock { return self.stateMachine.cancelProducer(writeToken: writeToken) } + + switch action { + case .resumeProducerWithCancellationError(let onProduceMore): onProduceMore(.failure(CancellationError())) + + case .none: break + } + } + + func finish(_ failure: Failure?) { + let onTerminate = self.lock.withLock { + let action = self.stateMachine.finish(failure) + + switch action { + case .resumeAllContinuationsAndCallOnTerminate( + let consumerContinuation, + let failure, + let producerContinuations, + let onTerminate + ): + // It is safe to resume the continuation while holding the lock + // since the task will get enqueued on its executor and the resume method + // is returning immediately + switch failure { + case .some(let error): consumerContinuation.resume(throwing: error) + case .none: consumerContinuation.resume(returning: nil) + } + + for producerContinuation in producerContinuations { + // TODO: Throw a new cancelled error + producerContinuation(.failure(CancellationError())) + } + + return onTerminate + + case .resumeProducerContinuations(let producerContinuations): + for producerContinuation in producerContinuations { + // TODO: Throw a new cancelled error + producerContinuation(.failure(CancellationError())) + } + + return nil + + case .none: return nil + } + } + + onTerminate?() + } + + func next() async throws -> Element? { + let action = self.lock.withLock { return self.stateMachine.next() } + + switch action { + case .returnElement(let element): return element + + case .returnElementAndResumeProducers(let element, let producerContinuations): + for producerContinuation in producerContinuations { producerContinuation(.success(())) } + + return element + + case .returnFailureAndCallOnTerminate(let failure, let onTerminate): + onTerminate?() + switch failure { + case .some(let error): throw error + + case .none: return nil + } + + case .returnNil: return nil + + case .suspendTask: return try await suspendNext() + } + } + + func suspendNext() async throws -> Element? { + return try await withTaskCancellationHandler { + return try await withCheckedThrowingContinuation { continuation in + let action = self.lock.withLock { return self.stateMachine.suspendNext(continuation: continuation) } + + switch action { + case .resumeContinuationWithElement(let continuation, let element): + continuation.resume(returning: element) + + case .resumeContinuationWithElementAndProducers( + let continuation, + let element, + let producerContinuations + ): + continuation.resume(returning: element) + for producerContinuation in producerContinuations { producerContinuation(.success(())) } + + case .resumeContinuationWithFailureAndCallOnTerminate( + let continuation, + let failure, + let onTerminate + ): + onTerminate?() + switch failure { + case .some(let error): continuation.resume(throwing: error) + + case .none: continuation.resume(returning: nil) + } + + case .resumeContinuationWithNil(let continuation): continuation.resume(returning: nil) + + case .none: break + } + } + } onCancel: { + self.lock.withLockVoid { + let action = self.stateMachine.cancelNext() + + switch action { + case .resumeContinuationWithCancellationErrorAndFinishProducersAndCallOnTerminate( + let continuation, + let producerContinuations, + let onTerminate + ): + onTerminate?() + continuation.resume(throwing: CancellationError()) + for producerContinuation in producerContinuations { + // TODO: Throw a new cancelled error + producerContinuation(.failure(CancellationError())) + } + + case .finishProducersAndCallOnTerminate(let producerContinuations, let onTerminate): + onTerminate?() + for producerContinuation in producerContinuations { + // TODO: Throw a new cancelled error + producerContinuation(.failure(CancellationError())) + } + + case .none: break + } + } + } + } + } +} + +extension AsyncBackpressuredStream { + struct StateMachine { + enum State { + case initial( + backPressureStrategy: InternalBackPressureStrategy, + iteratorInitialized: Bool, + onTerminate: (() -> Void)? + ) + + /// The state once either any element was yielded or `next()` was called. + case streaming( + backPressureStrategy: InternalBackPressureStrategy, + buffer: Deque, + consumerContinuation: CheckedContinuation?, + producerContinuations: Deque<(UInt, (Result) -> Void)>, + cancelledAsyncProducers: Deque, + hasOutstandingDemand: Bool, + iteratorInitialized: Bool, + onTerminate: (() -> Void)? + ) + + /// The state once the underlying source signalled that it is finished. + case sourceFinished( + buffer: Deque, + iteratorInitialized: Bool, + failure: Failure?, + onTerminate: (() -> Void)? + ) + + /// The state once there can be no outstanding demand. This can happen if: + /// 1. The iterator was deinited + /// 2. The underlying source finished and all buffered elements have been consumed + case finished(iteratorInitialized: Bool) + } + + /// The state machine's current state. + var state: State + + var producerContinuationCounter: UInt = 0 + + /// Initializes a new `StateMachine`. + /// + /// We are passing and holding the back-pressure strategy here because + /// it is a customizable extension of the state machine. + /// + /// - Parameter backPressureStrategy: The back-pressure strategy. + init(backPressureStrategy: InternalBackPressureStrategy, onTerminate: (() -> Void)?) { + self.state = .initial( + backPressureStrategy: backPressureStrategy, + iteratorInitialized: false, + onTerminate: onTerminate + ) + } + + /// Actions returned by `sequenceDeinitialized()`. + enum SequenceDeinitializedAction { + /// Indicates that `onTerminate` should be called. + case callOnTerminate((() -> Void)?) + /// Indicates that nothing should be done. + case none + } + + mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { + switch self.state { + case .initial(_, iteratorInitialized: false, let onTerminate), + .streaming(_, _, _, _, _, _, iteratorInitialized: false, let onTerminate), + .sourceFinished(_, iteratorInitialized: false, _, let onTerminate): + // No iterator was created so we can transition to finished right away. + self.state = .finished(iteratorInitialized: false) + + return .callOnTerminate(onTerminate) + + case .initial(_, iteratorInitialized: true, _), .streaming(_, _, _, _, _, _, iteratorInitialized: true, _), + .sourceFinished(_, iteratorInitialized: true, _, _): + // An iterator was created and we deinited the sequence. + // This is an expected pattern and we just continue on normal. + return .none + + case .finished: + // We are already finished so there is nothing left to clean up. + // This is just the references dropping afterwards. + return .none + } + } + + mutating func iteratorInitialized() { + switch self.state { + case .initial(_, iteratorInitialized: true, _), .streaming(_, _, _, _, _, _, iteratorInitialized: true, _), + .sourceFinished(_, iteratorInitialized: true, _, _), .finished(iteratorInitialized: true): + // Our sequence is a unicast sequence and does not support multiple AsyncIterator's + fatalError("Only a single AsyncIterator can be created") + + case .initial(let backPressureStrategy, iteratorInitialized: false, let onTerminate): + // The first and only iterator was initialized. + self.state = .initial( + backPressureStrategy: backPressureStrategy, + iteratorInitialized: true, + onTerminate: onTerminate + ) + + case .streaming( + let backPressureStrategy, + let buffer, + let consumerContinuation, + let producerContinuations, + let cancelledAsyncProducers, + let hasOutstandingDemand, + false, + let onTerminate + ): + // The first and only iterator was initialized. + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: true, + onTerminate: onTerminate + ) + + case .sourceFinished(let buffer, false, let failure, let onTerminate): + // The first and only iterator was initialized. + self.state = .sourceFinished( + buffer: buffer, + iteratorInitialized: true, + failure: failure, + onTerminate: onTerminate + ) + + case .finished(iteratorInitialized: false): + // It is strange that an iterator is created after we are finished + // but it can definitely happen, e.g. + // Sequence.init -> source.finish -> sequence.makeAsyncIterator + self.state = .finished(iteratorInitialized: true) + } + } + + /// Actions returned by `iteratorDeinitialized()`. + enum IteratorDeinitializedAction { + /// Indicates that `onTerminate` should be called. + case callOnTerminate((() -> Void)?) + /// Indicates that nothing should be done. + case none + } + + mutating func iteratorDeinitialized() -> IteratorDeinitializedAction { + switch self.state { + case .initial(_, iteratorInitialized: false, _), + .streaming(_, _, _, _, _, _, iteratorInitialized: false, _), + .sourceFinished(_, iteratorInitialized: false, _, _): + // An iterator needs to be initialized before it can be deinitialized. + preconditionFailure("Internal inconsistency") + + case .initial(_, iteratorInitialized: true, let onTerminate), + .streaming(_, _, _, _, _, _, iteratorInitialized: true, let onTerminate), + .sourceFinished(_, iteratorInitialized: true, _, let onTerminate): + // An iterator was created and deinited. Since we only support + // a single iterator we can now transition to finish and inform the delegate. + self.state = .finished(iteratorInitialized: true) + + return .callOnTerminate(onTerminate) + + case .finished: + // We are already finished so there is nothing left to clean up. + // This is just the references dropping afterwards. + return .none + } + } + + /// Actions returned by `yield()`. + enum WriteAction { + /// Indicates that the producer should be notified to produce more. + case returnProduceMore + /// Indicates that the producer should be suspended to stop producing. + case returnEnqueue + /// Indicates that the consumer continuation should be resumed and the producer should be notified to produce more. + case resumeConsumerContinuationAndReturnProduceMore( + continuation: CheckedContinuation, + element: Element + ) + /// Indicates that the consumer continuation should be resumed and the producer should be suspended. + case resumeConsumerContinuationAndReturnEnqueue( + continuation: CheckedContinuation, + element: Element + ) + /// Indicates that the producer has been finished. + case throwFinishedError + + init( + shouldProduceMore: Bool, + continuationAndElement: (CheckedContinuation, Element)? = nil + ) { + switch (shouldProduceMore, continuationAndElement) { + case (true, .none): self = .returnProduceMore + + case (false, .none): self = .returnEnqueue + + case (true, .some((let continuation, let element))): + self = .resumeConsumerContinuationAndReturnProduceMore(continuation: continuation, element: element) + + case (false, .some((let continuation, let element))): + self = .resumeConsumerContinuationAndReturnEnqueue(continuation: continuation, element: element) + } + } + } + + mutating func write(_ sequence: S) -> WriteAction where S.Element == Element { + switch self.state { + case .initial(var backPressureStrategy, let iteratorInitialized, let onTerminate): + let buffer = Deque(sequence) + let shouldProduceMore = backPressureStrategy.didYield(elements: buffer[...]) + + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: .init(), + cancelledAsyncProducers: .init(), + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .init(shouldProduceMore: shouldProduceMore) + + case .streaming( + var backPressureStrategy, + var buffer, + .some(let consumerContinuation), + let producerContinuations, + let cancelledAsyncProducers, + let hasOutstandingDemand, + let iteratorInitialized, + let onTerminate + ): + // The buffer should always be empty if we hold a continuation + precondition(buffer.isEmpty, "Expected an empty buffer") + + let bufferEndIndexBeforeAppend = buffer.endIndex + buffer.append(contentsOf: sequence) + _ = backPressureStrategy.didYield(elements: buffer[bufferEndIndexBeforeAppend...]) + + guard let element = buffer.popFirst() else { + // We got a yield of an empty sequence. We just tolerate this. + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + return .init(shouldProduceMore: hasOutstandingDemand) + } + + // We have an element and can resume the continuation + + let shouldProduceMore = backPressureStrategy.didConsume(element: element) + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, // Setting this to nil since we are resuming the continuation + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .init( + shouldProduceMore: shouldProduceMore, + continuationAndElement: (consumerContinuation, element) + ) + + case .streaming( + var backPressureStrategy, + var buffer, + consumerContinuation: .none, + let producerContinuations, + let cancelledAsyncProducers, + _, + let iteratorInitialized, + let onTerminate + ): + let bufferEndIndexBeforeAppend = buffer.endIndex + buffer.append(contentsOf: sequence) + let shouldProduceMore = backPressureStrategy.didYield(elements: buffer[bufferEndIndexBeforeAppend...]) + + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .init(shouldProduceMore: shouldProduceMore) + + case .sourceFinished, .finished: + // If the source has finished we are dropping the elements. + return .throwFinishedError + } + } + + /// Actions returned by `suspendYield()`. + @usableFromInline enum EnqueueProducerAction { + case resumeProducer((Result) -> Void) + case resumeProducerWithCancellationError((Result) -> Void) + case none + } + + @inlinable mutating func enqueueProducer( + writeToken: Source.WriteResult.WriteToken, + onProduceMore: @escaping (Result) -> Void + ) -> EnqueueProducerAction { + switch self.state { + case .initial: + // We need to transition to streaming before we can suspend + preconditionFailure("Internal inconsistency") + + case .streaming( + let backPressureStrategy, + let buffer, + let consumerContinuation, + var producerContinuations, + var cancelledAsyncProducers, + let hasOutstandingDemand, + let iteratorInitialized, + let onTerminate + ): + if let index = cancelledAsyncProducers.firstIndex(of: writeToken.id) { + cancelledAsyncProducers.remove(at: index) + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .resumeProducerWithCancellationError(onProduceMore) + } else if hasOutstandingDemand { + // We hit an edge case here where we yielded but got suspended afterwards + // and in-between yielding and suspending the yield we got consumption which lead us + // to produce more again. + return .resumeProducer(onProduceMore) + } else { + producerContinuations.append((writeToken.id, onProduceMore)) + + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .none + } + + case .sourceFinished, .finished: + // Since we are unlocking between yielding and suspending the yield + // It can happen that the source got finished or the consumption fully finishes. + return .none + } + } + + /// Actions returned by `cancelYield()`. + enum CancelYieldAction { + case resumeProducerWithCancellationError((Result) -> Void) + case none + } + + mutating func cancelProducer(writeToken: Source.WriteResult.WriteToken) -> CancelYieldAction { + switch self.state { + case .initial: + // We need to transition to streaming before we can suspend + preconditionFailure("Internal inconsistency") + + case .streaming( + let backPressureStrategy, + let buffer, + let consumerContinuation, + var producerContinuations, + var cancelledAsyncProducers, + let hasOutstandingDemand, + let iteratorInitialized, + let onTerminate + ): + guard let index = producerContinuations.firstIndex(where: { $0.0 == writeToken.id }) else { + // The task that yields was cancelled before yielding so the cancellation handler + // got invoked right away + cancelledAsyncProducers.append(writeToken.id) + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .none + } + let continuation = producerContinuations.remove(at: index).1 + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: consumerContinuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .resumeProducerWithCancellationError(continuation) + + case .sourceFinished, .finished: + // Since we are unlocking between yielding and suspending the yield + // It can happen that the source got finished or the consumption fully finishes. + return .none + } + } + + /// Actions returned by `finish()`. + @usableFromInline enum FinishAction { + /// Indicates that the consumer continuation should be resumed with the failure, the producer continuations + /// should be resumed with an error and `onTerminate` should be called. + case resumeAllContinuationsAndCallOnTerminate( + consumerContinuation: CheckedContinuation, + failure: Failure?, + producerContinuations: [(Result) -> Void], + onTerminate: (() -> Void)? + ) + /// Indicates that the producer continuations should be resumed with an error. + case resumeProducerContinuations(producerContinuations: [(Result) -> Void]) + /// Indicates that nothing should be done. + case none + } + + @inlinable mutating func finish(_ failure: Failure?) -> FinishAction { + switch self.state { + case .initial(_, let iteratorInitialized, let onTerminate): + // TODO: Should we call onTerminate here + // Nothing was yielded nor did anybody call next + // This means we can transition to sourceFinished and store the failure + self.state = .sourceFinished( + buffer: .init(), + iteratorInitialized: iteratorInitialized, + failure: failure, + onTerminate: onTerminate + ) + + return .none + + case .streaming( + _, + let buffer, + .some(let consumerContinuation), + let producerContinuations, + _, + _, + let iteratorInitialized, + let onTerminate + ): + // We have a continuation, this means our buffer must be empty + // Furthermore, we can now transition to finished + // and resume the continuation with the failure + precondition(buffer.isEmpty, "Expected an empty buffer") + + self.state = .finished(iteratorInitialized: iteratorInitialized) + + return .resumeAllContinuationsAndCallOnTerminate( + consumerContinuation: consumerContinuation, + failure: failure, + producerContinuations: Array(producerContinuations.map { $0.1 }), + onTerminate: onTerminate + ) + + case .streaming( + _, + let buffer, + consumerContinuation: .none, + let producerContinuations, + _, + _, + let iteratorInitialized, + let onTerminate + ): + self.state = .sourceFinished( + buffer: buffer, + iteratorInitialized: iteratorInitialized, + failure: failure, + onTerminate: onTerminate + ) + + return .resumeProducerContinuations(producerContinuations: Array(producerContinuations.map { $0.1 })) + + case .sourceFinished, .finished: + // If the source has finished, finishing again has no effect. + return .none + } + } + + /// Actions returned by `next()`. + enum NextAction { + /// Indicates that the element should be returned to the caller. + case returnElement(Element) + /// Indicates that the element should be returned to the caller and that all producers should be called. + case returnElementAndResumeProducers(Element, [(Result) -> Void]) + /// Indicates that the `Failure` should be returned to the caller and that `onTerminate` should be called. + case returnFailureAndCallOnTerminate(Failure?, (() -> Void)?) + /// Indicates that the `nil` should be returned to the caller. + case returnNil + /// Indicates that the `Task` of the caller should be suspended. + case suspendTask + } + + mutating func next() -> NextAction { + switch self.state { + case .initial(let backPressureStrategy, let iteratorInitialized, let onTerminate): + // We are not interacting with the back-pressure strategy here because + // we are doing this inside `next(:)` + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: Deque(), + consumerContinuation: nil, + producerContinuations: .init(), + cancelledAsyncProducers: .init(), + hasOutstandingDemand: false, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .suspendTask + + case .streaming(_, _, .some, _, _, _, _, _): + // We have multiple AsyncIterators iterating the sequence + preconditionFailure("This should never happen since we only allow a single Iterator to be created") + + case .streaming( + var backPressureStrategy, + var buffer, + .none, + var producerContinuations, + let cancelledAsyncProducers, + let hasOutstandingDemand, + let iteratorInitialized, + let onTerminate + ): + guard let element = buffer.popFirst() else { + // There is nothing in the buffer to fulfil the demand so we need to suspend. + // We are not interacting with the back-pressure strategy here because + // we are doing this inside `suspendNext` + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .suspendTask + } + // We have an element to fulfil the demand right away. + + let shouldProduceMore = backPressureStrategy.didConsume(element: element) + + guard shouldProduceMore else { + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + // We don't have any new demand, so we can just return the element. + return .returnElement(element) + } + let producers = Array(producerContinuations.map { $0.1 }) + producerContinuations.removeAll() + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + return .returnElementAndResumeProducers(element, producers) + + case .sourceFinished(var buffer, let iteratorInitialized, let failure, let onTerminate): + // Check if we have an element left in the buffer and return it + guard let element = buffer.popFirst() else { + // We are returning the queued failure now and can transition to finished + self.state = .finished(iteratorInitialized: iteratorInitialized) + + return .returnFailureAndCallOnTerminate(failure, onTerminate) + } + self.state = .sourceFinished( + buffer: buffer, + iteratorInitialized: iteratorInitialized, + failure: failure, + onTerminate: onTerminate + ) + + return .returnElement(element) + + case .finished: return .returnNil + } + } + + /// Actions returned by `suspendNext()`. + enum SuspendNextAction { + /// Indicates that the continuation should be resumed. + case resumeContinuationWithElement(CheckedContinuation, Element) + /// Indicates that the continuation and all producers should be resumed. + case resumeContinuationWithElementAndProducers( + CheckedContinuation, + Element, + [(Result) -> Void] + ) + /// Indicates that the continuation should be resumed with the failure and that `onTerminate` should be called. + case resumeContinuationWithFailureAndCallOnTerminate( + CheckedContinuation, + Failure?, + (() -> Void)? + ) + /// Indicates that the continuation should be resumed with `nil`. + case resumeContinuationWithNil(CheckedContinuation) + /// Indicates that nothing should be done. + case none + } + + mutating func suspendNext(continuation: CheckedContinuation) -> SuspendNextAction { + switch self.state { + case .initial: + // We need to transition to streaming before we can suspend + preconditionFailure("Internal inconsistency") + + case .streaming(_, _, .some, _, _, _, _, _): + // We have multiple AsyncIterators iterating the sequence + preconditionFailure("This should never happen since we only allow a single Iterator to be created") + + case .streaming( + var backPressureStrategy, + var buffer, + .none, + var producerContinuations, + let cancelledAsyncProducers, + let hasOutstandingDemand, + let iteratorInitialized, + let onTerminate + ): + // We have to check here again since we might have a producer interleave next and suspendNext + guard let element = buffer.popFirst() else { + // There is nothing in the buffer to fulfil the demand so we to store the continuation. + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: continuation, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: hasOutstandingDemand, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + + return .none + } + // We have an element to fulfil the demand right away. + + let shouldProduceMore = backPressureStrategy.didConsume(element: element) + + guard shouldProduceMore else { + // We don't have any new demand, so we can just return the element. + return .resumeContinuationWithElement(continuation, element) + } + let producers = Array(producerContinuations.map { $0.1 }) + producerContinuations.removeAll() + self.state = .streaming( + backPressureStrategy: backPressureStrategy, + buffer: buffer, + consumerContinuation: nil, + producerContinuations: producerContinuations, + cancelledAsyncProducers: cancelledAsyncProducers, + hasOutstandingDemand: shouldProduceMore, + iteratorInitialized: iteratorInitialized, + onTerminate: onTerminate + ) + return .resumeContinuationWithElementAndProducers(continuation, element, producers) + + case .sourceFinished(var buffer, let iteratorInitialized, let failure, let onTerminate): + // Check if we have an element left in the buffer and return it + guard let element = buffer.popFirst() else { + // We are returning the queued failure now and can transition to finished + self.state = .finished(iteratorInitialized: iteratorInitialized) + + return .resumeContinuationWithFailureAndCallOnTerminate(continuation, failure, onTerminate) + } + self.state = .sourceFinished( + buffer: buffer, + iteratorInitialized: iteratorInitialized, + failure: failure, + onTerminate: onTerminate + ) + + return .resumeContinuationWithElement(continuation, element) + + case .finished: return .resumeContinuationWithNil(continuation) + } + } + + /// Actions returned by `cancelNext()`. + enum CancelNextAction { + /// Indicates that the continuation should be resumed with a cancellation error, the producers should be finished and call onTerminate. + case resumeContinuationWithCancellationErrorAndFinishProducersAndCallOnTerminate( + CheckedContinuation, + [(Result) -> Void], + (() -> Void)? + ) + /// Indicates that the producers should be finished and call onTerminate. + case finishProducersAndCallOnTerminate([(Result) -> Void], (() -> Void)?) + /// Indicates that nothing should be done. + case none + } + + mutating func cancelNext() -> CancelNextAction { + switch self.state { + case .initial: + // We need to transition to streaming before we can suspend + preconditionFailure("Internal inconsistency") + + case .streaming( + _, + _, + let consumerContinuation, + let producerContinuations, + _, + _, + let iteratorInitialized, + let onTerminate + ): + self.state = .finished(iteratorInitialized: iteratorInitialized) + + guard let consumerContinuation = consumerContinuation else { + return .finishProducersAndCallOnTerminate(Array(producerContinuations.map { $0.1 }), onTerminate) + } + return .resumeContinuationWithCancellationErrorAndFinishProducersAndCallOnTerminate( + consumerContinuation, + Array(producerContinuations.map { $0.1 }), + onTerminate + ) + + case .sourceFinished, .finished: return .none + } + } + } +} diff --git a/Sources/OpenAPIURLSession/AsyncBackpressuredStream/NIOLock.swift b/Sources/OpenAPIURLSession/AsyncBackpressuredStream/NIOLock.swift new file mode 100644 index 0000000..783d37e --- /dev/null +++ b/Sources/OpenAPIURLSession/AsyncBackpressuredStream/NIOLock.swift @@ -0,0 +1,213 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore-file +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Darwin) +import Darwin +#elseif os(Windows) +import ucrt +import WinSDK +#elseif canImport(Glibc) +import Glibc +#elseif canImport(Musl) +import Musl +#else +#error("The concurrency NIOLock module was unable to identify your C library.") +#endif + +#if os(Windows) +@usableFromInline typealias LockPrimitive = SRWLOCK +#else +@usableFromInline typealias LockPrimitive = pthread_mutex_t +#endif + +@usableFromInline enum LockOperations {} + +extension LockOperations { + @inlinable static func create(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + InitializeSRWLock(mutex) + #else + var attr = pthread_mutexattr_t() + pthread_mutexattr_init(&attr) + + let err = pthread_mutex_init(mutex, &attr) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + // SRWLOCK does not need to be free'd + #else + let err = pthread_mutex_destroy(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + AcquireSRWLockExclusive(mutex) + #else + let err = pthread_mutex_lock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } + + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { + mutex.assertValidAlignment() + + #if os(Windows) + ReleaseSRWLockExclusive(mutex) + #else + let err = pthread_mutex_unlock(mutex) + precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") + #endif + } +} + +// Tail allocate both the mutex and a generic value using ManagedBuffer. +// Both the header pointer and the elements pointer are stable for +// the class's entire lifetime. +// +// However, for safety reasons, we elect to place the lock in the "elements" +// section of the buffer instead of the head. The reasoning here is subtle, +// so buckle in. +// +// _As a practical matter_, the implementation of ManagedBuffer ensures that +// the pointer to the header is stable across the lifetime of the class, and so +// each time you call `withUnsafeMutablePointers` or `withUnsafeMutablePointerToHeader` +// the value of the header pointer will be the same. This is because ManagedBuffer uses +// `Builtin.addressOf` to load the value of the header, and that does ~magic~ to ensure +// that it does not invoke any weird Swift accessors that might copy the value. +// +// _However_, the header is also available via the `.header` field on the ManagedBuffer. +// This presents a problem! The reason there's an issue is that `Builtin.addressOf` and friends +// do not interact with Swift's exclusivity model. That is, the various `with` functions do not +// conceptually trigger a mutating access to `.header`. For elements this isn't a concern because +// there's literally no other way to perform the access, but for `.header` it's entirely possible +// to accidentally recursively read it. +// +// Our implementation is free from these issues, so we don't _really_ need to worry about it. +// However, out of an abundance of caution, we store the Value in the header, and the LockPrimitive +// in the trailing elements. We still don't use `.header`, but it's better to be safe than sorry, +// and future maintainers will be happier that we were cautious. +// +// See also: https://github.com/apple/swift/pull/40000 +@usableFromInline final class LockStorage: ManagedBuffer { + + @inlinable static func create(value: Value) -> Self { + let buffer = Self.create(minimumCapacity: 1) { _ in return value } + let storage = unsafeDowncast(buffer, to: Self.self) + + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } + + return storage + } + + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } + + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } + + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } + + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointerToElements { lockPtr in return try body(lockPtr) } + } + + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { + try self.withUnsafeMutablePointers { valuePtr, lockPtr in LockOperations.lock(lockPtr) + defer { LockOperations.unlock(lockPtr) } + return try mutate(&valuePtr.pointee) + } + } +} + +extension LockStorage: @unchecked Sendable {} + +/// A threading lock based on `libpthread` instead of `libdispatch`. +/// +/// - note: ``NIOLock`` has reference semantics. +/// +/// This object provides a lock on top of a single `pthread_mutex_t`. This kind +/// of lock is safe to use with `libpthread`-based threading models, such as the +/// one used by NIO. On Windows, the lock is based on the substantially similar +/// `SRWLOCK` type. +public struct NIOLock { + @usableFromInline internal let _storage: LockStorage + + /// Create a new lock. + @inlinable public init() { self._storage = .create(value: ()) } + + /// Acquire the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `unlock`, to simplify lock handling. + @inlinable public func lock() { self._storage.lock() } + + /// Release the lock. + /// + /// Whenever possible, consider using `withLock` instead of this method and + /// `lock`, to simplify lock handling. + @inlinable public func unlock() { self._storage.unlock() } + + @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows + -> T + { return try self._storage.withLockPrimitive(body) } +} + +extension NIOLock { + /// Acquire the lock for the duration of the given block. + /// + /// This convenience method should be preferred to `lock` and `unlock` in + /// most situations, as it ensures that the lock will be released regardless + /// of how `body` exits. + /// + /// - Parameter body: The block to execute while holding the lock. + /// - Returns: The value returned by the block. + @inlinable public func withLock(_ body: () throws -> T) rethrows -> T { + self.lock() + defer { self.unlock() } + return try body() + } + + @inlinable public func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } +} + +extension NIOLock: Sendable {} + +extension UnsafeMutablePointer { + @inlinable func assertValidAlignment() { + assert(UInt(bitPattern: self) % UInt(MemoryLayout.alignment) == 0) + } +} diff --git a/Sources/OpenAPIURLSession/Documentation.docc/Documentation.md b/Sources/OpenAPIURLSession/Documentation.docc/Documentation.md index 551611b..4c34441 100644 --- a/Sources/OpenAPIURLSession/Documentation.docc/Documentation.md +++ b/Sources/OpenAPIURLSession/Documentation.docc/Documentation.md @@ -9,9 +9,14 @@ A client transport that uses the [URLSession](https://developer.apple.com/docume Use the transport with client code generated by [Swift OpenAPI Generator](https://github.com/apple/swift-openapi-generator). ### Supported platforms and minimum versions -| macOS | Linux | iOS | tvOS | watchOS | -| :-: | :-: | :-: | :-: | :-: | -| ✅ 10.15+ | ✅ | ✅ 13+ | ✅ 13+ | ✅ 6+ | + +| macOS | Linux | iOS | tvOS | watchOS | +| :-: | :-: | :-: | :-: | :-: | +| ✅ 10.15+ | ✅ | ✅ 13+ | ✅ 13+ | ✅ 6+ | + +Note: Streaming support only available on macOS 12+, iOS 15+, tvOS 15+, and +watchOS 8+.For streaming support on Linux, please use the [AsyncHTTPClient +Transport](https://github.com/swift-server/swift-openapi-async-http-client) ### Usage diff --git a/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/BidirectionalStreamingURLSessionDelegate.swift b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/BidirectionalStreamingURLSessionDelegate.swift new file mode 100644 index 0000000..bb9918b --- /dev/null +++ b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/BidirectionalStreamingURLSessionDelegate.swift @@ -0,0 +1,191 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import OpenAPIRuntime +import HTTPTypes +#if canImport(Darwin) +import Foundation + +/// Delegate that supports bidirectional streaming of request and response bodies. +/// +/// While URLSession provides a high-level API that returns an async sequence of +/// bytes, `bytes(for:delegate:)`, but does not provide an API that takes an async sequence +/// as a request body. For instance, `upload(for:delegate:)` and `upload(fromFile:delegate:)` +/// both buffer the entire response body and return `Data`. +/// +/// Additionally, bridging `URLSession.AsyncBytes`, which is an `AsyncSequence` to +/// `OpenAPIRuntime.HTTPBody`, an `AsyncSequence`, is problematic and will +/// incur an allocation for every byte. +/// +/// This delegate vends the response body as a `HTTBody` with one chunk for each +/// `urlSession(_:didReceive data:)` callback. It also provides backpressure, which will +/// suspend and resume the URLSession task based on a configurable high and low watermark. +/// +/// When performing requests without a body, this delegate should be used with a +/// `URLSessionDataTask` to stream the response body. +/// +/// When performing requests with a body, this delegate should be used with a +/// `URLSessionUploadTask` using `uploadTask(withStreamedRequest:delegate:)`, which will +/// ask the delegate for a `InputStream` for the request body via the +/// `urlSession(_:needNewBodyStreamForTask:)` callback. +/// +/// The `urlSession(_:needNewBodyStreamForTask:)` callback will create a pair of bound +/// streams, bridge the `HTTPBody` request body to the `OutputStream` and return the +/// `InputStream` to URLSession. Backpressure for the request body stream is provided +/// as an implementation detail of how URLSession reads from the `InputStream`. +/// +/// Note that `urlSession(_:needNewBodyStreamForTask:)` may be called more than once, e.g. +/// when performing a HTTP redirect, upon which the delegate is expected to create a new +/// `InputStream` for the request body. This is only possible if the underlying `HTTPBody` +/// request body can be iterated multiple times, i.e. `iterationBehavior == .multiple`. +/// If the request body cannot be iterated multiple times, then the URLSession task will be cancelled. +final class BidirectionalStreamingURLSessionDelegate: NSObject, URLSessionTaskDelegate, URLSessionDataDelegate { + + let requestBody: HTTPBody? + var hasAlreadyIteratedRequestBody: Bool + var hasSuspendedURLSessionTask: Bool + let requestStreamBufferSize: Int + var requestStream: HTTPBodyOutputStreamBridge? + + typealias ResponseContinuation = CheckedContinuation + var responseContinuation: ResponseContinuation? + + typealias ResponseBodyStream = AsyncBackpressuredStream + var responseBodyStream: ResponseBodyStream + var responseBodyStreamSource: ResponseBodyStream.Source + + /// This lock is taken for the duration of all delegate callbacks to protect the mutable delegate state. + /// + /// Although all the delegate callbacks are performed on the session's `delegateQueue`, there is no guarantee that + /// this is a _serial_ queue. + /// + /// Regardless of the type of delegate queue, URLSession will attempt to order the callbacks for each task in a + /// sensible way, but it cannot be guaranteed, specifically when the URLSession task is cancelled. + /// + /// Therefore, even though the `suspend()`, `resume()`, and `cancel()` URLSession methods are thread-safe, we need + /// to protect any mutable state within the delegate itself. + let callbackLock = NIOLock() + + /// In addition to the callback lock, there is one point of rentrancy, where the response stream callback gets fired + /// immediately, for this we have a different lock, which protects `hasSuspendedURLSessionTask`. + let hasSuspendedURLSessionTaskLock = NIOLock() + + /// Use `bidirectionalStreamingRequest(for:baseURL:requestBody:requestStreamBufferSize:responseStreamWatermarks:)`. + init(requestBody: HTTPBody?, requestStreamBufferSize: Int, responseStreamWatermarks: (low: Int, high: Int)) { + self.requestBody = requestBody + self.hasAlreadyIteratedRequestBody = false + self.hasSuspendedURLSessionTask = false + self.requestStreamBufferSize = requestStreamBufferSize + (self.responseBodyStream, self.responseBodyStreamSource) = AsyncBackpressuredStream.makeStream( + backPressureStrategy: .highLowWatermarkWithElementCounts( + lowWatermark: responseStreamWatermarks.low, + highWatermark: responseStreamWatermarks.high + ) + ) + } + + func urlSession(_ session: URLSession, needNewBodyStreamForTask task: URLSessionTask) async -> InputStream? { + callbackLock.withLock { + debug("Task delegate: needNewBodyStreamForTask") + // If the HTTP body cannot be iterated multiple times then bad luck; the only thing + // we can do is cancel the task and return nil. + if hasAlreadyIteratedRequestBody { + guard requestBody!.iterationBehavior == .multiple else { + debug("Task delegate: Cannot rewind request body, cancelling task") + task.cancel() + return nil + } + } + hasAlreadyIteratedRequestBody = true + + // Create a fresh pair of streams. + let (inputStream, outputStream) = createStreamPair(withBufferSize: requestStreamBufferSize) + + // Bridge the output stream to the request body (which opens the output stream). + requestStream = HTTPBodyOutputStreamBridge(outputStream, requestBody!) + + // Return the new input stream (unopened, it gets opened by URLSession). + return inputStream + } + } + + func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { + callbackLock.withLock { + debug("Task delegate: didReceive data (numBytes: \(data.count))") + do { + switch try responseBodyStreamSource.write(contentsOf: CollectionOfOne(ArraySlice(data))) { + case .produceMore: break + case .enqueueCallback(let writeToken): + let shouldActuallyEnqueueCallback = hasSuspendedURLSessionTaskLock.withLock { + if hasSuspendedURLSessionTask { + debug("Task delegate: already suspended task, not enqueing another writer callback") + return false + } + debug("Task delegate: response stream backpressure, suspending task and enqueing callback") + dataTask.suspend() + hasSuspendedURLSessionTask = true + return true + } + if shouldActuallyEnqueueCallback { + responseBodyStreamSource.enqueueCallback(writeToken: writeToken) { result in + self.hasSuspendedURLSessionTaskLock.withLock { + switch result { + case .success: + debug("Task delegate: response stream callback, resuming task") + dataTask.resume() + self.hasSuspendedURLSessionTask = false + case .failure(let error): + debug("Task delegate: response stream callback, cancelling task, error: \(error)") + dataTask.cancel() + } + } + } + } + } + } catch { + debug("Task delegate: response stream consumer terminated, cancelling task") + dataTask.cancel() + } + } + } + + func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse) async + -> URLSession.ResponseDisposition + { + callbackLock.withLock { + debug("Task delegate: didReceive response") + self.responseContinuation?.resume(returning: response) + return .allow + } + } + + func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: (any Error)?) { + callbackLock.withLock { + debug("Task delegate: didCompleteWithError (error: \(String(describing: error)))") + responseBodyStreamSource.finish(throwing: error) + if let error { responseContinuation?.resume(throwing: error) } + } + } +} + +extension BidirectionalStreamingURLSessionDelegate: @unchecked Sendable {} // State synchronized using DispatchQueue. + +private func createStreamPair(withBufferSize bufferSize: Int) -> (InputStream, OutputStream) { + var inputStream: InputStream? + var outputStream: OutputStream? + Stream.getBoundStreams(withBufferSize: bufferSize, inputStream: &inputStream, outputStream: &outputStream) + guard let inputStream, let outputStream else { fatalError("getBoundStreams did not return non-nil streams") } + return (inputStream, outputStream) +} + +#endif // canImport(Darwin) diff --git a/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/HTTPBodyOutputStreamBridge.swift b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/HTTPBodyOutputStreamBridge.swift new file mode 100644 index 0000000..3b89d62 --- /dev/null +++ b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/HTTPBodyOutputStreamBridge.swift @@ -0,0 +1,287 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import OpenAPIRuntime +import HTTPTypes +#if canImport(Darwin) +import Foundation + +final class HTTPBodyOutputStreamBridge: NSObject, StreamDelegate { + static let streamQueue = DispatchQueue(label: "HTTPBodyStreamDelegate", autoreleaseFrequency: .workItem) + + let httpBody: HTTPBody + let outputStream: OutputStream + private(set) var state: State { + didSet { debug("Output stream delegate state transition: \(oldValue) -> \(state)") } + } + + /// Creates a new `HTTPBodyOutputStreamBridge` and opens the output stream. + init(_ outputStream: OutputStream, _ httpBody: HTTPBody) { + self.httpBody = httpBody + self.outputStream = outputStream + self.state = .initial + super.init() + self.outputStream.delegate = self + CFWriteStreamSetDispatchQueue(self.outputStream as CFWriteStream, Self.streamQueue) + self.outputStream.open() + } + + deinit { debug("Output stream delegate deinit") } + + func performAction(_ action: State.Action) { + debug("Output stream delegate performing action from state machine: \(action)") + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + switch action { + case .none: return + case .resumeProducer(let producerContinuation): + producerContinuation.resume() + performAction(self.state.resumedProducer()) + case .writeBytes(let chunk): writePendingBytes(chunk) + case .cancelProducerAndCloseStream(let producerContinuation): + producerContinuation.resume(throwing: CancellationError()) + outputStream.close() + case .cancelProducer(let producerContinuation): producerContinuation.resume(throwing: CancellationError()) + case .closeStream: outputStream.close() + } + } + + func startWriterTask() { + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + let task = Task { + dispatchPrecondition(condition: .notOnQueue(Self.streamQueue)) + for try await chunk in httpBody { + try await withCheckedThrowingContinuation { continuation in + Self.streamQueue.async { + debug("Output stream delegate produced chunk and suspended producer.") + self.performAction(self.state.producedChunkAndSuspendedProducer(chunk, continuation)) + } + } + } + Self.streamQueue.async { + debug("Output stream delegate wrote final chunk.") + self.performAction(self.state.wroteFinalChunk()) + } + } + self.performAction(self.state.startedProducerTask(task)) + } + + private func writePendingBytes(_ bytesToWrite: Chunk) { + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + precondition(!bytesToWrite.isEmpty, "\(#function) must be called with non-empty bytes") + guard outputStream.streamStatus == .open else { + debug("Output stream closed unexpectedly.") + performAction(self.state.wroteBytes(numBytesWritten: 0, streamStillHasSpaceAvailable: false)) + return + } + switch bytesToWrite.withUnsafeBytes({ outputStream.write($0.baseAddress!, maxLength: bytesToWrite.count) }) { + case 0: + debug("Output stream delegate reached end of stream when writing.") + performAction(self.state.endEncountered()) + case -1: + debug("Output stream delegate encountered error writing to stream: \(outputStream.streamError!).") + performAction(self.state.errorOccurred(outputStream.streamError!)) + case let written where written > 0: + debug("Output stream delegate wrote \(written) bytes to stream.") + performAction( + self.state.wroteBytes( + numBytesWritten: written, + streamStillHasSpaceAvailable: outputStream.hasSpaceAvailable + ) + ) + default: preconditionFailure("OutputStream.write(_:maxLength:) returned undocumented value") + } + } + + func stream(_ stream: Stream, handle event: Stream.Event) { + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + debug("Output stream delegate received event: \(event).") + switch event { + case .openCompleted: + guard case .initial = state else { + debug("Output stream delegate ignoring duplicate openCompleted event.") + return + } + startWriterTask() + case .hasSpaceAvailable: performAction(self.state.spaceBecameAvailable()) + case .errorOccurred: performAction(self.state.errorOccurred(stream.streamError!)) + case .endEncountered: performAction(self.state.endEncountered()) + default: + debug("Output stream ignoring event: \(event).") + break + } + } +} + +extension HTTPBodyOutputStreamBridge { + typealias Chunk = ArraySlice + typealias ProducerTask = Task + typealias ProducerContinuation = CheckedContinuation + + enum State { + case initial + case waitingForBytes(spaceAvailable: Bool) + case haveBytes(spaceAvailable: Bool, Chunk, ProducerContinuation) + case needBytes(spaceAvailable: Bool, ProducerContinuation) + case closed((any Error)?) + + mutating func startedProducerTask(_ producerTask: ProducerTask) -> Action { + switch self { + case .initial: + self = .waitingForBytes(spaceAvailable: false) + return .none + case .waitingForBytes, .haveBytes, .needBytes, .closed: + preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func producedChunkAndSuspendedProducer(_ chunk: Chunk, _ producerContinuation: ProducerContinuation) + -> Action + { + switch self { + case .waitingForBytes(let spaceAvailable): + self = .haveBytes(spaceAvailable: spaceAvailable, chunk, producerContinuation) + guard spaceAvailable else { return .none } + return .writeBytes(chunk) + case .closed: return .cancelProducer(producerContinuation) + case .initial, .haveBytes, .needBytes: preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func wroteBytes(numBytesWritten: Int, streamStillHasSpaceAvailable: Bool) -> Action { + switch self { + case .haveBytes(let spaceAvailable, let chunk, let producerContinuation): + guard spaceAvailable, numBytesWritten <= chunk.count else { preconditionFailure() } + let remaining = chunk.dropFirst(numBytesWritten) + guard remaining.isEmpty else { + self = .haveBytes(spaceAvailable: streamStillHasSpaceAvailable, remaining, producerContinuation) + guard streamStillHasSpaceAvailable else { return .none } + return .writeBytes(remaining) + } + self = .needBytes(spaceAvailable: streamStillHasSpaceAvailable, producerContinuation) + return .resumeProducer(producerContinuation) + case .initial, .needBytes, .waitingForBytes, .closed: + preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func resumedProducer() -> Action { + switch self { + case .needBytes(let spaceAvailable, _): + self = .waitingForBytes(spaceAvailable: spaceAvailable) + return .none + case .initial, .haveBytes, .waitingForBytes, .closed: + preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func errorOccurred(_ error: any Error) -> Action { + switch self { + case .initial: + self = .closed(error) + return .none + case .waitingForBytes(_): + self = .closed(error) + return .closeStream + case .haveBytes(_, _, let producerContinuation): + self = .closed(error) + return .cancelProducerAndCloseStream(producerContinuation) + case .needBytes(_, let producerContinuation): + self = .closed(error) + return .cancelProducerAndCloseStream(producerContinuation) + case .closed: preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func wroteFinalChunk() -> Action { + switch self { + case .waitingForBytes(_): + self = .closed(nil) + return .closeStream + case .initial, .haveBytes, .needBytes, .closed: + preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func endEncountered() -> Action { + switch self { + case .waitingForBytes(_): + self = .closed(nil) + return .closeStream + case .haveBytes(_, _, let producerContinuation): + self = .closed(nil) + return .cancelProducerAndCloseStream(producerContinuation) + case .needBytes(_, let producerContinuation): + self = .closed(nil) + return .cancelProducerAndCloseStream(producerContinuation) + case .initial, .closed: preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + mutating func spaceBecameAvailable() -> Action { + switch self { + case .waitingForBytes(_): + self = .waitingForBytes(spaceAvailable: true) + return .none + case .haveBytes(_, let chunk, let producerContinuation): + self = .haveBytes(spaceAvailable: true, chunk, producerContinuation) + return .writeBytes(chunk) + case .needBytes(_, let producerContinuation): + self = .needBytes(spaceAvailable: true, producerContinuation) + return .none + case .closed: + debug("Ignoring space available event in closed state") + return .none + case .initial: preconditionFailure("\(#function) called in invalid state: \(self)") + } + } + + enum Action { + case none + case resumeProducer(ProducerContinuation) + case writeBytes(Chunk) + case cancelProducerAndCloseStream(ProducerContinuation) + case cancelProducer(ProducerContinuation) + case closeStream + } + } +} + +extension HTTPBodyOutputStreamBridge: @unchecked Sendable {} // State synchronized using DispatchQueue. + +extension HTTPBodyOutputStreamBridge.State: CustomStringConvertible { + var description: String { + switch self { + case .initial: return "initial" + case .waitingForBytes(let spaceAvailable): return "waitingForBytes(spaceAvailable: \(spaceAvailable))" + case .haveBytes(let spaceAvailable, let chunk, _): + return "haveBytes(spaceAvailable: \(spaceAvailable), [\(chunk.count) bytes])" + case .needBytes(let spaceAvailable, _): return "needBytes (spaceAvailable: \(spaceAvailable), _)" + case .closed(let error): return "closed (error: \(String(describing: error)))" + } + } +} + +extension HTTPBodyOutputStreamBridge.State.Action: CustomStringConvertible { + var description: String { + switch self { + case .none: return "none" + case .resumeProducer: return "resumeProducer" + case .writeBytes: return "writeBytes" + case .cancelProducerAndCloseStream: return "cancelProducerAndCloseStream" + case .cancelProducer: return "cancelProducer" + case .closeStream: return "closeStream" + } + } +} + +#endif // canImport(Darwin) diff --git a/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift new file mode 100644 index 0000000..007b9f2 --- /dev/null +++ b/Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import OpenAPIRuntime +import HTTPTypes +#if canImport(Darwin) +import Foundation + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) extension URLSession { + func bidirectionalStreamingRequest( + for request: HTTPRequest, + baseURL: URL, + requestBody: HTTPBody?, + requestStreamBufferSize: Int, + responseStreamWatermarks: (low: Int, high: Int) + ) async throws -> (HTTPResponse, HTTPBody?) { + let urlRequest = try URLRequest(request, baseURL: baseURL) + let task: URLSessionTask + if requestBody != nil { + task = uploadTask(withStreamedRequest: urlRequest) + } else { + task = dataTask(with: urlRequest) + } + return try await withTaskCancellationHandler { + let delegate = BidirectionalStreamingURLSessionDelegate( + requestBody: requestBody, + requestStreamBufferSize: requestStreamBufferSize, + responseStreamWatermarks: responseStreamWatermarks + ) + let response = try await withCheckedThrowingContinuation { continuation in + delegate.responseContinuation = continuation + task.delegate = delegate + task.resume() + } + let responseBody = HTTPBody( + delegate.responseBodyStream, + length: .init(from: response), + iterationBehavior: .single + ) + return (try HTTPResponse(response), responseBody) + } onCancel: { + task.cancel() + } + } +} + +#endif // canImport(Darwin) diff --git a/Sources/OpenAPIURLSession/URLSessionTransport.swift b/Sources/OpenAPIURLSession/URLSessionTransport.swift index 00aec8e..88f3413 100644 --- a/Sources/OpenAPIURLSession/URLSessionTransport.swift +++ b/Sources/OpenAPIURLSession/URLSessionTransport.swift @@ -20,13 +20,13 @@ import Foundation @preconcurrency import struct Foundation.URLComponents @preconcurrency import struct Foundation.Data @preconcurrency import protocol Foundation.LocalizedError -#endif #if canImport(FoundationNetworking) @preconcurrency import struct FoundationNetworking.URLRequest @preconcurrency import class FoundationNetworking.URLSession @preconcurrency import class FoundationNetworking.URLResponse @preconcurrency import class FoundationNetworking.HTTPURLResponse #endif +#endif /// A client transport that performs HTTP operations using the URLSession type /// provided by the Foundation framework. @@ -73,8 +73,23 @@ public struct URLSessionTransport: ClientTransport { /// Creates a new configuration with the provided session. /// - Parameter session: The URLSession used for performing HTTP operations. - /// If none is provided, the system uses the shared URLSession. - public init(session: URLSession = .shared) { self.session = session } + /// If none is provided, the system uses the shared URLSession. + public init(session: URLSession = .shared) { self.init(session: session, implementation: .platformDefault) } + + enum Implementation { + case buffering + case streaming(requestBodyStreamBufferSize: Int, responseBodyStreamWatermarks: (low: Int, high: Int)) + } + + var implemenation: Implementation + + init(session: URLSession = .shared, implementation: Implementation = .platformDefault) { + self.session = session + if case .streaming = implementation { + precondition(Implementation.platformSupportsStreaming, "Streaming not supported on platform") + } + self.implemenation = implementation + } } /// A set of configuration values used by the transport. @@ -84,42 +99,50 @@ public struct URLSessionTransport: ClientTransport { /// - Parameter configuration: A set of configuration values used by the transport. public init(configuration: Configuration = .init()) { self.configuration = configuration } - /// Asynchronously sends an HTTP request and returns the response and body. - /// + /// Sends the underlying HTTP request and returns the received HTTP response. /// - Parameters: - /// - request: The HTTP request to be sent. - /// - body: The HTTP body to include in the request (optional). - /// - baseURL: The base URL for the request. - /// - operationID: An optional identifier for the operation or request. - /// - Returns: A tuple containing the HTTP response and an optional HTTP response body. - /// - Throws: An error if there is a problem sending the request or processing the response. - public func send(_ request: HTTPRequest, body: HTTPBody?, baseURL: URL, operationID: String) async throws -> ( - HTTPResponse, HTTPBody? - ) { - // TODO: https://github.com/apple/swift-openapi-generator/issues/301 - let urlRequest = try await URLRequest(request, body: body, baseURL: baseURL) - let (responseBody, urlResponse) = try await invokeSession(urlRequest) - return try HTTPResponse.response(method: request.method, urlResponse: urlResponse, data: responseBody) + /// - request: An HTTP request. + /// - requestBody: An HTTP request body. + /// - baseURL: A server base URL. + /// - operationID: The identifier of the OpenAPI operation. + /// - Returns: An HTTP response and its body. + /// - Throws: If there was an error performing the HTTP request. + public func send(_ request: HTTPRequest, body requestBody: HTTPBody?, baseURL: URL, operationID: String) + async throws -> (HTTPResponse, HTTPBody?) + { + switch self.configuration.implemenation { + case .streaming(let requestBodyStreamBufferSize, let responseBodyStreamWatermarks): + #if canImport(Darwin) + guard #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) else { + throw URLSessionTransportError.streamingNotSupported + } + return try await configuration.session.bidirectionalStreamingRequest( + for: request, + baseURL: baseURL, + requestBody: requestBody, + requestStreamBufferSize: requestBodyStreamBufferSize, + responseStreamWatermarks: responseBodyStreamWatermarks + ) + #else + throw URLSessionTransportError.streamingNotSupported + #endif + case .buffering: + return try await configuration.session.bufferedRequest( + for: request, + baseURL: baseURL, + requestBody: requestBody + ) + } } +} - private func invokeSession(_ urlRequest: URLRequest) async throws -> (Data, URLResponse) { - // Using `dataTask(with:completionHandler:)` instead of the async method `data(for:)` of URLSession because the latter is not available on linux platforms - return try await withCheckedThrowingContinuation { continuation in - configuration.session - .dataTask(with: urlRequest) { data, response, error in - if let error { - continuation.resume(with: .failure(error)) - return - } - - guard let response else { - continuation.resume(with: .failure(URLSessionTransportError.noResponse(url: urlRequest.url))) - return - } - - continuation.resume(with: .success((data ?? Data(), response))) - } - .resume() +extension HTTPBody.Length { + init(from urlResponse: URLResponse) { + if urlResponse.expectedContentLength == -1 { + self = .unknown + } else { + // TODO: Content-Length will change to Int64: https://github.com/apple/swift-openapi-generator/issues/354 + self = .known(Int(urlResponse.expectedContentLength)) } } } @@ -135,12 +158,13 @@ internal enum URLSessionTransportError: Error { /// Returned `URLResponse` was nil case noResponse(url: URL?) + + /// Platform does not support streaming. + case streamingNotSupported } extension HTTPResponse { - static func response(method: HTTPRequest.Method, urlResponse: URLResponse, data: Data) throws -> ( - HTTPResponse, HTTPBody? - ) { + init(_ urlResponse: URLResponse) throws { guard let httpResponse = urlResponse as? HTTPURLResponse else { throw URLSessionTransportError.notHTTPResponse(urlResponse) } @@ -151,17 +175,12 @@ extension HTTPResponse { else { continue } headerFields[name] = value } - let body: HTTPBody? - switch method { - case .head, .connect, .trace: body = nil - default: body = .init(data) - } - return (HTTPResponse(status: .init(code: httpResponse.statusCode), headerFields: headerFields), body) + self.init(status: .init(code: httpResponse.statusCode), headerFields: headerFields) } } extension URLRequest { - init(_ request: HTTPRequest, body: HTTPBody?, baseURL: URL) async throws { + init(_ request: HTTPRequest, baseURL: URL) throws { guard var baseUrlComponents = URLComponents(string: baseURL.absoluteString), let requestUrlComponents = URLComponents(string: request.path ?? "") else { @@ -183,20 +202,16 @@ extension URLRequest { for header in request.headerFields { self.setValue(header.value, forHTTPHeaderField: header.name.canonicalName) } - if let body { - // TODO: https://github.com/apple/swift-openapi-generator/issues/301 - self.httpBody = try await Data(collecting: body, upTo: .max) - } } } extension URLSessionTransportError: LocalizedError { - /// A custom error description for `URLSessionTransportError`. + /// A localized message describing what error occurred. public var errorDescription: String? { description } } extension URLSessionTransportError: CustomStringConvertible { - /// A custom textual representation for `URLSessionTransportError`. + /// A textual representation of this instance. public var description: String { switch self { case let .invalidRequestURL(path: path, method: method, baseURL: baseURL): @@ -205,6 +220,73 @@ extension URLSessionTransportError: CustomStringConvertible { case .notHTTPResponse(let response): return "Received a non-HTTP response, of type: \(String(describing: type(of: response)))" case .noResponse(let url): return "Received a nil response for \(url?.absoluteString ?? "")" + case .streamingNotSupported: return "Streaming is not supported on this platform" } } } + +private let _debugLoggingEnabled = LockStorage.create(value: false) +var debugLoggingEnabled: Bool { + get { _debugLoggingEnabled.withLockedValue { $0 } } + set { _debugLoggingEnabled.withLockedValue { $0 = newValue } } +} +func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") { + assert( + { + if debugLoggingEnabled { print(items, separator: separator, terminator: terminator) } + return true + }() + ) +} + +extension URLSession { + func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> ( + HTTPResponse, HTTPBody? + ) { + var urlRequest = try URLRequest(request, baseURL: baseURL) + if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) } + + /// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on + /// Darwin platforms newer than our minimum deployment target, and not at all on Linux. + let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation { + continuation in + let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let response else { + continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url)) + return + } + continuation.resume(with: .success((response, data))) + } + task.resume() + } + + let maybeResponseBody = maybeResponseBodyData.map { data in + HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple) + } + return (try HTTPResponse(response), maybeResponseBody) + } +} + +extension URLSessionTransport.Configuration.Implementation { + static var platformSupportsStreaming: Bool { + #if canImport(Darwin) + guard #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) else { return false } + _ = URLSession.bidirectionalStreamingRequest + return true + #else + return false + #endif + } + + static var platformDefault: Self { + guard platformSupportsStreaming else { return .buffering } + return .streaming( + requestBodyStreamBufferSize: 16 * 1024, + responseBodyStreamWatermarks: (low: 16 * 1024, high: 32 * 1024) + ) + } +} diff --git a/Tests/OpenAPIURLSessionTests/AsyncBackpressuredStreamTests/AsyncBackpressuredStreamTests.swift b/Tests/OpenAPIURLSessionTests/AsyncBackpressuredStreamTests/AsyncBackpressuredStreamTests.swift new file mode 100644 index 0000000..bdb474a --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/AsyncBackpressuredStreamTests/AsyncBackpressuredStreamTests.swift @@ -0,0 +1,208 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020-2021 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +import XCTest +@testable import OpenAPIURLSession + +final class AsyncBackpressuredStreamTests: XCTestCase { + func testYield() async throws { + let (stream, source) = AsyncBackpressuredStream.makeStream( + of: Int.self, + backPressureStrategy: .highLowWatermark(lowWatermark: 5, highWatermark: 10) + ) + + try await source.asyncWrite(contentsOf: [1, 2, 3, 4, 5, 6]) + source.finish(throwing: nil) + + let result = try await stream.collect() + XCTAssertEqual(result, [1, 2, 3, 4, 5, 6]) + } + + func testBackPressure() async throws { + let (stream, source) = AsyncBackpressuredStream.makeStream( + of: Int.self, + backPressureStrategy: .highLowWatermark(lowWatermark: 2, highWatermark: 4) + ) + + let (backPressureEventStream, backPressureEventContinuation) = AsyncStream.makeStream(of: Void.self) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while true { + backPressureEventContinuation.yield(()) + print("Yielding") + try await source.asyncWrite(contentsOf: [1]) + } + } + + var backPressureEventIterator = backPressureEventStream.makeAsyncIterator() + var iterator = stream.makeAsyncIterator() + + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + + print("Waited 4 times") + + _ = try await iterator.next() + _ = try await iterator.next() + _ = try await iterator.next() + print("Consumed three") + + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + + group.cancelAll() + } + } + + func testBackPressureSync() async throws { + let (stream, source) = AsyncBackpressuredStream.makeStream( + of: Int.self, + backPressureStrategy: .highLowWatermark(lowWatermark: 2, highWatermark: 4) + ) + + let (backPressureEventStream, backPressureEventContinuation) = AsyncStream.makeStream(of: Void.self) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + @Sendable func yield() { + backPressureEventContinuation.yield(()) + print("Yielding") + source.write(contentsOf: [1]) { result in + switch result { + case .success: yield() + + case .failure: print("Stopping to yield") + } + } + } + + yield() + } + + var backPressureEventIterator = backPressureEventStream.makeAsyncIterator() + var iterator = stream.makeAsyncIterator() + + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + + print("Waited 4 times") + + _ = try await iterator.next() + _ = try await iterator.next() + _ = try await iterator.next() + print("Consumed three") + + await backPressureEventIterator.next() + await backPressureEventIterator.next() + await backPressureEventIterator.next() + + group.cancelAll() + } + } + + func testWatermarkBackPressureStrategy() async throws { + typealias Strategy = AsyncBackpressuredStream.HighLowWatermarkBackPressureStrategy + var strategy = Strategy(lowWatermark: 2, highWatermark: 3) + + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice(["*", "*"])), true) + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didYield(elements: Slice(["*"])), false) + XCTAssertEqual(strategy.currentWatermark, 3) + XCTAssertEqual(strategy.didYield(elements: Slice(["*"])), false) + XCTAssertEqual(strategy.currentWatermark, 4) + + XCTAssertEqual(strategy.currentWatermark, 4) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), false) + XCTAssertEqual(strategy.currentWatermark, 4) + XCTAssertEqual(strategy.didConsume(elements: Slice(["*", "*"])), false) + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didConsume(elements: Slice(["*"])), true) + XCTAssertEqual(strategy.currentWatermark, 1) + XCTAssertEqual(strategy.didConsume(elements: Slice(["*"])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + } + + func testWatermarkWithoutElementCountsBackPressureStrategy() async throws { + typealias Strategy = AsyncBackpressuredStream<[String], any Error>.HighLowWatermarkBackPressureStrategy + var strategy = Strategy(lowWatermark: 2, highWatermark: 3) + + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 1) + XCTAssertEqual(strategy.didYield(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 2) + + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), false) + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didConsume(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 1) + XCTAssertEqual(strategy.didConsume(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + } + + func testWatermarkWithElementCountsBackPressureStrategy() async throws { + typealias Strategy = AsyncBackpressuredStream<[String], any Error>.HighLowWatermarkBackPressureStrategy + var strategy = Strategy(lowWatermark: 2, highWatermark: 3, waterLevelForElement: { $0.count }) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didYield(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didYield(elements: Slice([["*", "*"]])), false) + XCTAssertEqual(strategy.currentWatermark, 4) + + XCTAssertEqual(strategy.currentWatermark, 4) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), false) + XCTAssertEqual(strategy.currentWatermark, 4) + XCTAssertEqual(strategy.didConsume(elements: Slice([["*", "*"]])), false) + XCTAssertEqual(strategy.currentWatermark, 2) + XCTAssertEqual(strategy.didConsume(elements: Slice([["*", "*"]])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + XCTAssertEqual(strategy.didConsume(elements: Slice([])), true) + XCTAssertEqual(strategy.currentWatermark, 0) + } +} + +extension AsyncSequence { + /// Collect all elements in the sequence into an array. + fileprivate func collect() async rethrows -> [Element] { + try await self.reduce(into: []) { accumulated, next in accumulated.append(next) } + } +} diff --git a/Tests/OpenAPIURLSessionTests/AsyncSyncSequence.swift b/Tests/OpenAPIURLSessionTests/AsyncSyncSequence.swift new file mode 100644 index 0000000..ed83714 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/AsyncSyncSequence.swift @@ -0,0 +1,86 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore-file +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +extension Sequence { + /// An asynchronous sequence containing the same elements as this sequence, + /// but on which operations, such as `map` and `filter`, are + /// implemented asynchronously. + @inlinable + var async: AsyncSyncSequence { + AsyncSyncSequence(self) + } +} + +/// An asynchronous sequence composed from a synchronous sequence. +/// +/// Asynchronous lazy sequences can be used to interface existing or pre-calculated +/// data to interoperate with other asynchronous sequences and algorithms based on +/// asynchronous sequences. +/// +/// This functions similarly to `LazySequence` by accessing elements sequentially +/// in the iterator's `next()` method. +@frozen +public struct AsyncSyncSequence: AsyncSequence { + public typealias Element = Base.Element + + @frozen + public struct Iterator: AsyncIteratorProtocol { + @usableFromInline + var iterator: Base.Iterator? + + @usableFromInline + init(_ iterator: Base.Iterator) { + self.iterator = iterator + } + + @inlinable + public mutating func next() async -> Base.Element? { + if !Task.isCancelled, let value = iterator?.next() { + return value + } else { + iterator = nil + return nil + } + } + } + + @usableFromInline + let base: Base + + @usableFromInline + init(_ base: Base) { + self.base = base + } + + @inlinable + public func makeAsyncIterator() -> Iterator { + Iterator(base.makeIterator()) + } +} + +extension AsyncSyncSequence: Sendable where Base: Sendable { } + +@available(*, unavailable) +extension AsyncSyncSequence.Iterator: Sendable { } diff --git a/Tests/OpenAPIURLSessionTests/Locking.swift b/Tests/OpenAPIURLSessionTests/Locking.swift deleted file mode 100644 index 9d22065..0000000 --- a/Tests/OpenAPIURLSessionTests/Locking.swift +++ /dev/null @@ -1,45 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftOpenAPIGenerator open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Foundation - -/// A wrapper providing locked access to a value. -/// -/// Marked as @unchecked Sendable due to the synchronization being -/// performed manually using locks. -/// -/// Note: Use the `package` access modifier once min Swift version is increased. -@_spi(Locking) public final class LockedValueBox: @unchecked Sendable { - private let lock: NSLock = { - let lock = NSLock() - lock.name = "com.apple.swift-openapi-urlsession.lock.LockedValueBox" - return lock - }() - private var value: Value - /// Initializes a new `LockedValueBox` instance with the provided initial value. - /// - /// - Parameter value: The initial value to store in the `LockedValueBox`. - public init(_ value: Value) { self.value = value } - /// Perform an operation on the value in a synchronized manner. - /// - /// - Parameter work: A closure that takes an inout reference to the wrapped value and returns a result. - /// - /// - Returns: The result of the provided closure. - /// - Returns: The result of the closure passed to `work`. - public func withValue(_ work: (inout Value) throws -> R) rethrows -> R { - lock.lock() - defer { lock.unlock() } - return try work(&value) - } -} diff --git a/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift b/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift new file mode 100644 index 0000000..98e88b2 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/NIOAsyncHTTP1TestServer.swift @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIOCore +import NIOPosix +import NIOHTTP1 + +final class AsyncTestHTTP1Server { + + typealias ConnectionHandler = @Sendable (NIOAsyncChannel) + async throws -> Void + + /// Use `start(host:port:connectionHandler:)` instead. + private init() {} + + /// Start a localhost HTTP1 server with a given connection handler. + /// + /// - Parameters: + /// - connectionTaskGroup: Task group used to run the connection handler on new connections. + /// - connectionHandler: Handler to run for each new connection. + /// - Returns: The port on which the server is running. + /// - Throws: If there was an error starting the server. + static func start( + connectionTaskGroup: inout ThrowingTaskGroup, + connectionHandler: @escaping ConnectionHandler + ) async throws -> Int { + let group: MultiThreadedEventLoopGroup = .singleton + let channel = try await ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: 0) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.configureHTTPServerPipeline() + try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponseChannelHandler()) + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerByteBufferResponsePart.self + ) + ) + } + } + + connectionTaskGroup.addTask { + // NOTE: it would be better to use `withThrowingDiscardingTaskGroup` here, but this would require some availablity dance and this is just used in tests. + try await withThrowingTaskGroup(of: Void.self) { group in + try await channel.executeThenClose { inbound, outbound in + for try await connectionChannel in inbound { + group.addTask { + do { + print("Sevrer handling new connection") + try await connectionHandler(connectionChannel) + print("Server done handling connection") + } catch { print("Server error handling connection: \(error)") } + } + } + } + } + } + return channel.channel.localAddress!.port! + } +} + +/// Because `HTTPServerResponsePart` is not sendable because its body type is `IOData`, which is an abstraction over a +/// `ByteBuffer` or `FileRegion`. The latter is not sendable, so we need a channel handler that deals in terms of only +/// `ByteBuffer`. +extension AsyncTestHTTP1Server { + typealias HTTPServerByteBufferResponsePart = HTTPPart + + final class HTTPByteBufferResponseChannelHandler: ChannelOutboundHandler, RemovableChannelHandler { + typealias OutboundIn = HTTPServerByteBufferResponsePart + typealias OutboundOut = HTTPServerResponsePart + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let part = unwrapOutboundIn(data) + switch part { + case .head(let head): context.write(self.wrapOutboundOut(.head(head)), promise: promise) + case .body(let buffer): context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) + case .end(let headers): context.write(self.wrapOutboundOut(.end(headers)), promise: promise) + } + } + } + +} diff --git a/Tests/OpenAPIURLSessionTests/TestUtils.swift b/Tests/OpenAPIURLSessionTests/TestUtils.swift new file mode 100644 index 0000000..7798e77 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/TestUtils.swift @@ -0,0 +1,166 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import Foundation +#if !canImport(Darwin) && canImport(FoundationNetworking) +import FoundationNetworking +#endif +import OpenAPIRuntime +import XCTest + +func XCTAssertThrowsError( + _ expression: @autoclosure () async throws -> T, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line, + _ errorHandler: (_ error: any Error) -> Void = { _ in } +) async { + do { + _ = try await expression() + XCTFail("expression did not throw", file: file, line: line) + } catch { errorHandler(error) } +} + +func XCTSkipUnlessAsync( + _ expression: @autoclosure () async throws -> Bool, + _ message: @autoclosure () -> String? = nil, + file: StaticString = #filePath, + line: UInt = #line +) async throws { + let result = try await expression() + try XCTSkipUnless(result, message(), file: file, line: line) +} + +func XCTUnwrapAsync( + _ expression: @autoclosure () async throws -> T?, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line +) async throws -> T { + let maybeValue = try await expression() + return try XCTUnwrap(maybeValue, message(), file: file, line: line) +} + +func XCTAssertNilAsync( + _ expression: @autoclosure () async throws -> Any?, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line +) async throws { + let maybeValue = try await expression() + XCTAssertNil(maybeValue, message(), file: file, line: line) +} + +extension URL { + var withoutPath: URL { + var components = URLComponents(url: self, resolvingAgainstBaseURL: false)! + components.path = "" + return components.url! + } +} + +extension Collection { + func chunks(of size: Int) -> [[Element]] { + precondition(size > 0) + var chunkStart = startIndex + var results = [[Element]]() + results.reserveCapacity((count - 1) / size + 1) + while chunkStart < endIndex { + let chunkEnd = index(chunkStart, offsetBy: size, limitedBy: endIndex) ?? endIndex + results.append(Array(self[chunkStart..: @unchecked Sendable where Value: Sendable { + private let lock: NSLock = { + let lock = NSLock() + lock.name = "com.apple.swift-openapi-urlsession.lock.LockedValueBox" + return lock + }() + private var value: Value + init(_ value: Value) { self.value = value } + func withValue(_ work: (inout Value) throws -> R) rethrows -> R { + lock.lock() + defer { lock.unlock() } + return try work(&value) + } +} + +extension AsyncStream { + // We have this here until we drop 5.8, since it's in the standard library in Swift 5.9+. + static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: Self.Continuation.BufferingPolicy = .unbounded + ) -> (stream: Self, continuation: Self.Continuation) { + var continuation: Self.Continuation! + let stream = Self(elementType, bufferingPolicy: limit) { continuation = $0 } + return (stream, continuation) + } +} diff --git a/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/HTTPBodyOutputStreamTests.swift b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/HTTPBodyOutputStreamTests.swift new file mode 100644 index 0000000..b173ec0 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/HTTPBodyOutputStreamTests.swift @@ -0,0 +1,274 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if canImport(Darwin) + +import OpenAPIRuntime +import XCTest +@testable import OpenAPIURLSession + +// swift-format-ignore: AllPublicDeclarationsHaveDocumentation +class HTTPBodyOutputStreamBridgeTests: XCTestCase { + static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = true } + + func testHTTPBodyOutputStreamInputOutput() async throws { + let chunkSize = 71 + let streamBufferSize = 37 + let numBytes: UInt8 = .max + + // Create a HTTP body with one byte per chunk. + let requestBytes = (0...numBytes).map { UInt8($0) } + let requestChunks = requestBytes.chunks(of: chunkSize) + let requestByteSequence = MockAsyncSequence(elementsToVend: requestChunks, gatingProduction: false) + let requestBody = HTTPBody(requestByteSequence, length: .known(requestBytes.count), iterationBehavior: .single) + + // Create a pair of bound streams with a tiny buffer to be the bottleneck for backpressure. + var inputStream: InputStream? + var outputStream: OutputStream? + Stream.getBoundStreams(withBufferSize: streamBufferSize, inputStream: &inputStream, outputStream: &outputStream) + guard let inputStream, let outputStream else { fatalError("getBoundStreams did not return non-nil streams") } + + // Bridge the HTTP body to the output stream. + let requestStream = HTTPBodyOutputStreamBridge(outputStream, requestBody) + + // Set up a mock delegate to drive the stream pair. + let delegate = MockInputStreamDelegate(inputStream: inputStream) + + // Read all the data from the input stream using max bytes > stream buffer size. + var data = [UInt8]() + data.reserveCapacity(requestBytes.count) + while let inputStreamBytes = try await delegate.waitForBytes(maxBytes: 4096) { + data.append(contentsOf: inputStreamBytes) + } + XCTAssertEqual(data, requestBytes) + + // Check all bytes have been vended. + XCTAssertEqual(requestByteSequence.elementsVended.count, requestByteSequence.elementsToVend.count) + + // Input stream delegate will have reached end of stream and closed the input stream. + XCTAssertEqual(inputStream.streamStatus, .closed) + XCTAssertNil(inputStream.streamError) + + // Check the output stream closes gracefully in response to the input stream closing. + HTTPBodyOutputStreamBridge.streamQueue.asyncAndWait { + XCTAssertEqual(requestStream.outputStream.streamStatus, .closed) + XCTAssertNil(requestStream.outputStream.streamError) + } + } + + func testHTTPBodyOutputStreamBridgeBackpressure() async throws { + let chunkSize = 71 + let streamBufferSize = 37 + let numBytes: UInt8 = .max + + // Create a HTTP body with one byte per chunk. + let requestBytes = (0...numBytes).map { UInt8($0) } + let requestChunks = requestBytes.chunks(of: chunkSize) + let requestByteSequence = MockAsyncSequence(elementsToVend: requestChunks, gatingProduction: true) + let requestBody = HTTPBody(requestByteSequence, length: .known(requestBytes.count), iterationBehavior: .single) + + // Create a pair of bound streams with a tiny buffer to be the bottleneck for backpressure. + var inputStream: InputStream? + var outputStream: OutputStream? + Stream.getBoundStreams(withBufferSize: streamBufferSize, inputStream: &inputStream, outputStream: &outputStream) + guard let inputStream, let outputStream else { fatalError("getBoundStreams did not return non-nil streams") } + + // Bridge the HTTP body to the output stream. + let requestStream = HTTPBodyOutputStreamBridge(outputStream, requestBody) + + // Set up a mock delegate to drive the stream pair. + let delegate = MockInputStreamDelegate(inputStream: inputStream) + _ = delegate + + // Check both streams have been opened. + XCTAssertEqual(outputStream.streamStatus, .open) + XCTAssertEqual(inputStream.streamStatus, .open) + + // At this point, because our mock async sequence that's backing the output stream is gated: + // - The mock async sequence has vended zero elements. + // - The output stream bridge has read nothing from from the async sequence. + // - The output stream bridge has written nothing to the output stream. + // - The output stream should have space available, the entire size of the buffer. + XCTAssert(requestByteSequence.elementsVended.isEmpty) + XCTAssertEqual(outputStream.streamStatus, .open) + // XCTAssert(requestStream.bytesToWrite.isEmpty) + XCTAssert(outputStream.hasSpaceAvailable) + + // Now we'll tell our mock sequence to let through as many bytes as it can. + requestByteSequence.openGate() + + // After some time, the buffer will be full. + let expectation = expectation(description: "output stream has no space available") + HTTPBodyOutputStreamBridge.streamQueue.asyncAfter(deadline: .now() + .milliseconds(100)) { + if !requestStream.outputStream.hasSpaceAvailable { expectation.fulfill() } + } + await fulfillment(of: [expectation], timeout: 0.5) + + // The underlying sequence should only have vended enough chunks to fill the buffer. + XCTAssertEqual(requestByteSequence.elementsVended.count, (streamBufferSize - 1) / chunkSize + 1) + } + + func testHTTPBodyOutputStreamPullThroughBufferOneByteBig() async throws { + let chunkSize = 1 + let streamBufferSize = 1 + let numBytes: UInt8 = .max + + // Create a HTTP body with one byte per chunk. + let requestBytes = (0...numBytes).map { UInt8($0) } + let requestChunks = requestBytes.chunks(of: chunkSize) + let requestByteSequence = MockAsyncSequence(elementsToVend: requestChunks, gatingProduction: true) + let requestBody = HTTPBody(requestByteSequence, length: .known(requestBytes.count), iterationBehavior: .single) + + // Create a pair of bound streams with a tiny buffer to be the bottleneck for backpressure. + var inputStream: InputStream? + var outputStream: OutputStream? + Stream.getBoundStreams(withBufferSize: streamBufferSize, inputStream: &inputStream, outputStream: &outputStream) + guard let inputStream, let outputStream else { fatalError("getBoundStreams did not return non-nil streams") } + + // Bridge the HTTP body to the output stream. + let requestStream = HTTPBodyOutputStreamBridge(outputStream, requestBody) + + // Set up a mock delegate to drive the stream pair. + let delegate = MockInputStreamDelegate(inputStream: inputStream) + + // Read one byte at a time from the input sequence, which will make space in the buffer. + for i in 0..: AsyncSequence, Sendable where Element: Sendable { + var elementsToVend: [Element] + private let _elementsVended: LockedValueBox<[Element]> + var elementsVended: [Element] { _elementsVended.withValue { $0 } } + private let semaphore: DispatchSemaphore? + + init(elementsToVend: [Element], gatingProduction: Bool) { + self.elementsToVend = elementsToVend + self._elementsVended = LockedValueBox([]) + self.semaphore = gatingProduction ? DispatchSemaphore(value: 0) : nil + } + + func openGate(for count: Int) { for _ in 0.. AsyncIterator { + AsyncIterator(elementsToVend: elementsToVend[...], semaphore: semaphore, elementsVended: _elementsVended) + } + + final class AsyncIterator: AsyncIteratorProtocol { + var elementsToVend: ArraySlice + var semaphore: DispatchSemaphore? + var elementsVended: LockedValueBox<[Element]> + + init( + elementsToVend: ArraySlice, + semaphore: DispatchSemaphore?, + elementsVended: LockedValueBox<[Element]> + ) { + self.elementsToVend = elementsToVend + self.semaphore = semaphore + self.elementsVended = elementsVended + } + + func next() async throws -> Element? { + await withCheckedContinuation { continuation in + semaphore?.wait() + continuation.resume() + } + guard let element = elementsToVend.popFirst() else { return nil } + elementsVended.withValue { $0.append(element) } + return element + } + } +} + +#endif // #if canImport(Darwin) diff --git a/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockInputStreamDelegate.swift b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockInputStreamDelegate.swift new file mode 100644 index 0000000..78d9325 --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/MockInputStreamDelegate.swift @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if canImport(Darwin) + +import Foundation + +/// Reads one byte at a time from the stream, regardless of how many bytes are available. +/// +/// Used for testing the HTTPOutputStreamBridge backpressure behaviour, without URLSession. +final class MockInputStreamDelegate: NSObject, StreamDelegate { + static let streamQueue = DispatchQueue(label: "MockInputStreamDelegate", autoreleaseFrequency: .workItem) + + private var inputStream: InputStream + + enum State { + case noWaiter + case haveWaiter(CheckedContinuation<[UInt8]?, any Error>, maxBytes: Int) + case closed((any Error)?) + } + private(set) var state: State + + init(inputStream: InputStream) { + self.inputStream = inputStream + self.state = .noWaiter + super.init() + self.inputStream.delegate = self + CFReadStreamSetDispatchQueue(self.inputStream as CFReadStream, Self.streamQueue) + self.inputStream.open() + } + + deinit { print("Input stream delegate deinit") } + + private func readAndResumeContinuation() { + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + guard case .haveWaiter(let continuation, let maxBytes) = state else { + preconditionFailure("Invalid state: \(state)") + } + guard inputStream.hasBytesAvailable else { return } + let buffer = [UInt8](unsafeUninitializedCapacity: maxBytes) { buffer, count in + count = inputStream.read(buffer.baseAddress!, maxLength: maxBytes) + } + switch buffer.count { + case -1: + print("Input stream delegate error reading from stream: \(inputStream.streamError!)") + inputStream.close() + continuation.resume(throwing: inputStream.streamError!) + case 0: + print("Input stream delegate reached end of stream; will close stream") + self.close() + continuation.resume(returning: nil) + case let numBytesRead where numBytesRead > 0: + print("Input stream delegate read \(numBytesRead) bytes from stream: \(buffer)") + continuation.resume(returning: buffer) + default: preconditionFailure() + } + state = .noWaiter + } + + func waitForBytes(maxBytes: Int) async throws -> [UInt8]? { + if inputStream.streamStatus == .closed { + state = .closed(inputStream.streamError) + guard let error = inputStream.streamError else { return nil } + throw error + } + return try await withCheckedThrowingContinuation { continuation in + Self.streamQueue.async { + guard case .noWaiter = self.state else { preconditionFailure() } + self.state = .haveWaiter(continuation, maxBytes: maxBytes) + self.readAndResumeContinuation() + } + } + } + + func close(withError error: (any Error)? = nil) { + self.inputStream.close() + Self.streamQueue.async { self.state = .closed(error) } + print("Input stream delegate closed stream with error: \(String(describing: error))") + } + + func stream(_ stream: Stream, handle event: Stream.Event) { + dispatchPrecondition(condition: .onQueue(Self.streamQueue)) + print("Input stream delegate received event: \(event)") + switch event { + case .hasBytesAvailable: + switch state { + case .haveWaiter: readAndResumeContinuation() + case .noWaiter: break + case .closed: preconditionFailure() + } + case .errorOccurred: self.close() + default: break + } + } +} + +extension MockInputStreamDelegate: @unchecked Sendable {} // State synchronized using DispatchQueue. + +#endif // canImport(Darwin) diff --git a/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/URLSessionBidirectionalStreamingTests.swift b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/URLSessionBidirectionalStreamingTests.swift new file mode 100644 index 0000000..f4fd6fd --- /dev/null +++ b/Tests/OpenAPIURLSessionTests/URLSessionBidirectionalStreamingTests/URLSessionBidirectionalStreamingTests.swift @@ -0,0 +1,399 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if canImport(Darwin) + +import Foundation +import HTTPTypes +import NIO +import NIOHTTP1 +import OpenAPIRuntime +import XCTest +@testable import OpenAPIURLSession + +class URLSessionBidirectionalStreamingTests: XCTestCase { + // swift-format-ignore: AllPublicDeclarationsHaveDocumentation + static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = true } + + func testBidirectionalEcho_PerChunkRatchet_1BChunk_1Chunks_1BUploadBuffer_1BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1)[...], + numRequestBodyChunks: 1, + uploadBufferSize: 1, + responseStreamWatermarks: (low: 1, high: 1) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_1BChunk_10Chunks_1BUploadBuffer_1BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 1, + responseStreamWatermarks: (low: 1, high: 1) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_1BChunk_10Chunks_10BUploadBuffer_1BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 10, + responseStreamWatermarks: (low: 1, high: 1) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_1BChunk_10Chunks_1BUploadBuffer_10BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 1, + responseStreamWatermarks: (low: 10, high: 10) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_1BChunk_10Chunks_10BUploadBuffer_10BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 10, + responseStreamWatermarks: (low: 10, high: 10) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_10BChunk_10Chunks_1BUploadBuffer_1BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 10)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 1, + responseStreamWatermarks: (low: 1, high: 1) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_10BChunk_10Chunks_10BUploadBuffer_1BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 10)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 10, + responseStreamWatermarks: (low: 1, high: 1) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_10BChunk_10Chunks_1BUploadBuffer_10BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 10)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 1, + responseStreamWatermarks: (low: 10, high: 10) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_10BChunk_10Chunks_10BUploadBuffer_10BDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 10)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 10, + responseStreamWatermarks: (low: 10, high: 10) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_4kChunk_10Chunks_16kUploadBuffer_4kDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 4 * 1024)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 16 * 1024, + responseStreamWatermarks: (low: 4096, high: 4096) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_1MChunk_10Chunks_16kUploadBuffer_4kDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 1 * 1024 * 1024)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 16 * 1024, + responseStreamWatermarks: (low: 4096, high: 4096) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_10MChunk_10Chunks_1MUploadBuffer_1MDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 10 * 1024 * 1024)[...], + numRequestBodyChunks: 10, + uploadBufferSize: 1 * 1024 * 1024, + responseStreamWatermarks: (low: 1 * 1024 * 1024, high: 1 * 1024 * 1024) + ) + } + + func testBidirectionalEcho_PerChunkRatchet_100kChunk_100Chunks_1MUploadBuffer_1MDownloadWatermark() async throws { + try await testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: Array(repeating: UInt8(ascii: "*"), count: 100 * 1024)[...], + numRequestBodyChunks: 100, + uploadBufferSize: 1 * 1024 * 1024, + responseStreamWatermarks: (low: 1 * 1024 * 1024, high: 1 * 1024 * 1024) + ) + } + + func testBidirectionalEchoPerChunkRatchet( + requestBodyChunk: HTTPBody.ByteChunk, + numRequestBodyChunks: Int, + uploadBufferSize: Int, + responseStreamWatermarks: (low: Int, high: Int) + ) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + // Server task. + let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in + try await connectionChannel.executeThenClose { inbound, outbound in + for try await requestPart in inbound { + switch requestPart { + case .head(_): + try await outbound.write( + .head( + .init( + version: .http1_1, + status: .ok, + headers: ["Content-Type": "application/octet-stream"] + ) + ) + ) + case .body(let buffer): try await outbound.write(.body(buffer)) + case .end(_): try await outbound.write(.end(nil)) + } + } + } + } + + // Set up the request body. + let (requestBodyStream, requestBodyStreamContinuation) = AsyncStream.makeStream() + let requestBody = HTTPBody(requestBodyStream, length: .unknown, iterationBehavior: .single) + + // Start the request. + async let asyncResponse = URLSession.shared.bidirectionalStreamingRequest( + for: HTTPRequest( + method: .post, + scheme: nil, + authority: nil, + path: "/some/path", + headerFields: [.contentType: "application/octet-stream"] + ), + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + requestBody: requestBody, + requestStreamBufferSize: uploadBufferSize, + responseStreamWatermarks: responseStreamWatermarks + ) + + /// At this point in the test, the server has sent the response head, which can be verified in Wireshark. + /// + /// A quirk of URLSession is that it won't fire the `didReceive response` callback, even if it has received + /// the response head, until it has received at least one body byte, even when the server response headers + /// indicate that the content-type is `application/octet-stream` and the transfer encoding is chunked. + /// + /// It's also worth noting that URLSession implements content sniffing so, if the content-type is absent, + /// it will not call the `didReceive response` callback until it has received many more bytes. + /// + /// Additionally, there's no requirement on client libraries (or any intermediaries) to deliver partial + /// responses to users, so the ability to affect this particular request response pattern entirely depends + /// on the implementation details of the HTTP client libary. + /// + /// So... we send the first request chunk here, and have the server echo it back. + requestBodyStreamContinuation.yield(requestBodyChunk) + + // We can now get the response head and the response body stream. + let (response, responseBody) = try await asyncResponse + XCTAssertEqual(response.status, .ok) + + // Consume and verify the first response chunk. + var responseBodyIterator = responseBody!.makeAsyncIterator() + var pendingExpectedResponseBytes = requestBodyChunk + while !pendingExpectedResponseBytes.isEmpty { + let responseBodyChunk = try await responseBodyIterator.next()! + XCTAssertEqual(responseBodyChunk, pendingExpectedResponseBytes.prefix(responseBodyChunk.count)) + pendingExpectedResponseBytes.removeFirst(responseBodyChunk.count) + } + + // Send the remaining request chunks, one at a time, and check the echoed response chunk. + for _ in 1..= responseChunk.count { + print("Client reconstructing and verifying chunk \(numProcessedChunks+1)/\(numResponseChunks)") + XCTAssertEqual( + ArraySlice(unprocessedBytes.readBytes(length: responseChunk.count)!), + responseChunk + ) + unprocessedBytes.discardReadBytes() + numProcessedChunks += 1 + } + } + XCTAssertEqual(unprocessedBytes.readableBytes, 0) + XCTAssertEqual(numProcessedChunks, numResponseChunks) + case .count: + var numBytesReceived = 0 + for try await receivedResponseChunk in responseBody! { + print("Client received some response body bytes (numBytes: \(receivedResponseChunk.count))") + numBytesReceived += receivedResponseChunk.count + } + XCTAssertEqual(numBytesReceived, responseChunk.count * numResponseChunks) + case .delay(let delay): + for try await receivedResponseChunk in responseBody! { + print("Client received some response body bytes (numBytes: \(receivedResponseChunk.count))") + print("Client doing fake work for \(delay)s") + try await Task.sleep(for: delay) + } + } + + group.cancelAll() + } + } +} + +#endif // canImport(Darwin) diff --git a/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift b/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift index 5b97ec6..a14c14d 100644 --- a/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift +++ b/Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift @@ -11,25 +11,20 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -import XCTest -import OpenAPIRuntime -#if canImport(Darwin) import Foundation -#else -@preconcurrency import struct Foundation.URL +#if !canImport(Darwin) && canImport(FoundationNetworking) +import FoundationNetworking #endif -#if canImport(FoundationNetworking) -@preconcurrency import struct FoundationNetworking.URLRequest -@preconcurrency import class FoundationNetworking.URLProtocol -@preconcurrency import class FoundationNetworking.URLSession -@preconcurrency import class FoundationNetworking.HTTPURLResponse -@preconcurrency import class FoundationNetworking.URLResponse -@preconcurrency import class FoundationNetworking.URLSessionConfiguration -#endif -@testable import OpenAPIURLSession import HTTPTypes +import NIO +import NIOHTTP1 +import OpenAPIRuntime +import XCTest +@testable import OpenAPIURLSession -class URLSessionTransportTests: XCTestCase { +// swift-format-ignore: AllPublicDeclarationsHaveDocumentation +class URLSessionTransportConverterTests: XCTestCase { + static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = true } func testRequestConversion() async throws { let request = HTTPRequest( @@ -39,13 +34,11 @@ class URLSessionTransportTests: XCTestCase { path: "/hello%20world/Maria?greeting=Howdy", headerFields: [.init("x-mumble2")!: "mumble"] ) - let body: HTTPBody = "👋" - let urlRequest = try await URLRequest(request, body: body, baseURL: URL(string: "http://example.com/api")!) + let urlRequest = try URLRequest(request, baseURL: URL(string: "http://example.com/api")!) XCTAssertEqual(urlRequest.url, URL(string: "http://example.com/api/hello%20world/Maria?greeting=Howdy")) XCTAssertEqual(urlRequest.httpMethod, "POST") XCTAssertEqual(urlRequest.allHTTPHeaderFields?.count, 1) XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "x-mumble2"), "mumble") - XCTAssertEqual(urlRequest.httpBody, Data("👋".utf8)) } func testResponseConversion() async throws { @@ -55,87 +48,265 @@ class URLSessionTransportTests: XCTestCase { httpVersion: "HTTP/1.1", headerFields: ["x-mumble3": "mumble"] )! - let (response, maybeResponseBody) = try HTTPResponse.response( - method: .get, - urlResponse: urlResponse, - data: Data("👋".utf8) - ) - let responseBody = try XCTUnwrap(maybeResponseBody) + let response = try HTTPResponse(urlResponse) XCTAssertEqual(response.status.code, 201) XCTAssertEqual(response.headerFields, [.init("x-mumble3")!: "mumble"]) - let bufferedResponseBody = try await String(collecting: responseBody, upTo: .max) - XCTAssertEqual(bufferedResponseBody, "👋") } +} + +// swift-format-ignore: AllPublicDeclarationsHaveDocumentation +class URLSessionTransportBufferedTests: XCTestCase { + var transport: (any ClientTransport)! + + static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = true } + + override func setUp() async throws { + transport = URLSessionTransport(configuration: .init(implementation: .buffering)) + } + + func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) } + + func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) } + + #if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307. + func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws { + try await testHTTPRedirect( + transport: transport, + requestBodyIterationBehavior: .multiple, + expectFailureDueToIterationBehavior: false + ) + } + + func testHTTPRedirect_singleIterationBehavior_succeeds() async throws { + try await testHTTPRedirect( + transport: transport, + requestBodyIterationBehavior: .single, + expectFailureDueToIterationBehavior: false + ) + } + #endif +} + +// swift-format-ignore: AllPublicDeclarationsHaveDocumentation +class URLSessionTransportStreamingTests: XCTestCase { + var transport: (any ClientTransport)! - func testSend() async throws { - let endpointURL = URL(string: "http://example.com/api/hello%20world/Maria?greeting=Howdy")! - MockURLProtocol.mockHTTPResponses.withValue { map in - map[endpointURL] = .success( - ( - HTTPURLResponse(url: endpointURL, statusCode: 201, httpVersion: nil, headerFields: [:])!, - body: Data("👋".utf8) + static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = true } + + override func setUpWithError() throws { + try XCTSkipUnless(URLSessionTransport.Configuration.Implementation.platformSupportsStreaming) + self.transport = URLSessionTransport( + configuration: .init( + implementation: .streaming( + requestBodyStreamBufferSize: 16 * 1024, + responseBodyStreamWatermarks: (low: 16 * 1024, high: 32 * 1024) ) ) - } - let transport: any ClientTransport = URLSessionTransport( - configuration: .init(session: MockURLProtocol.mockURLSession) ) - let request = HTTPRequest( - method: .post, - scheme: nil, - authority: nil, - path: "/hello%20world/Maria?greeting=Howdy", - headerFields: [.init("x-mumble1")!: "mumble"] + } + + func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) } + + func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) } + + #if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307. + func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws { + try await testHTTPRedirect( + transport: transport, + requestBodyIterationBehavior: .multiple, + expectFailureDueToIterationBehavior: false ) - let requestBody: HTTPBody = "👋" - let (response, maybeResponseBody) = try await transport.send( - request, - body: requestBody, - baseURL: URL(string: "http://example.com/api")!, - operationID: "postGreeting" + } + + func testHTTPRedirect_singleIterationBehavior_fails() async throws { + try await testHTTPRedirect( + transport: transport, + requestBodyIterationBehavior: .single, + expectFailureDueToIterationBehavior: true ) - let responseBody = try XCTUnwrap(maybeResponseBody) - XCTAssertEqual(response.status.code, 201) - let bufferedResponseBody = try await String(collecting: responseBody, upTo: .max) - XCTAssertEqual(bufferedResponseBody, "👋") } + #endif } -class MockURLProtocol: URLProtocol { - typealias MockHTTPResponseMap = [URL: Result<(response: HTTPURLResponse, body: Data?), any Error>] - static let mockHTTPResponses = LockedValueBox([:]) +class URLSessionTransportPlatformSupportTests: XCTestCase { + func testDefaultsToStreamingIfSupported() { + if URLSessionTransport.Configuration.Implementation.platformSupportsStreaming { + guard case .streaming = URLSessionTransport.Configuration.Implementation.platformDefault else { + XCTFail() + return + } + } else { + guard case .buffering = URLSessionTransport.Configuration.Implementation.platformDefault else { + XCTFail() + return + } + } + } +} - static let recordedHTTPRequests = LockedValueBox<[URLRequest]>([]) +func testHTTPRedirect( + transport: any ClientTransport, + requestBodyIterationBehavior: HTTPBody.IterationBehavior, + expectFailureDueToIterationBehavior: Bool +) async throws { + let requestBodyChunks = ["✊", "✊", " ", "knock", " ", "knock!"] + let requestBody = HTTPBody( + requestBodyChunks.async, + length: .known(requestBodyChunks.joined().lengthOfBytes(using: .utf8)), + iterationBehavior: requestBodyIterationBehavior + ) - /// Determines whether this protocol can handle the given request. - override class func canInit(with request: URLRequest) -> Bool { true } + try await withThrowingTaskGroup(of: Void.self) { group in + let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in + try await connectionChannel.executeThenClose { inbound, outbound in + var requestPartIterator = inbound.makeAsyncIterator() + var currentURI: String? = nil + var accumulatedBody = ByteBuffer() + while let requestPart = try await requestPartIterator.next() { + switch requestPart { + case .head(let head): + print("Server received head for \(head.uri)") + currentURI = head.uri + case .body(let buffer): + let currentURI = try XCTUnwrap(currentURI) + print("Server received body bytes for \(currentURI) (numBytes: \(buffer.readableBytes))") + accumulatedBody.writeImmutableBuffer(buffer) + case .end: + let currentURI = try XCTUnwrap(currentURI) + print("Server received end for \(currentURI)") + XCTAssertEqual(accumulatedBody, ByteBuffer(string: requestBodyChunks.joined())) + switch currentURI { + case "/old": + print("Server reseting body buffer") + accumulatedBody = ByteBuffer() + try await outbound.write( + .head( + .init(version: .http1_1, status: .temporaryRedirect, headers: ["Location": "/new"]) + ) + ) + print("Server sent head for \(currentURI)") + try await outbound.write(.end(nil)) + print("Server sent end for \(currentURI)") + case "/new": + try await outbound.write(.head(.init(version: .http1_1, status: .ok))) + print("Server sent head for \(currentURI)") + try await outbound.write(.end(nil)) + print("Server sent end for \(currentURI)") + default: preconditionFailure() + } + } + } + } + } + print("Server running on 127.0.0.1:\(serverPort)") - /// Returns a canonical version of the given request. - override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } + // Send the request. + print("Client starting request") + if expectFailureDueToIterationBehavior { + await XCTAssertThrowsError( + try await transport.send( + HTTPRequest(method: .post, scheme: nil, authority: nil, path: "/old"), + body: requestBody, + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + operationID: "unused" + ) + ) { error in XCTAssertEqual((error as? URLError)?.code, .cancelled, "Unexpected error: \(error)") } + } else { + let (response, _) = try await transport.send( + HTTPRequest(method: .post, scheme: nil, authority: nil, path: "/old"), + body: requestBody, + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + operationID: "unused" + ) + print("Client received response head: \(response)") + XCTAssertEqual(response.status, .ok) + } - /// Stops protocol-specific loading of a request. - override func stopLoading() {} + group.cancelAll() + } +} - /// Starts protocol-specific loading of a request. - override func startLoading() { - Self.recordedHTTPRequests.withValue { $0.append(self.request) } - guard let url = self.request.url else { return } - guard let response = Self.mockHTTPResponses.withValue({ $0[url] }) else { return } - switch response { - case .success(let mockResponse): - client?.urlProtocol(self, didReceive: mockResponse.response, cacheStoragePolicy: .notAllowed) - if let data = mockResponse.body { client?.urlProtocol(self, didLoad: data) } - client?.urlProtocolDidFinishLoading(self) - case let .failure(error): client?.urlProtocol(self, didFailWithError: error) +func testHTTPBasicGet(transport: any ClientTransport) async throws { + let requestPath = "/hello/world" + let responseBodyMessage = "Hey!" + + try await withThrowingTaskGroup(of: Void.self) { group in + let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in + try await connectionChannel.executeThenClose { inbound, outbound in + var requestPartIterator = inbound.makeAsyncIterator() + while let requestPart = try await requestPartIterator.next() { + switch requestPart { + case .head(let head): + XCTAssertEqual(head.uri, requestPath) + XCTAssertEqual(head.method, .GET) + case .body: XCTFail("Didn't expect any request body bytes.") + case .end: + try await outbound.write(.head(.init(version: .http1_1, status: .ok))) + try await outbound.write(.body(ByteBuffer(string: responseBodyMessage))) + try await outbound.write(.end(nil)) + } + } + } } + print("Server running on 127.0.0.1:\(serverPort)") + + // Send the request. + print("Client starting request") + let (response, maybeResponseBody) = try await transport.send( + HTTPRequest(method: .get, scheme: nil, authority: nil, path: requestPath), + body: nil, + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + operationID: "unused" + ) + print("Client received response head: \(response)") + XCTAssertEqual(response.status, .ok) + let receivedMessage = try await String(collecting: try XCTUnwrap(maybeResponseBody), upTo: .max) + XCTAssertEqual(receivedMessage, responseBodyMessage) + + group.cancelAll() } +} + +func testHTTPBasicPost(transport: any ClientTransport) async throws { + let requestPath = "/hello/world" + let requestBodyMessage = "Hello, world!" + let responseBodyMessage = "Hey!" + + try await withThrowingTaskGroup(of: Void.self) { group in + let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in + try await connectionChannel.executeThenClose { inbound, outbound in + var requestPartIterator = inbound.makeAsyncIterator() + var accumulatedBody = ByteBuffer() + while let requestPart = try await requestPartIterator.next() { + switch requestPart { + case .head(let head): + XCTAssertEqual(head.uri, requestPath) + XCTAssertEqual(head.method, .POST) + case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer) + case .end: + XCTAssertEqual(accumulatedBody, ByteBuffer(string: requestBodyMessage)) + try await outbound.write(.head(.init(version: .http1_1, status: .ok))) + try await outbound.write(.body(ByteBuffer(string: responseBodyMessage))) + try await outbound.write(.end(nil)) + } + } + } + } + print("Server running on 127.0.0.1:\(serverPort)") + + // Send the request. + print("Client starting request") + let (response, maybeResponseBody) = try await transport.send( + HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath), + body: HTTPBody(requestBodyMessage), + baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!, + operationID: "unused" + ) + print("Client received response head: \(response)") + XCTAssertEqual(response.status, .ok) + let receivedMessage = try await String(collecting: try XCTUnwrap(maybeResponseBody), upTo: .max) + XCTAssertEqual(receivedMessage, responseBodyMessage) - static var mockURLSession: URLSession { - let configuration: URLSessionConfiguration = .ephemeral - configuration.protocolClasses = [Self.self] - configuration.timeoutIntervalForRequest = 0.1 - configuration.timeoutIntervalForResource = 0.1 - configuration.requestCachePolicy = .reloadIgnoringLocalAndRemoteCacheData - return URLSession(configuration: configuration) + group.cancelAll() } }