Skip to content

Commit

Permalink
chore: Refactor existing identity structure to use new identity proto…
Browse files Browse the repository at this point in the history
…cols in smithy-swift (#1150)

* Refactor existing structure to use new identity protocols in smithy-swift.
---------

Co-authored-by: Sichan Yoo <chanyoo@amazon.com>
  • Loading branch information
sichanyoo and Sichan Yoo authored Oct 4, 2023
1 parent 56298ea commit 668575c
Show file tree
Hide file tree
Showing 24 changed files with 64 additions and 67 deletions.
12 changes: 6 additions & 6 deletions Sources/Core/AWSClientRuntime/AWSClientConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class AWSClientConfiguration<ServiceSpecificConfiguration: AWSServiceSpec
/// The credentials provider to be used for AWS credentials.
///
/// If no credentials provider is supplied, the SDK will look for credentials in the environment, then in the `~/.aws/credentials` file.
public var credentialsProvider: CredentialsProviding
public var credentialsProvider: any CredentialsProviding

/// The AWS region to use, i.e. `us-east-1` or `us-west-2`, etc.
///
Expand Down Expand Up @@ -111,7 +111,7 @@ public class AWSClientConfiguration<ServiceSpecificConfiguration: AWSServiceSpec
/// All convenience inits should call this.
private init(
// these params have no labels to distinguish this init from the similar convenience inits below
_ credentialsProvider: AWSClientRuntime.CredentialsProviding,
_ credentialsProvider: any AWSClientRuntime.CredentialsProviding,
_ endpoint: Swift.String?,
_ serviceSpecific: ServiceSpecificConfiguration?,
_ region: Swift.String?,
Expand Down Expand Up @@ -163,7 +163,7 @@ extension AWSClientConfiguration {

/// Creates a configuration asynchronously
public convenience init(
credentialsProvider: AWSClientRuntime.CredentialsProviding? = nil,
credentialsProvider: (any AWSClientRuntime.CredentialsProviding)? = nil,
endpoint: Swift.String? = nil,
serviceSpecific: ServiceSpecificConfiguration? = nil,
region: Swift.String? = nil,
Expand All @@ -184,7 +184,7 @@ extension AWSClientConfiguration {
} else {
resolvedRegion = await resolvedRegionResolver.resolveRegion()
}
let resolvedCredentialsProvider: AWSClientRuntime.CredentialsProviding
let resolvedCredentialsProvider: any AWSClientRuntime.CredentialsProviding
if let credentialsProvider = credentialsProvider {
resolvedCredentialsProvider = credentialsProvider
} else {
Expand Down Expand Up @@ -219,7 +219,7 @@ extension AWSClientConfiguration {

public convenience init(
region: Swift.String,
credentialsProvider: AWSClientRuntime.CredentialsProviding? = nil,
credentialsProvider: (any AWSClientRuntime.CredentialsProviding)? = nil,
endpoint: Swift.String? = nil,
serviceSpecific: ServiceSpecificConfiguration? = nil,
signingRegion: Swift.String? = nil,
Expand All @@ -231,7 +231,7 @@ extension AWSClientConfiguration {
connectTimeoutMs: UInt32? = nil
) throws {
let fileBasedConfig = try CRTFileBasedConfiguration.make()
let resolvedCredentialsProvider: CredentialsProviding
let resolvedCredentialsProvider: any CredentialsProviding
if let credentialsProvider = credentialsProvider {
resolvedCredentialsProvider = credentialsProvider
} else {
Expand Down
4 changes: 2 additions & 2 deletions Sources/Core/AWSClientRuntime/Auth/Credentials+CRT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public extension Credentials {
self.init(
accessKey: accessKey,
secret: secret,
expirationTimeout: crtCredentials.getExpiration(),
expiration: crtCredentials.getExpiration(),
sessionToken: crtCredentials.getSessionToken()
)
}
Expand All @@ -39,7 +39,7 @@ public extension CRTCredentials {
accessKey: credentials.accessKey,
secret: credentials.secret,
sessionToken: credentials.sessionToken,
expiration: credentials.expirationTimeout
expiration: credentials.expiration
)
}
}
11 changes: 6 additions & 5 deletions Sources/Core/AWSClientRuntime/Auth/Credentials.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,35 @@
//

import Foundation
import protocol ClientRuntime.Identity

public typealias AWSCredentials = Credentials

/// A type representing credentials for authenticating with an AWS service
///
/// For more information see [AWS security credentials](https://docs.aws.amazon.com/general/latest/gr/aws-security-credentials.html#AccessKeys)
public struct Credentials {
public struct Credentials: Identity {
let accessKey: String
let secret: String
let expirationTimeout: Date?
let sessionToken: String?
public let expiration: Date?

/// Creates credentials with the specified keys and optionally an expiration and session token.
///
/// - Parameters:
/// - accessKey: The access key
/// - secret: The secret for the provided access key
/// - expirationTimeout: The date when the credentials will expire and no longer be valid. If value is `nil` then the credentials never expire. Defaults to `nil`
/// - expiration: The date when the credentials will expire and no longer be valid. If value is `nil` then the credentials never expire. Defaults to `nil`
/// - sessionToken: A session token for this session. Defaults to `nil`
public init(
accessKey: String,
secret: String,
expirationTimeout: Date? = nil,
expiration: Date? = nil,
sessionToken: String? = nil
) {
self.accessKey = accessKey
self.secret = secret
self.expirationTimeout = expirationTimeout
self.expiration = expiration
self.sessionToken = sessionToken
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public struct CachedCredentialsProvider: CredentialsSourcedByCRT {
/// - source: The source credentials provider to get the credentials.
/// - refreshTime: The number of seconds that must pass before new credentials will be fetched again.
public init(
source: CredentialsProviding,
source: any CredentialsProviding,
refreshTime: TimeInterval
) throws {
self.crtCredentialsProvider = try CRTCredentialsProvider(source: .cached(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct CustomCredentialsProvider: CredentialsSourcedByCRT {
/// - Parameter provider: An object confirming to `CredentialsProviding` to source the credentials.
///
/// - Returns: A credentials provider that uses the provided the object confirming to `CredentialsProviding` to source the credentials.
init(_ provider: CredentialsProviding) throws {
init(_ provider: any CredentialsProviding) throws {
self.crtCredentialsProvider = try CRTCredentialsProvider(
provider: CredentialsProvidingCRTAdapter(credentialsProvider: provider)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public struct STSAssumeRoleCredentialsProvider: CredentialsSourcedByCRT {
/// - sessionName: The name to associate with the session. This is used to uniquely identify a session when the same role is assumed by different principals or for different reasons. In cross-account scenarios, the session name is visible to, and can be logged by the account that owns the role. The role session name is also in the ARN of the assumed role principal.
/// - durationSeconds: The expiry duration of the STS credentials. Defaults to 15 minutes if not set.
public init(
credentialsProvider: CredentialsProviding,
credentialsProvider: any CredentialsProviding,
roleArn: String,
sessionName: String,
durationSeconds: TimeInterval = .minutes(15)
Expand Down
12 changes: 7 additions & 5 deletions Sources/Core/AWSClientRuntime/Auth/CredentialsProviding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
import ClientRuntime

/// A type that can provide credentials for authenticating with an AWS service
public protocol CredentialsProviding {
/// Returns credentials for authenticating with an AWS service
func getCredentials() async throws -> Credentials
}
public protocol CredentialsProviding: IdentityResolver where IdentityT == Credentials {}

extension CredentialsProviding {
/// Returns the underlying `CRTCredentialsProvider`.
/// If `self` is not backed by a `CRTCredentialsProvider` then this wraps `self` in a `CustomCredentialsProvider` which will create a `CRTCredentialsProvider`.
func getCRTCredentialsProvider() throws -> CRTCredentialsProvider {
let providerSourcedByCRT = try self as? CredentialsSourcedByCRT ?? CustomCredentialsProvider(self)
let providerSourcedByCRT = try self as? (any CredentialsSourcedByCRT) ?? CustomCredentialsProvider(self)
return providerSourcedByCRT.crtCredentialsProvider
}

public func getIdentity(identityProperties: Attributes? = nil) async throws -> Credentials {
let crtCredentials = try await self.getCRTCredentialsProvider().getCredentials()
return try .init(crtCredentials: crtCredentials)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
//

import AwsCommonRuntimeKit
import ClientRuntime

typealias CRTCredentialsProviding = AwsCommonRuntimeKit.CredentialsProviding
typealias CRTCredentialsProvider = AwsCommonRuntimeKit.CredentialsProvider

/// A credentials provider that adapts a credentials provider to `CRTCredentialsProvding`
struct CredentialsProvidingCRTAdapter: CRTCredentialsProviding {
let credentialsProvider: CredentialsProviding
let credentialsProvider: any CredentialsProviding

func getCredentials() async throws -> CRTCredentials {
let credentials = try await credentialsProvider.getCredentials()
let credentials = try await credentialsProvider.getIdentity(identityProperties: Attributes())
return try .init(credentials: credentials)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,3 @@ import Foundation
protocol CredentialsSourcedByCRT: CredentialsProviding {
var crtCredentialsProvider: CRTCredentialsProvider { get }
}

extension CredentialsSourcedByCRT {
public func getCredentials() async throws -> Credentials {
let crtCredentials = try await crtCredentialsProvider.getCredentials()
return try .init(crtCredentials: crtCredentials)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import ClientRuntime
import struct Foundation.Date

extension HttpContext {
static let credentialsProvider = AttributeKey<CredentialsProviding>(name: "CredentialsProvider")
static let credentialsProvider = AttributeKey<(any CredentialsProviding)>(name: "CredentialsProvider")
static let region = AttributeKey<String>(name: "Region")
public static let signingName = AttributeKey<String>(name: "SigningName")
public static let signingRegion = AttributeKey<String>(name: "SigningRegion")
public static let signingAlgorithm = AttributeKey<String>(name: "SigningAlgorithm")
public static let requestSignature = AttributeKey<String>(name: "AWS_HTTP_SIGNATURE")

func getCredentialsProvider() -> CredentialsProviding? {
func getCredentialsProvider() -> (any CredentialsProviding)? {
return attributes.get(key: HttpContext.credentialsProvider)
}

Expand Down Expand Up @@ -49,7 +49,7 @@ extension HttpContext {
/// - Returns: `AWSSigningConfig` for the event stream message
public func makeEventStreamSigningConfig(date: Date = Date().withoutFractionalSeconds())
async throws -> AWSSigningConfig {
let credentials = try await getCredentialsProvider()?.getCredentials()
let credentials = try await getCredentialsProvider()?.getIdentity(identityProperties: Attributes())
guard let service = getSigningName() else {
fatalError("Signing name must not be nil, it must be set by the middleware during the request")
}
Expand Down Expand Up @@ -100,7 +100,7 @@ extension HttpContextBuilder {
}

@discardableResult
public func withCredentialsProvider(value: CredentialsProviding) -> HttpContextBuilder {
public func withCredentialsProvider(value: any CredentialsProviding) -> HttpContextBuilder {
self.attributes.set(key: HttpContext.credentialsProvider, value: value)
return self
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public struct SigV4Middleware<OperationStackOutput: HttpResponseBinding,
omitSessionToken: config.omitSessionToken)
let signedBodyValue: AWSSignedBodyValue = config.unsignedBody ? .unsignedPayload : .empty

let credentials = try await credentialsProvider.getCredentials()
let credentials = try await credentialsProvider.getIdentity(identityProperties: Attributes())
let signingConfig = AWSSigningConfig(
credentials: credentials,
expiration: config.expiration,
Expand Down
4 changes: 2 additions & 2 deletions Sources/Core/AWSClientRuntime/Signing/AWSSigV4Signer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ public class AWSSigV4Signer {

public static func sigV4SignedURL(
requestBuilder: SdkHttpRequestBuilder,
credentialsProvider: CredentialsProviding,
credentialsProvider: any CredentialsProviding,
signingName: Swift.String,
signingRegion: Swift.String,
date: ClientRuntime.Date,
expiration: TimeInterval,
signingAlgorithm: AWSSigningAlgorithm
) async -> ClientRuntime.URL? {
do {
let credentials = try await credentialsProvider.getCredentials()
let credentials = try await credentialsProvider.getIdentity(identityProperties: Attributes())
let flags = SigningFlags(useDoubleURIEncode: true,
shouldNormalizeURIPath: true,
omitSessionToken: false)
Expand Down
4 changes: 2 additions & 2 deletions Sources/Core/AWSClientRuntime/Signing/AWSSigningConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation

public struct AWSSigningConfig {
public let credentials: AWSCredentials?
public let credentialsProvider: CredentialsProviding?
public let credentialsProvider: (any CredentialsProviding)?
public let expiration: TimeInterval
public let signedBodyHeader: AWSSignedBodyHeader
public let signedBodyValue: AWSSignedBodyValue
Expand All @@ -24,7 +24,7 @@ public struct AWSSigningConfig {

public init(
credentials: AWSCredentials? = nil,
credentialsProvider: CredentialsProviding? = nil,
credentialsProvider: (any CredentialsProviding)? = nil,
expiration: TimeInterval = 0,
signedBodyHeader: AWSSignedBodyHeader = .none,
signedBodyValue: AWSSignedBodyValue,
Expand Down
4 changes: 2 additions & 2 deletions Sources/Core/AWSClientRuntime/Signing/SigV4Config.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import Foundation

public struct SigV4Config {
let credentialsProvider: CredentialsProviding?
let credentialsProvider: (any CredentialsProviding)?
let signingService: String?
let signatureType: AWSSignatureType
let useDoubleURIEncode: Bool
Expand All @@ -20,7 +20,7 @@ public struct SigV4Config {
let signingAlgorithm: AWSSigningAlgorithm

public init(
credentialsProvider: CredentialsProviding? = nil,
credentialsProvider: (any CredentialsProviding)? = nil,
signingService: String? = nil,
signatureType: AWSSignatureType = .requestHeaders,
useDoubleURIEncode: Bool = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import XCTest
@_spi(FileBasedConfig) @testable import AWSClientRuntime

class CachedCredentialsProviderTests: XCTestCase {
func testGetCredentials() async throws {
func testGetIdentity() async throws {
var counter: Int = 0
let coreProvider = MockCredentialsProvider {
counter += 1
Expand All @@ -24,16 +24,16 @@ class CachedCredentialsProviderTests: XCTestCase {
refreshTime: 1
)

_ = try await subject.getCredentials()
_ = try await subject.getCredentials()
_ = try await subject.getCredentials()
_ = try await subject.getCredentials()
_ = try await subject.getIdentity()
_ = try await subject.getIdentity()
_ = try await subject.getIdentity()
_ = try await subject.getIdentity()

XCTAssertEqual(counter, 1)

try! await Task.sleep(nanoseconds: 1 * 1_000_000_000)

let credentials = try! await subject.getCredentials()
let credentials = try! await subject.getIdentity()

XCTAssertEqual(counter, 2)
XCTAssertEqual(credentials.accessKey, "some_access_key")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CustomCredentialsProviderTests: XCTestCase {
func testGetCredentials() async throws {
let mockProvider = MockCredentialsProvider()
let subject = try CustomCredentialsProvider(mockProvider)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "some_access_key")
XCTAssertEqual(credentials.secret, "some_secret")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DefaultChainCredentialsProviderTests: XCTestCase {
}

let subject = try DefaultChainCredentialsProvider()
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "some_access_key_b")
XCTAssertEqual(credentials.secret, "some_secret_b")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EnvironmentCredentialsProviderTests: XCTestCase {
}

let subject = try EnvironmentCredentialsProvider()
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "some_access_key_a")
XCTAssertEqual(credentials.secret, "some_secret_a")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ProcessCredentialsProviderTests: XCTestCase {
configFilePath: configPath,
credentialsFilePath: credentialsPath
)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual("AccessKey123", credentials.accessKey)
XCTAssertEqual("SecretAccessKey123", credentials.secret)
Expand All @@ -36,7 +36,7 @@ class ProcessCredentialsProviderTests: XCTestCase {
configFilePath: configPath,
credentialsFilePath: credentialsPath
)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual("AccessKey123", credentials.accessKey)
XCTAssertEqual("SecretAccessKey123", credentials.secret)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ProfileCredentialsProviderTests: XCTestCase {
configFilePath: configPath,
credentialsFilePath: credentialsPath
)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "access_key_default_cred")
XCTAssertEqual(credentials.secret, "secret_default_cred")
Expand All @@ -32,7 +32,7 @@ class ProfileCredentialsProviderTests: XCTestCase {
configFilePath: configPath,
credentialsFilePath: credentialsPath
)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "access_key_profile_config")
XCTAssertEqual(credentials.secret, "secret_profile_config")
Expand All @@ -44,7 +44,7 @@ class ProfileCredentialsProviderTests: XCTestCase {
configFilePath: configPath,
credentialsFilePath: credentialsPath
)
let credentials = try await subject.getCredentials()
let credentials = try await subject.getIdentity()

XCTAssertEqual(credentials.accessKey, "access_key_profile_cred")
XCTAssertEqual(credentials.secret, "secret_profile_cred")
Expand Down
Loading

0 comments on commit 668575c

Please sign in to comment.