diff --git a/Package.resolved b/Package.resolved index 5a1c5ca..b079764 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,160 +1,167 @@ { - "object": { - "pins": [ - { - "package": "async-http-client", - "repositoryURL": "https://github.com/swift-server/async-http-client.git", - "state": { - "branch": null, - "revision": "16f7e62c08c6969899ce6cc277041e868364e5cf", - "version": "1.19.0" - } - }, - { - "package": "smoke-aws", - "repositoryURL": "https://github.com/amzn/smoke-aws.git", - "state": { - "branch": null, - "revision": "35edcb634ac6f6cb25e2bb449acf5f60201c8338", - "version": "2.44.298" - } - }, - { - "package": "smoke-aws-support", - "repositoryURL": "https://github.com/amzn/smoke-aws-support.git", - "state": { - "branch": null, - "revision": "4f77513b76d28694dc51dffd0dfd6de6602724a6", - "version": "1.3.1" - } - }, - { - "package": "smoke-http", - "repositoryURL": "https://github.com/amzn/smoke-http.git", - "state": { - "branch": null, - "revision": "84e6805ca07f9e4d7c39fbf61b7feff0484ee67f", - "version": "2.21.0" - } - }, - { - "package": "swift-atomics", - "repositoryURL": "https://github.com/apple/swift-atomics.git", - "state": { - "branch": null, - "revision": "6c89474e62719ddcc1e9614989fff2f68208fe10", - "version": "1.1.0" - } - }, - { - "package": "swift-collections", - "repositoryURL": "https://github.com/apple/swift-collections.git", - "state": { - "branch": null, - "revision": "937e904258d22af6e447a0b72c0bc67583ef64a2", - "version": "1.0.4" - } - }, - { - "package": "swift-crypto", - "repositoryURL": "https://github.com/apple/swift-crypto.git", - "state": { - "branch": null, - "revision": "ddb07e896a2a8af79512543b1c7eb9797f8898a5", - "version": "1.1.7" - } - }, - { - "package": "swift-distributed-tracing", - "repositoryURL": "https://github.com/apple/swift-distributed-tracing.git", - "state": { - "branch": null, - "revision": "49b7617717a09f6b781c9a11e1628e3315d8d4fe", - "version": "1.0.1" - } - }, - { - "package": "swift-log", - "repositoryURL": "https://github.com/apple/swift-log.git", - "state": { - "branch": null, - "revision": "532d8b529501fb73a2455b179e0bbb6d49b652ed", - "version": "1.5.3" - } - }, - { - "package": "swift-metrics", - "repositoryURL": "https://github.com/apple/swift-metrics.git", - "state": { - "branch": null, - "revision": "971ba26378ab69c43737ee7ba967a896cb74c0d1", - "version": "2.4.1" - } - }, - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "cf281631ff10ec6111f2761052aa81896a83a007", - "version": "2.58.0" - } - }, - { - "package": "swift-nio-extras", - "repositoryURL": "https://github.com/apple/swift-nio-extras.git", - "state": { - "branch": null, - "revision": "0e0d0aab665ff1a0659ce75ac003081f2b1c8997", - "version": "1.19.0" - } - }, - { - "package": "swift-nio-http2", - "repositoryURL": "https://github.com/apple/swift-nio-http2.git", - "state": { - "branch": null, - "revision": "a8ccf13fa62775277a5d56844878c828bbb3be1a", - "version": "1.27.0" - } - }, - { - "package": "swift-nio-ssl", - "repositoryURL": "https://github.com/apple/swift-nio-ssl.git", - "state": { - "branch": null, - "revision": "320bd978cceb8e88c125dcbb774943a92f6286e9", - "version": "2.25.0" - } - }, - { - "package": "swift-nio-transport-services", - "repositoryURL": "https://github.com/apple/swift-nio-transport-services.git", - "state": { - "branch": null, - "revision": "e7403c35ca6bb539a7ca353b91cc2d8ec0362d58", - "version": "1.19.0" - } - }, - { - "package": "swift-service-context", - "repositoryURL": "https://github.com/apple/swift-service-context.git", - "state": { - "branch": null, - "revision": "ce0141c8f123132dbd02fd45fea448018762df1b", - "version": "1.0.0" - } - }, - { - "package": "XMLCoding", - "repositoryURL": "https://github.com/LiveUI/XMLCoding.git", - "state": { - "branch": null, - "revision": "f0fbfe17e73f329e13a6133ff5437f7b174049fd", - "version": "0.4.1" - } - } - ] - }, - "version": 1 + "pins" : [ + { + "identity" : "async-http-client", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swift-server/async-http-client.git", + "state" : { + "revision" : "16f7e62c08c6969899ce6cc277041e868364e5cf", + "version" : "1.19.0" + } + }, + { + "identity" : "smoke-aws", + "kind" : "remoteSourceControl", + "location" : "https://github.com/amzn/smoke-aws.git", + "state" : { + "revision" : "07e658e0fdc8a46923156c30a76fdd5d7427465e", + "version" : "2.46.5" + } + }, + { + "identity" : "smoke-aws-support", + "kind" : "remoteSourceControl", + "location" : "https://github.com/amzn/smoke-aws-support.git", + "state" : { + "revision" : "141efadb31e399736b23cfd2478af3dbdc170259", + "version" : "1.5.0" + } + }, + { + "identity" : "smoke-http", + "kind" : "remoteSourceControl", + "location" : "https://github.com/amzn/smoke-http.git", + "state" : { + "revision" : "2f27d29b863c797f74318f87579bac77935df2eb", + "version" : "2.22.2" + } + }, + { + "identity" : "swift-atomics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-atomics.git", + "state" : { + "revision" : "cd142fd2f64be2100422d658e7411e39489da985", + "version" : "1.2.0" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "a902f1823a7ff3c9ab2fba0f992396b948eda307", + "version" : "1.0.5" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "ddb07e896a2a8af79512543b1c7eb9797f8898a5", + "version" : "1.1.7" + } + }, + { + "identity" : "swift-distributed-tracing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-distributed-tracing.git", + "state" : { + "revision" : "49b7617717a09f6b781c9a11e1628e3315d8d4fe", + "version" : "1.0.1" + } + }, + { + "identity" : "swift-http-types", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-http-types", + "state" : { + "revision" : "99d066e29effa8845e4761dd3f2f831edfdf8925", + "version" : "1.0.0" + } + }, + { + "identity" : "swift-log", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-log.git", + "state" : { + "revision" : "532d8b529501fb73a2455b179e0bbb6d49b652ed", + "version" : "1.5.3" + } + }, + { + "identity" : "swift-metrics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-metrics.git", + "state" : { + "revision" : "971ba26378ab69c43737ee7ba967a896cb74c0d1", + "version" : "2.4.1" + } + }, + { + "identity" : "swift-nio", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio.git", + "state" : { + "revision" : "702cd7c56d5d44eeba73fdf83918339b26dc855c", + "version" : "2.62.0" + } + }, + { + "identity" : "swift-nio-extras", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-extras.git", + "state" : { + "revision" : "798c962495593a23fdea0c0c63fd55571d8dff51", + "version" : "1.20.0" + } + }, + { + "identity" : "swift-nio-http2", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-http2.git", + "state" : { + "revision" : "3bd9004b9d685ed6b629760fc84903e48efec806", + "version" : "1.29.0" + } + }, + { + "identity" : "swift-nio-ssl", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-ssl.git", + "state" : { + "revision" : "320bd978cceb8e88c125dcbb774943a92f6286e9", + "version" : "2.25.0" + } + }, + { + "identity" : "swift-nio-transport-services", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-nio-transport-services.git", + "state" : { + "revision" : "ebf8b9c365a6ce043bf6e6326a04b15589bd285e", + "version" : "1.20.0" + } + }, + { + "identity" : "swift-service-context", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-service-context.git", + "state" : { + "revision" : "ce0141c8f123132dbd02fd45fea448018762df1b", + "version" : "1.0.0" + } + }, + { + "identity" : "xmlcoding", + "kind" : "remoteSourceControl", + "location" : "https://github.com/LiveUI/XMLCoding.git", + "state" : { + "revision" : "f0fbfe17e73f329e13a6133ff5437f7b174049fd", + "version" : "0.4.1" + } + } + ], + "version" : 2 } diff --git a/Package.swift b/Package.swift index 76569cf..0d7d49a 100644 --- a/Package.swift +++ b/Package.swift @@ -28,12 +28,14 @@ let package = Package( dependencies: [ .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.19.0"), .package(url: "https://github.com/amzn/smoke-aws.git", from: "2.44.174"), + .package(url: "https://github.com/amzn/smoke-aws-support.git", from: "1.5.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), ], targets: [ .target( name: "SmokeAWSCredentials", dependencies: [ + .product(name: "AWSCore", package: "smoke-aws-support"), .product(name: "SecurityTokenClient", package: "smoke-aws"), .product(name: "Logging", package: "swift-log"), .product(name: "NIO", package: "swift-nio"), diff --git a/Sources/SmokeAWSCredentials/AwsRotatingCredentialsProviderV2.swift b/Sources/SmokeAWSCredentials/AwsRotatingCredentialsProviderV2.swift index 87eb2ec..2d0045a 100644 --- a/Sources/SmokeAWSCredentials/AwsRotatingCredentialsProviderV2.swift +++ b/Sources/SmokeAWSCredentials/AwsRotatingCredentialsProviderV2.swift @@ -11,13 +11,13 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. // -// AwsRotatingCredentials.swift +// AwsRotatingCredentialsV2.swift // SmokeAWSCredentials // import Foundation import Logging -import SmokeAWSCore +import AWSCore import SmokeHTTPClient private let secondsToNanoSeconds: UInt64 = 1_000_000_000 @@ -33,10 +33,250 @@ internal extension NSLocking { } } +/** + An actor that manages the current credentials. CurrentCredentials will attempt to always + keep the credentials valid firstly by scheduling a background task and in the worst case + fetching updated credentials when credentials are requested. + */ +private actor CurrentCredentials { + private var state: State + private let expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever + + // the task to schedule a background refresh + private var backgroundRefreshTask: Task? + // the task to refresh the credentials in the background + // this is held seperately to `state` so the existing credentials can continue to + // be used until the background refresh is complete + private var backgroundPendingCredentialsTask: Task? + private let backgroundLogger: Logger + private let credentialsStreamContinuation: AsyncStream.Continuation + + private enum State { + case present(ExpiringCredentials) + case pending(Task) + case missing // the credentials have previously expired and new credentials have failed to be retrieved + } + + private let expirationBufferSeconds: Double + private let backgroundExpirationBufferSeconds: Double + + /** + Initializes the actor. + + - Parameters: + - credentials: the initial credentials + - expiringCredentialsRetriever: used to retrieve refreshed credentials when required + - backgroundLogger: the logger to use for background credential refreshes + - credentialsStreamContinuation: the continuation for a stream of credential updates. + */ + init( + credentials: ExpiringCredentials, + expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever, + backgroundLogger: Logger, + credentialsStreamContinuation: AsyncStream.Continuation, + expirationBufferSeconds: Double, + backgroundExpirationBufferSeconds: Double + ) { + self.state = .present(credentials) + self.expiringCredentialsRetriever = expiringCredentialsRetriever + self.backgroundLogger = backgroundLogger + self.credentialsStreamContinuation = credentialsStreamContinuation + self.expirationBufferSeconds = expirationBufferSeconds + self.backgroundExpirationBufferSeconds = backgroundExpirationBufferSeconds + } + + /** + Starts a task to manage refreshing the current credentials just before they are about to expire. + */ + func startBackgroundRefreshTaskIfRequired() { + switch self.state { + case .present(let presentValue): + if let currentExpiration = presentValue.expiration { + self.backgroundRefreshTask = scheduleRefreshBeforeExpiration(currentExpiration) + } + case .pending, .missing: + // nothing to do + break + } + } + + /** + Gets the current credentials, ensuring that these credentials are always valid + */ + func get( + isBackgroundRefresh: Bool = false, + logger: Logger = Logger(label: "com.azmn.smoke-aws-credentials.CurrentCredentials.get") + ) async throws -> AWSCore.Credentials { + switch self.state { + case .present(let presentValue): + // if not within the buffer period and about to become expired + if !isBackgroundRefresh, let expiration = presentValue.expiration, + expiration > Date(timeIntervalSinceNow: self.expirationBufferSeconds) { + // these credentials can be used + logger.trace("Current credentials used.") + + return presentValue + } else if let backgroundPendingCredentialsTask = self.backgroundPendingCredentialsTask { + // if there is an-progress background refresh + // normally we wouldn't wait on this task but the current credentials are now expired + // so they can't be used + return try await backgroundPendingCredentialsTask.value + } + + logger.trace("Replacing current credentials.") + case .pending(let task): + // There is a pending credentials refresh + logger.trace("Waiting on existing credentials refresh") + + return try await task.value + case .missing: + logger.trace("Fetching new credentials.") + } + + // get the task for this entry + let task = self.handleGetFromRetriever(isBackgroundRefresh: isBackgroundRefresh) + + // if this is a background refresh, continue to use + // the existing credentials until the refreshed credentials + // are available (in other words don't hold up getting credentials + // for a client while the background refresh is in progress) + if !isBackgroundRefresh { + // cancel any background refresh task + backgroundRefreshTask?.cancel() + backgroundRefreshTask = nil + + // update the entry + // any concurrent credential gets will also wait for this task + self.state = .pending(task) + } else { + self.backgroundPendingCredentialsTask = task + } + + return try await task.value + } + + func stop() async { + self.backgroundRefreshTask?.cancel() + self.backgroundPendingCredentialsTask?.cancel() + + do { + try await self.expiringCredentialsRetriever.shutdown() + } catch { + self.backgroundLogger.warning("ExpiringCredentialsRetriever failed to shutdown cleanly", + metadata: ["cause": "\(error)"]) + } + + switch self.state { + case .pending(let task): + task.cancel() + case .present, .missing: + // nothing to do + break + } + } + + private func handleGetFromRetriever(isBackgroundRefresh: Bool) -> Task { + Task.detached { + let result: Result + do { + // wait for the value of the entry to be retrieved + let value = try await self.expiringCredentialsRetriever.getCredentials() + + result = .success(value) + } catch { + result = .failure(error) + } + + await self.addEntry(isBackgroundRefresh: isBackgroundRefresh, result: result) + + switch result { + case .success(let newEntry): + return newEntry + case .failure(let error): + throw error + } + } + } + + private func addEntry( + isBackgroundRefresh: Bool, + result: Result + ) { + self.backgroundPendingCredentialsTask = nil + + guard case .success(let credentials) = result else { + // we ignore the failure of a background refresh, now relying on a refresh initiated by a credentials get + // if a refresh initiated by a credentials get fails, we potentially just don't have any valid credentials + // set the state is `.missing` so any future credentials get can try again to refresh the credentials + if !isBackgroundRefresh { + self.state = .missing + } + return + } + + self.credentialsStreamContinuation.yield(credentials) + + if let currentExpiration = credentials.expiration { + // there are new credentials, schedule their refresh before they expire + self.backgroundRefreshTask = scheduleRefreshBeforeExpiration(currentExpiration) + } + + // update the entry + self.state = .present(credentials) + } + + // creates a task that will suspend until just before the current credentials expire + // and then initiates a refresh of the current credentials + private nonisolated func scheduleRefreshBeforeExpiration(_ currentExpiration: Date) -> Task { + return Task { + // create a deadline 5 minutes before the expiration + let waitDurationInSeconds = (currentExpiration - self.backgroundExpirationBufferSeconds).timeIntervalSinceNow + let waitDurationInMinutes = waitDurationInSeconds / 60 + + let wholeNumberOfHours = Int(waitDurationInMinutes) / 60 + // the total number of minutes minus the number of minutes + // that can be expressed in a whole number of hours + // Can also be expressed as: let overflowMinutes = waitDurationInMinutes - (wholeNumberOfHours * 60) + let overflowMinutes = Int(waitDurationInMinutes) % 60 + + if waitDurationInSeconds > 0 { + self.backgroundLogger.trace( + "Credentials updated; rotation scheduled in \(wholeNumberOfHours) hours, \(overflowMinutes) minutes.") + do { + try await Task.sleep(nanoseconds: UInt64(waitDurationInSeconds) * secondsToNanoSeconds) + } catch is CancellationError { + self.backgroundLogger.trace( + "Background credentials rotation cancelled.") + return + } catch { + self.backgroundLogger.error( + "Background credentials rotation failed due to error \(error).") + return + } + } + + do { + _ = try await self.get(isBackgroundRefresh: true, logger: self.backgroundLogger) + } catch is CancellationError { + self.backgroundLogger.trace( + "Background credentials rotation cancelled.") + return + } catch { + self.backgroundLogger.error( + "Background credentials rotation failed due to error \(error).") + return + } + + self.backgroundLogger.trace( + "Background credentials rotation completed.") + } + } +} + /** Class that manages the rotating credentials. */ -public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { +public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider, CredentialsProviderV2 { public var credentials: Credentials { // the provider returns a copy of the current // credentials which is used within a request. @@ -49,13 +289,13 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { } private var expiringCredentials: ExpiringCredentials - - let expirationBufferSeconds = 300.0 // 5 minutes - let validCredentialsRetrySeconds = 60.0 // 1 minute - let invalidCredentialsRetrySeconds = 3600.0 // 1 hour - - let roleSessionName: String? - let logger: Logger + + private let currentCredentials: CurrentCredentials + // a stream of credentials updates that is used to ensure the `credentials` property required by the original + // `CredentialsProvider` protocol returns the latest set of credentials. Credential instances can be placed into this + // stream either due to a background refresh or initiated by a call to `getCredentials() async throws` that identified + // expired credentials and refreshes them on demand + private let credentialsStream: (stream: AsyncStream, continuation: AsyncStream.Continuation) public enum Status { case initialized @@ -67,7 +307,6 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { public var status: Status let completedSemaphore = DispatchSemaphore(value: 0) var statusLock: NSLock = .init() - let expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever /** Initializer that accepts the initial ExpiringCredentials instance for this provider. @@ -78,22 +317,46 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { @available(swift, deprecated: 3.0, message: "Migrate to async constructor") public init(expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever, roleSessionName: String?, - logger: Logger) throws { + logger: Logger, + expirationBufferSeconds: Double = 120.0, // 2 minutes + backgroundExpirationBufferSeconds: Double = 300.0) throws { // 5 minutes self.expiringCredentials = try expiringCredentialsRetriever.get() - self.expiringCredentialsRetriever = expiringCredentialsRetriever - self.roleSessionName = roleSessionName - self.logger = logger self.status = .initialized + + var decoratedLogger = logger + if let roleSessionName { + decoratedLogger[metadataKey: "roleSessionName"] = "\(roleSessionName)" + } + + self.credentialsStream = AsyncStream.makeStream(of: ExpiringCredentials.self) + self.currentCredentials = CurrentCredentials(credentials: self.expiringCredentials, + expiringCredentialsRetriever: expiringCredentialsRetriever, + backgroundLogger: decoratedLogger, + credentialsStreamContinuation: self.credentialsStream.continuation, + expirationBufferSeconds: expirationBufferSeconds, + backgroundExpirationBufferSeconds: backgroundExpirationBufferSeconds) } public init(expiringCredentialsRetriever: ExpiringCredentialsAsyncRetriever, roleSessionName: String?, - logger: Logger) async throws { + logger: Logger, + expirationBufferSeconds: Double = 120.0, // 2 minutes + backgroundExpirationBufferSeconds: Double = 300.0) async throws { // 5 minutes self.expiringCredentials = try await expiringCredentialsRetriever.getCredentials() - self.expiringCredentialsRetriever = expiringCredentialsRetriever - self.roleSessionName = roleSessionName - self.logger = logger self.status = .initialized + + var decoratedLogger = logger + if let roleSessionName { + decoratedLogger[metadataKey: "roleSessionName"] = "\(roleSessionName)" + } + + self.credentialsStream = AsyncStream.makeStream(of: ExpiringCredentials.self) + self.currentCredentials = CurrentCredentials(credentials: self.expiringCredentials, + expiringCredentialsRetriever: expiringCredentialsRetriever, + backgroundLogger: decoratedLogger, + credentialsStreamContinuation: self.credentialsStream.continuation, + expirationBufferSeconds: expirationBufferSeconds, + backgroundExpirationBufferSeconds: backgroundExpirationBufferSeconds) } deinit { @@ -127,17 +390,17 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { Gracefully shuts down credentials rotation, letting any ongoing work complete.. */ public func stop() throws { - try self.statusLock.withLock { + self.statusLock.withLock { // if there is currently a worker to shutdown switch status { case .initialized: // no worker ever started, can just go straight to stopped status = .stopped - try expiringCredentialsRetriever.syncShutdown() + self.credentialsStream.continuation.finish() completedSemaphore.signal() case .running: status = .shuttingDown - try expiringCredentialsRetriever.syncShutdown() + self.credentialsStream.continuation.finish() default: // nothing to do break @@ -166,7 +429,7 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { } if isShutdown { - try await self.expiringCredentialsRetriever.shutdown() + self.credentialsStream.continuation.finish() } } @@ -191,106 +454,40 @@ public class AwsRotatingCredentialsProviderV2: StoppableCredentialsProvider { self.completedSemaphore.wait() } - - private func verifyWorkerNotCancelled() -> Bool { - return self.statusLock.withLock { - guard case .running = status else { - status = .stopped - completedSemaphore.signal() - return false - } - - return true - } + + public func getCredentials() async throws -> Credentials { + return try await self.currentCredentials.get() } func run() async { - var expiration: Date? = self.expiringCredentials.expiration - - while let currentExpiration = expiration { - guard self.verifyWorkerNotCancelled() else { - return - } - - // create a deadline 5 minutes before the expiration - let waitDurationInSeconds = (currentExpiration - self.expirationBufferSeconds).timeIntervalSinceNow - let waitDurationInMinutes = waitDurationInSeconds / 60 - - let wholeNumberOfHours = Int(waitDurationInMinutes) / 60 - // the total number of minutes minus the number of minutes - // that can be expressed in a whole number of hours - // Can also be expressed as: let overflowMinutes = waitDurationInMinutes - (wholeNumberOfHours * 60) - let overflowMinutes = Int(waitDurationInMinutes) % 60 - - let logEntryPrefix: String - if let roleSessionName = self.roleSessionName { - logEntryPrefix = "Credentials for session '\(roleSessionName)'" - } else { - logEntryPrefix = "Credentials" - } - - self.logger.trace( - "\(logEntryPrefix) updated; rotation scheduled in \(wholeNumberOfHours) hours, \(overflowMinutes) minutes.") - do { - try await Task.sleep(nanoseconds: UInt64(waitDurationInSeconds) * secondsToNanoSeconds) - } catch { - self.logger.error( - "\(logEntryPrefix) rotation stopped due to error \(error).") - } - - expiration = await self.updateCredentials(roleSessionName: roleSessionName, logger: self.logger) - } - } - - private func updateCredentials(roleSessionName: String?, - logger _: Logger) async - -> Date? { - let logEntryPrefix: String - if let roleSessionName = roleSessionName { - logEntryPrefix = "Credentials for session '\(roleSessionName)'" - } else { - logEntryPrefix = "Credentials" - } - - self.logger.trace("\(logEntryPrefix) about to expire; rotating.") - - let expiration: Date? - do { - let expiringCredentials = try await self.expiringCredentialsRetriever.getCredentials() - + await self.currentCredentials.startBackgroundRefreshTaskIfRequired() + + for await credentials in self.credentialsStream.stream { self.statusLock.withLock { - self.expiringCredentials = expiringCredentials - } - - expiration = expiringCredentials.expiration - } catch { - let timeIntervalSinceNow = - self.expiringCredentials.expiration?.timeIntervalSinceNow ?? 0 - - let retryDuration: Double - let logPrefix = "\(logEntryPrefix) rotation failed." - - // if the expiry of the current credentials is still in the future - if timeIntervalSinceNow > 0 { - // try again relatively soon (still within the 5 minute credentials - // expirary buffer) to get new credentials - retryDuration = self.validCredentialsRetrySeconds - - self.logger.warning( - "\(logPrefix) Credentials still valid. Attempting credentials refresh in 1 minute.") - } else { - // at this point, we have tried multiple times to get new credentials - // something is quite wrong; try again in the future but at - // a reduced frequency - retryDuration = self.invalidCredentialsRetrySeconds - - self.logger.error( - "\(logPrefix) Credentials no longer valid. Attempting credentials refresh in 1 hour.") + self.expiringCredentials = credentials } - - expiration = Date(timeIntervalSinceNow: retryDuration) } + + // cancel any background tasks + await self.currentCredentials.stop() + + self.statusLock.withLock { + status = .stopped + completedSemaphore.signal() + } + } +} - return expiration +#if swift(<5.9.0) +// This should be removed once we support Swift 5.9+ +extension AsyncStream { + fileprivate static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) } } +#endif diff --git a/Tests/SmokeAWSCredentialsTests/AwsRotatingCredentialsProviderV2Tests.swift b/Tests/SmokeAWSCredentialsTests/AwsRotatingCredentialsProviderV2Tests.swift new file mode 100644 index 0000000..22ab58a --- /dev/null +++ b/Tests/SmokeAWSCredentialsTests/AwsRotatingCredentialsProviderV2Tests.swift @@ -0,0 +1,314 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). +// You may not use this file except in compliance with the License. +// A copy of the License is located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// +// AwsContainerRotatingCredentialsV2Tests.swift +// SmokeAWSCredentials +// + +import Logging +@testable import SmokeAWSCredentials +import SmokeHTTPClient +import XCTest +import Logging +import SmokeAWSCore + +private enum TestErrors: Swift.Error { + case retrieverError +} + +private actor TestExpiringCredentialsAsyncRetriever: ExpiringCredentialsAsyncRetriever { + enum Result { + case credentials(SmokeAWSCredentials.ExpiringCredentials) + case error(Swift.Error) + } + var results: [Result] + + init(results: [Result]) { + self.results = results.reversed() + } + func getCredentials() async throws -> SmokeAWSCredentials.ExpiringCredentials { + let result = self.results.popLast()! + + switch result { + case .credentials(let expiringCredentials): + return expiringCredentials + case .error(let error): + throw error + } + } + + nonisolated func close() throws { + // nothing to do + } + + func shutdown() async throws { + // nothing to do + } + nonisolated func get() throws -> SmokeAWSCredentials.ExpiringCredentials { + fatalError("Not implemented") + } + + +} + +class AwsContainerRotatingCredentialsV2Tests: XCTestCase { + private let accessKeyId1 = "accessKeyId1" + private let accessKeyId2 = "accessKeyId2" + private let accessKeyId3 = "accessKeyId3" + private let secretAccessKey1 = "secretAccessKey1" + private let secretAccessKey2 = "secretAccessKey2" + private let secretAccessKey3 = "secretAccessKey3" + private let sessionToken1 = "sessionToken1" + private let sessionToken2 = "sessionToken2" + private let sessionToken3 = "sessionToken3" + + func testBackgroundRefresh() async throws { + let firstCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId1, + expiration: Date() + 10, + secretAccessKey: secretAccessKey1, + sessionToken: sessionToken1) + let secondCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId2, + expiration: Date() + 20, + secretAccessKey: secretAccessKey2, + sessionToken: sessionToken2) + let thirdCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId3, + expiration: Date() + 3600, + secretAccessKey: secretAccessKey3, + sessionToken: sessionToken3) + + let retriever = TestExpiringCredentialsAsyncRetriever(results: [.credentials(firstCredentials), + .credentials(secondCredentials), + .credentials(thirdCredentials)]) + let provider = try await AwsRotatingCredentialsProviderV2( + expiringCredentialsRetriever: retriever, + roleSessionName: nil, + logger: Logger(label: "test.logger"), + expirationBufferSeconds: 2, + backgroundExpirationBufferSeconds: 5) + + provider.start() + + // will return credentials retrieved from the first time the credentials are called + let retrievedCredentials1 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials1.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials1.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials1.sessionToken, firstCredentials.sessionToken) + + + // legacy property should match + let retrievedCredentials1_1 = provider.credentials + XCTAssertEqual(retrievedCredentials1_1.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials1_1.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials1_1.sessionToken, firstCredentials.sessionToken) + + // the background credentials refresh should happen after 5 seconds (five seconds before the expiration) + try await Task.sleep(for: .seconds(6)) + + // will return credentials retrieved from the background refresh + // even through the first credentials haven't expired yet + let retrievedCredentials2 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials2.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials2.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials2.sessionToken, secondCredentials.sessionToken) + + // legacy property should match + let retrievedCredentials2_1 = provider.credentials + XCTAssertEqual(retrievedCredentials2_1.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials2_1.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials2_1.sessionToken, secondCredentials.sessionToken) + + // sleep until after the first credentials have expired + try await Task.sleep(for: .seconds(6)) + + // should still be the second credentials + let retrievedCredentials3 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials3.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials3.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials3.sessionToken, secondCredentials.sessionToken) + + // legacy property should match + let retrievedCredentials3_1 = provider.credentials + XCTAssertEqual(retrievedCredentials3_1.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials3_1.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials3_1.sessionToken, secondCredentials.sessionToken) + + // the next background credentials refresh should happen after 15 seconds (five seconds before the expiration) + try await Task.sleep(for: .seconds(4)) + + // will return credentials retrieved from the second background refresh + // even through the second credentials haven't expired yet + let retrievedCredentials4 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials4.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials4.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials4.sessionToken, thirdCredentials.sessionToken) + + // legacy property should match + let retrievedCredentials4_1 = provider.credentials + XCTAssertEqual(retrievedCredentials4_1.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials4_1.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials4_1.sessionToken, thirdCredentials.sessionToken) + + // sleep until after the second credentials have expired + try await Task.sleep(for: .seconds(6)) + + // should still be the third credentials + let retrievedCredentials5 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials5.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials5.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials5.sessionToken, thirdCredentials.sessionToken) + + // legacy property should match + let retrievedCredentials5_1 = provider.credentials + XCTAssertEqual(retrievedCredentials5_1.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials5_1.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials5_1.sessionToken, thirdCredentials.sessionToken) + + try await provider.shutdown() + provider.wait() + } + + func testFailedBackgroundRefresh() async throws { + let firstCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId1, + expiration: Date() + 10, + secretAccessKey: secretAccessKey1, + sessionToken: sessionToken1) + let secondCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId2, + expiration: Date() + 20, + secretAccessKey: secretAccessKey2, + sessionToken: sessionToken2) + let thirdCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId3, + expiration: Date() + 3600, + secretAccessKey: secretAccessKey3, + sessionToken: sessionToken3) + + let retriever = TestExpiringCredentialsAsyncRetriever(results: [.credentials(firstCredentials), + .error(TestErrors.retrieverError), + .credentials(secondCredentials), + .credentials(thirdCredentials)]) + let provider = try await AwsRotatingCredentialsProviderV2( + expiringCredentialsRetriever: retriever, + roleSessionName: nil, + logger: Logger(label: "test.logger"), + expirationBufferSeconds: 2, + backgroundExpirationBufferSeconds: 5) + + provider.start() + + // will return credentials retrieved from the first time the credentials are called + let retrievedCredentials1 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials1.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials1.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials1.sessionToken, firstCredentials.sessionToken) + + // the background credentials refresh should happen after 5 seconds (five seconds before the expiration) + try await Task.sleep(for: .seconds(6)) + + // will return the first credentials as the background refresh failed + // and not within expirationBufferSeconds of the credentials expiry + let retrievedCredentials2 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials2.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials2.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials2.sessionToken, firstCredentials.sessionToken) + + // sleep to within the expirationBufferSeconds + try await Task.sleep(for: .seconds(3)) + + // will actually go and retrieve refreshed credentials + let retrievedCredentials2_2 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials2_2.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials2_2.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials2_2.sessionToken, secondCredentials.sessionToken) + + // sleep until after the first credentials have expired + try await Task.sleep(for: .seconds(3)) + + // should still be the second credentials + let retrievedCredentials3 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials3.accessKeyId, secondCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials3.secretAccessKey, secondCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials3.sessionToken, secondCredentials.sessionToken) + + // the next background credentials refresh should happen after 15 seconds (five seconds before the expiration) + // the failure of the first background refresh shound not impact this occurring + try await Task.sleep(for: .seconds(4)) + + // will return credentials retrieved from the second background refresh + // even through the second credentials haven't expired yet + let retrievedCredentials4 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials4.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials4.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials4.sessionToken, thirdCredentials.sessionToken) + + // sleep until after the second credentials have expired + try await Task.sleep(for: .seconds(6)) + + // should still be the third credentials + let retrievedCredentials5 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials5.accessKeyId, thirdCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials5.secretAccessKey, thirdCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials5.sessionToken, thirdCredentials.sessionToken) + + try await provider.shutdown() + provider.wait() + } + + func testFailedRetrieval() async throws { + let firstCredentials = SmokeAWSCredentials.ExpiringCredentials(accessKeyId: accessKeyId1, + expiration: Date() + 10, + secretAccessKey: secretAccessKey1, + sessionToken: sessionToken1) + + let retriever = TestExpiringCredentialsAsyncRetriever(results: [.credentials(firstCredentials), + .error(TestErrors.retrieverError), + .error(TestErrors.retrieverError)]) + let provider = try await AwsRotatingCredentialsProviderV2( + expiringCredentialsRetriever: retriever, + roleSessionName: nil, + logger: Logger(label: "test.logger"), + expirationBufferSeconds: 2, + backgroundExpirationBufferSeconds: 5) + + provider.start() + + // will return credentials retrieved from the first time the credentials are called + let retrievedCredentials1 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials1.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials1.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials1.sessionToken, firstCredentials.sessionToken) + + // the background credentials refresh should happen after 5 seconds (five seconds before the expiration) + try await Task.sleep(for: .seconds(6)) + + // will return the first credentials as the background refresh failed + // and not within expirationBufferSeconds of the credentials expiry + let retrievedCredentials2 = try await provider.getCredentials() + XCTAssertEqual(retrievedCredentials2.accessKeyId, firstCredentials.accessKeyId) + XCTAssertEqual(retrievedCredentials2.secretAccessKey, firstCredentials.secretAccessKey) + XCTAssertEqual(retrievedCredentials2.sessionToken, firstCredentials.sessionToken) + + // sleep to within the expirationBufferSeconds + try await Task.sleep(for: .seconds(3)) + + // will actually go and retrieve refreshed credentials + do { + _ = try await provider.getCredentials() + + XCTFail("Expected error not thrown") + } catch TestErrors.retrieverError { + // expected error + } + + try await provider.shutdown() + provider.wait() + } +}