diff --git a/Sources/ClientRuntime/Endpoints/ServiceEndpointMetadata.swift b/Sources/ClientRuntime/Endpoints/ServiceEndpointMetadata.swift index 75b045d38..ce44a6fbb 100644 --- a/Sources/ClientRuntime/Endpoints/ServiceEndpointMetadata.swift +++ b/Sources/ClientRuntime/Endpoints/ServiceEndpointMetadata.swift @@ -52,7 +52,7 @@ extension ServiceEndpointMetadata { return SmithyEndpoint(endpoint: Endpoint(host: hostname, path: "/", - protocolType: ProtocolType(rawValue: transportProtocol)), + protocolType: ProtocolType(rawValue: transportProtocol)!), signingName: signingName) } diff --git a/Sources/ClientRuntime/Message/RequestMessage.swift b/Sources/ClientRuntime/Message/RequestMessage.swift index 94cbe7d7b..2000a1d7d 100644 --- a/Sources/ClientRuntime/Message/RequestMessage.swift +++ b/Sources/ClientRuntime/Message/RequestMessage.swift @@ -17,6 +17,9 @@ public protocol RequestMessage { /// The body of the request. var body: ByteStream { get } + // The uri of the request + var destination: URI { get } + /// - Returns: A new builder for this request message, with all properties copied. func toBuilder() -> RequestBuilderType } diff --git a/Sources/ClientRuntime/Message/URI.swift b/Sources/ClientRuntime/Message/URI.swift new file mode 100644 index 000000000..7cb89a7e7 --- /dev/null +++ b/Sources/ClientRuntime/Message/URI.swift @@ -0,0 +1,274 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +import Foundation + +/// A representation of the RFC 3986 Uniform Resource Identifier +/// Note: URIBuilder returns an URI instance with all components percent encoded +public struct URI: Hashable { + public let scheme: Scheme + public let path: String + public let host: String + public let port: Int16? + public var defaultPort: Int16 { + Int16(scheme.port) + } + public let queryItems: [SDKURLQueryItem] + public let username: String? + public let password: String? + public let fragment: String? + public var url: URL? { + self.toBuilder().getUrl() + } + public var queryString: String? { + self.queryItems.queryString + } + + fileprivate init(scheme: Scheme, + path: String, + host: String, + port: Int16?, + queryItems: [SDKURLQueryItem], + username: String? = nil, + password: String? = nil, + fragment: String? = nil) { + self.scheme = scheme + self.path = path + self.host = host + self.port = port + self.queryItems = queryItems + self.username = username + self.password = password + self.fragment = fragment + } + + public func toBuilder() -> URIBuilder { + return URIBuilder() + .withScheme(self.scheme) + .withPath(self.path) + .withHost(self.host) + .withPort(self.port) + .withQueryItems(self.queryItems) + .withUsername(self.username) + .withPassword(self.password) + .withFragment(self.fragment) + } +} + +/// A builder class for URI +/// The builder performs validation to conform with RFC 3986 +/// Note: URIBuilder returns an URI instance with all components percent encoded +public final class URIBuilder { + var urlComponents: URLComponents + + public init() { + self.urlComponents = URLComponents() + self.urlComponents.percentEncodedPath = "/" + self.urlComponents.scheme = Scheme.https.rawValue + self.urlComponents.host = "" + } + + @discardableResult + public func withScheme(_ value: Scheme) -> URIBuilder { + self.urlComponents.scheme = value.rawValue + return self + } + + /// According to https://developer.apple.com/documentation/foundation/nsurlcomponents/1408161-percentencodedpath + /// "Although an unencoded semicolon is a valid character in a percent-encoded path, + /// for compatibility with the NSURL class, you should always percent-encode it." + /// + /// URI also always return a percent-encoded path. + /// If an percent-encoded path is provided, we will replace the semicolon with %3B in the path. + /// If an unencoded path is provided, we should percent-encode the path including semicolon. + @discardableResult + public func withPath(_ value: String) -> URIBuilder { + if value.isPercentEncoded { + if value.contains(";") { + let encodedPath = value.replacingOccurrences( + of: ";", with: "%3B", options: NSString.CompareOptions.literal, range: nil) + self.urlComponents.percentEncodedPath = encodedPath + } else { + self.urlComponents.percentEncodedPath = value + } + } else { + if value.contains(";") { + self.urlComponents.percentEncodedPath = value.percentEncodePathIncludingSemicolon() + } else { + self.urlComponents.path = value + } + } + return self + } + + @discardableResult + public func withHost(_ value: String) -> URIBuilder { + if value.isPercentEncoded { + // URLComponents.percentEncodedHost follows RFC 3986 + // and returns a decoded value if it is set with a percent encoded value + // However on Linux platform, it returns a percent encoded value. + // To ensure consistent behaviour, we will decode it ourselves on Linux platform + if currentOS == .linux { + self.urlComponents.host = value.removingPercentEncoding! + } else { + self.urlComponents.percentEncodedHost = value + } + } else { + self.urlComponents.host = value + } + return self + } + + @discardableResult + public func withPort(_ value: Int16?) -> URIBuilder { + self.urlComponents.port = value.map { Int($0) } + return self + } + + @discardableResult + public func withPort(_ value: Int?) -> URIBuilder { + self.urlComponents.port = value + return self + } + + @discardableResult + public func withQueryItems(_ value: [SDKURLQueryItem]) -> URIBuilder { + if value.isEmpty { + return self + } + if value.containsPercentEncode() { + self.urlComponents.percentEncodedQueryItems = value.toURLQueryItems() + } else { + self.urlComponents.queryItems = value.toURLQueryItems() + } + return self + } + + @discardableResult + public func appendQueryItems(_ items: [SDKURLQueryItem]) -> URIBuilder { + guard !items.isEmpty else { + return self + } + var queryItems = self.urlComponents.percentEncodedQueryItems ?? [] + queryItems += items.toURLQueryItems() + + if queryItems.containsPercentEncode() { + self.urlComponents.percentEncodedQueryItems = queryItems + } else { + self.urlComponents.queryItems = queryItems + } + + return self + } + + @discardableResult + public func appendQueryItem(_ item: SDKURLQueryItem) -> URIBuilder { + self.appendQueryItems([item]) + return self + } + + @discardableResult + public func withUsername(_ value: String?) -> URIBuilder { + if let username = value { + if username.isPercentEncoded { + self.urlComponents.percentEncodedUser = username + } else { + self.urlComponents.user = username + } + } + return self + } + + @discardableResult + public func withPassword(_ value: String?) -> URIBuilder { + if let password = value { + if password.isPercentEncoded { + self.urlComponents.percentEncodedPassword = password + } else { + self.urlComponents.password = password + } + } + return self + } + + @discardableResult + public func withFragment(_ value: String?) -> URIBuilder { + if let fragment = value { + if fragment.isPercentEncoded { + self.urlComponents.percentEncodedFragment = fragment + } else { + self.urlComponents.fragment = fragment + } + } + return self + } + + public func build() -> URI { + return URI(scheme: Scheme(rawValue: self.urlComponents.scheme!)!, + path: self.urlComponents.percentEncodedPath, + host: self.urlComponents.percentEncodedHost!, + port: self.urlComponents.port.map { Int16($0) }, + queryItems: self.urlComponents.percentEncodedQueryItems?.map { + SDKURLQueryItem(name: $0.name, value: $0.value) + } ?? [], + username: self.urlComponents.percentEncodedUser, + password: self.urlComponents.percentEncodedPassword, + fragment: self.urlComponents.percentEncodedFragment) + } + + // We still have to keep 'url' as an optional, since we're + // dealing with dynamic components that could be invalid. + fileprivate func getUrl() -> URL? { + let isInvalidHost = self.urlComponents.host?.isEmpty ?? false + return isInvalidHost && self.urlComponents.path.isEmpty ? nil : self.urlComponents.url + } +} + +extension String { + var isPercentEncoded: Bool { + let decoded = self.removingPercentEncoding + return decoded != nil && decoded != self + } + + public func percentEncodePathIncludingSemicolon() -> String { + let allowed = + // swiftlint:disable force_cast + (CharacterSet.urlPathAllowed as NSCharacterSet).mutableCopy() as! NSMutableCharacterSet + // swiftlint:enable force_cast + allowed.removeCharacters(in: ";") + return self.addingPercentEncoding(withAllowedCharacters: allowed as CharacterSet)! + } + + public func percentEncodeQuery() -> String { + return self.addingPercentEncoding(withAllowedCharacters: CharacterSet.urlQueryAllowed as CharacterSet)! + } +} + +extension Array where Element == SDKURLQueryItem { + public var queryString: String? { + if self.isEmpty { + return nil + } + return self.map { [$0.name, $0.value].compactMap { $0 }.joined(separator: "=") }.joined(separator: "&") + } + + public func toURLQueryItems() -> [URLQueryItem] { + return self.map { URLQueryItem(name: $0.name, value: $0.value) } + } + + public func containsPercentEncode() -> Bool { + return self.contains { item in + return item.name.isPercentEncoded || (item.value?.isPercentEncoded ?? false) + } + } +} + +extension Array where Element == URLQueryItem { + public func containsPercentEncode() -> Bool { + return self.contains { item in + return item.name.isPercentEncoded || (item.value?.isPercentEncoded ?? false) + } + } +} diff --git a/Sources/ClientRuntime/Networking/Endpoint.swift b/Sources/ClientRuntime/Networking/Endpoint.swift index dffe4b5f2..8d4d81d13 100644 --- a/Sources/ClientRuntime/Networking/Endpoint.swift +++ b/Sources/ClientRuntime/Networking/Endpoint.swift @@ -6,18 +6,20 @@ import Foundation public struct Endpoint: Hashable { - public let path: String - public let queryItems: [SDKURLQueryItem]? - public let protocolType: ProtocolType? - public let host: String - public let port: Int16 - public let headers: Headers? - public let properties: [String: AnyHashable] + public let uri: URI + public let headers: Headers + public var protocolType: ProtocolType? { uri.scheme } + public var queryItems: [SDKURLQueryItem] { uri.queryItems } + public var path: String { uri.path } + public var host: String { uri.host } + public var port: Int16? { uri.port } + public var url: URL? { uri.url } + private let properties: [String: AnyHashable] public init(urlString: String, - headers: Headers? = nil, + headers: Headers = Headers(), properties: [String: AnyHashable] = [:]) throws { - guard let url = URL(string: urlString) else { + guard let url = URLComponents(string: urlString)?.url else { throw ClientError.unknownError("invalid url \(urlString)") } @@ -25,18 +27,24 @@ public struct Endpoint: Hashable { } public init(url: URL, - headers: Headers? = nil, + headers: Headers = Headers(), properties: [String: AnyHashable] = [:]) throws { + guard let host = url.host else { throw ClientError.unknownError("invalid host \(String(describing: url.host))") } let protocolType = ProtocolType(rawValue: url.scheme ?? "") ?? .https - self.init(host: host, - path: url.path, - port: Int16(url.port ?? protocolType.port), - queryItems: url.toQueryItems(), - protocolType: protocolType, + + let uri = URIBuilder() + .withScheme(protocolType) + .withPath(url.path) + .withHost(host) + .withPort(url.port) + .withQueryItems(url.getQueryItems() ?? []) + .build() + + self.init(uri: uri, headers: headers, properties: properties) } @@ -45,38 +53,30 @@ public struct Endpoint: Hashable { path: String = "/", port: Int16 = 443, queryItems: [SDKURLQueryItem]? = nil, - protocolType: ProtocolType? = .https, - headers: Headers? = nil, + headers: Headers = Headers(), + protocolType: ProtocolType? = .https) { + + let uri = URIBuilder() + .withScheme(protocolType ?? .https) + .withPath(path) + .withHost(host) + .withPort(port) + .withQueryItems(queryItems ?? []) + .build() + + self.init(uri: uri, headers: headers) + } + + public init(uri: URI, + headers: Headers = Headers(), properties: [String: AnyHashable] = [:]) { - self.host = host - self.path = path - self.port = port - self.queryItems = queryItems - self.protocolType = protocolType + self.uri = uri self.headers = headers self.properties = properties } } extension Endpoint { - // We still have to keep 'url' as an optional, since we're - // dealing with dynamic components that could be invalid. - public var url: URL? { - var components = URLComponents() - components.scheme = protocolType?.rawValue - components.host = host.isEmpty ? nil : host // If host is empty, URL is invalid - components.percentEncodedPath = path - components.percentEncodedQuery = queryItemString - return (components.host == nil || components.scheme == nil) ? nil : components.url - } - - var queryItemString: String? { - guard let queryItems = queryItems else { return nil } - return queryItems.map { queryItem in - return [queryItem.name, queryItem.value].compactMap { $0 }.joined(separator: "=") - }.joined(separator: "&") - } - /// Returns list of auth schemes /// This is an internal API and subject to change without notice /// - Returns: list of auth schemes if present diff --git a/Sources/ClientRuntime/Networking/Http/CRT/CRTClientEngine.swift b/Sources/ClientRuntime/Networking/Http/CRT/CRTClientEngine.swift index 90fe9e8c0..fba6a2b5b 100644 --- a/Sources/ClientRuntime/Networking/Http/CRT/CRTClientEngine.swift +++ b/Sources/ClientRuntime/Networking/Http/CRT/CRTClientEngine.swift @@ -25,9 +25,9 @@ public class CRTClientEngine: HTTPClient { private let port: Int16 init(endpoint: Endpoint) { - self.protocolType = endpoint.protocolType - self.host = endpoint.host - self.port = endpoint.port + self.protocolType = endpoint.uri.scheme + self.host = endpoint.uri.host + self.port = endpoint.uri.port ?? endpoint.uri.defaultPort } } @@ -72,9 +72,9 @@ public class CRTClientEngine: HTTPClient { } private func createConnectionPool(endpoint: Endpoint) throws -> HTTPClientConnectionManager { - let tlsConnectionOptions = endpoint.protocolType == .https ? TLSConnectionOptions( + let tlsConnectionOptions = endpoint.uri.scheme == .https ? TLSConnectionOptions( context: self.crtTLSOptions?.resolveContext() ?? sharedDefaultIO.tlsContext, - serverName: endpoint.host + serverName: endpoint.uri.host ) : nil var socketOptions = SocketOptions(socketType: .stream) @@ -93,9 +93,9 @@ public class CRTClientEngine: HTTPClient { let options = HTTPClientConnectionOptions( clientBootstrap: sharedDefaultIO.clientBootstrap, - hostName: endpoint.host, + hostName: endpoint.uri.host, initialWindowSize: windowSize, - port: UInt32(endpoint.port), + port: UInt32(endpoint.uri.port ?? endpoint.uri.defaultPort), proxyOptions: nil, socketOptions: socketOptions, tlsOptions: tlsConnectionOptions, @@ -104,7 +104,7 @@ public class CRTClientEngine: HTTPClient { enableManualWindowManagement: false ) // not using backpressure yet logger.debug(""" - Creating connection pool for \(String(describing: endpoint.host)) \ + Creating connection pool for \(String(describing: endpoint.uri.host)) \ with max connections: \(maxConnectionsPerEndpoint) """) return try HTTPClientConnectionManager(options: options) @@ -122,20 +122,20 @@ public class CRTClientEngine: HTTPClient { let tlsConnectionOptions = TLSConnectionOptions( context: self.crtTLSOptions?.resolveContext() ?? sharedDefaultIO.tlsContext, alpnList: [ALPNProtocol.http2.rawValue], - serverName: endpoint.host + serverName: endpoint.uri.host ) let options = HTTP2StreamManagerOptions( clientBootstrap: sharedDefaultIO.clientBootstrap, - hostName: endpoint.host, - port: UInt32(endpoint.port), + hostName: endpoint.uri.host, + port: UInt32(endpoint.uri.port ?? endpoint.uri.defaultPort), maxConnections: maxConnectionsPerEndpoint, socketOptions: socketOptions, tlsOptions: tlsConnectionOptions, enableStreamManualWindowManagement: false ) logger.debug(""" - Creating connection pool for \(String(describing: endpoint.host)) \ + Creating connection pool for \(String(describing: endpoint.uri.host)) \ with max connections: \(maxConnectionsPerEndpoint) """) @@ -164,7 +164,7 @@ public class CRTClientEngine: HTTPClient { let connectionMgr = try await serialExecutor.getOrCreateConnectionPool(endpoint: request.endpoint) let connection = try await connectionMgr.acquireConnection() - self.logger.debug("Connection was acquired to: \(String(describing: request.endpoint.url?.absoluteString))") + self.logger.debug("Connection was acquired to: \(String(describing: request.destination.url?.absoluteString))") switch connection.httpVersion { case .version_1_1: self.logger.debug("Using HTTP/1.1 connection") diff --git a/Sources/ClientRuntime/Networking/Http/Headers.swift b/Sources/ClientRuntime/Networking/Http/Headers.swift index 315bbdd28..9e45afdf5 100644 --- a/Sources/ClientRuntime/Networking/Http/Headers.swift +++ b/Sources/ClientRuntime/Networking/Http/Headers.swift @@ -159,6 +159,10 @@ public struct Headers { return first + last } } + + public var isEmpty: Bool { + return self.headers.isEmpty + } } extension Headers: Equatable { diff --git a/Sources/ClientRuntime/Networking/Http/HttpResponse.swift b/Sources/ClientRuntime/Networking/Http/HttpResponse.swift index 672792d75..873f6fde0 100644 --- a/Sources/ClientRuntime/Networking/Http/HttpResponse.swift +++ b/Sources/ClientRuntime/Networking/Http/HttpResponse.swift @@ -13,6 +13,7 @@ public class HttpResponse: HttpUrlResponse, ResponseMessage { public var headers: Headers public var body: ByteStream + public var reason: String? private var _statusCode: HttpStatusCode private let statusCodeQueue = DispatchQueue(label: "statusCodeSerialQueue") @@ -29,13 +30,17 @@ public class HttpResponse: HttpUrlResponse, ResponseMessage { } } - public init(headers: Headers = .init(), statusCode: HttpStatusCode = .processing, body: ByteStream = .noStream) { + public init( + headers: Headers = .init(), + statusCode: HttpStatusCode = .processing, + body: ByteStream = .noStream, + reason: String? = nil) { self.headers = headers self._statusCode = statusCode self.body = body } - public init(headers: Headers = .init(), body: ByteStream, statusCode: HttpStatusCode) { + public init(headers: Headers = .init(), body: ByteStream, statusCode: HttpStatusCode, reason: String? = nil) { self.body = body self._statusCode = statusCode self.headers = headers diff --git a/Sources/ClientRuntime/Networking/Http/HttpUrlResponse.swift b/Sources/ClientRuntime/Networking/Http/HttpUrlResponse.swift index c379e7e24..aebd2ce42 100644 --- a/Sources/ClientRuntime/Networking/Http/HttpUrlResponse.swift +++ b/Sources/ClientRuntime/Networking/Http/HttpUrlResponse.swift @@ -9,4 +9,5 @@ protocol HttpUrlResponse { var headers: Headers { get set } var body: ByteStream { get set} var statusCode: HttpStatusCode {get set} + var reason: String? {get set} } diff --git a/Sources/ClientRuntime/Networking/Http/ProtocolType.swift b/Sources/ClientRuntime/Networking/Http/ProtocolType.swift index a34be6f6a..190482acd 100644 --- a/Sources/ClientRuntime/Networking/Http/ProtocolType.swift +++ b/Sources/ClientRuntime/Networking/Http/ProtocolType.swift @@ -5,6 +5,8 @@ import Foundation +public typealias Scheme = ProtocolType + public enum ProtocolType: String, CaseIterable { case http case https diff --git a/Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift b/Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift index fd6adc82a..76847af03 100644 --- a/Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift +++ b/Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift @@ -17,24 +17,30 @@ import struct Foundation.URLRequest // in the CRT engine so that is why it's a class public final class SdkHttpRequest: RequestMessage { public var body: ByteStream - public let endpoint: Endpoint + public let destination: URI + public var headers: Headers public let method: HttpMethodType - private var additionalHeaders: Headers = Headers() - public var headers: Headers { - var allHeaders = endpoint.headers ?? Headers() - allHeaders.addAll(headers: additionalHeaders) - return allHeaders - } + public var host: String { destination.host } + public var path: String { destination.path } + public var queryItems: [SDKURLQueryItem]? { destination.queryItems } public var trailingHeaders: Headers = Headers() - public var path: String { endpoint.path } - public var host: String { endpoint.host } - public var queryItems: [SDKURLQueryItem]? { endpoint.queryItems } + public var endpoint: Endpoint { + return Endpoint(uri: self.destination, headers: self.headers) + } + + public convenience init(method: HttpMethodType, + endpoint: Endpoint, + body: ByteStream = ByteStream.noStream) { + self.init(method: method, uri: endpoint.uri, headers: endpoint.headers, body: body) + } public init(method: HttpMethodType, - endpoint: Endpoint, + uri: URI, + headers: Headers, body: ByteStream = ByteStream.noStream) { self.method = method - self.endpoint = endpoint + self.destination = uri + self.headers = headers self.body = body } @@ -44,22 +50,20 @@ public final class SdkHttpRequest: RequestMessage { .withMethod(self.method) .withHeaders(self.headers) .withTrailers(self.trailingHeaders) - .withPath(self.path) - .withHost(self.host) - .withPort(self.endpoint.port) - .withProtocol(self.endpoint.protocolType ?? .https) - if let qItems = self.queryItems { - builder.withQueryItems(qItems) - } + .withPath(self.destination.path) + .withHost(self.destination.host) + .withPort(self.destination.port) + .withProtocol(self.destination.scheme) + .withQueryItems(self.destination.queryItems) return builder } public func withHeader(name: String, value: String) { - self.additionalHeaders.add(name: name, value: value) + self.headers.add(name: name, value: value) } public func withoutHeader(name: String) { - self.additionalHeaders.remove(name: name) + self.headers.remove(name: name) } public func withBody(_ body: ByteStream) { @@ -92,7 +96,8 @@ extension SdkHttpRequest { public func toHttpRequest() throws -> HTTPRequest { let httpRequest = try HTTPRequest() httpRequest.method = method.rawValue - httpRequest.path = [endpoint.path, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?") + httpRequest.path = [self.destination.path, self.destination.queryString] + .compactMap { $0 }.joined(separator: "?") httpRequest.addHeaders(headers: headers.toHttpHeaders()) httpRequest.body = isChunked ? nil : StreamableHttpBody(body: body) // body needs to be nil to use writeChunk() return httpRequest @@ -104,7 +109,8 @@ extension SdkHttpRequest { public func toHttp2Request() throws -> HTTPRequestBase { let httpRequest = try HTTPRequest() httpRequest.method = method.rawValue - httpRequest.path = [endpoint.path, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?") + httpRequest.path = [self.destination.path, self.destination.queryString] + .compactMap { $0 }.joined(separator: "?") httpRequest.addHeaders(headers: headers.toHttpHeaders()) // Remove the "Transfer-Encoding" header if it exists since h2 does not support it @@ -120,7 +126,7 @@ extension SdkHttpRequest { public extension URLRequest { init(sdkRequest: SdkHttpRequest) async throws { // Set URL - guard let url = sdkRequest.endpoint.url else { + guard let url = sdkRequest.destination.url else { throw ClientError.dataNotFound("Failed to construct URLRequest due to missing URL.") } self.init(url: url) @@ -155,9 +161,11 @@ extension SdkHttpRequest: CustomDebugStringConvertible, CustomStringConvertible public var description: String { let method = method.rawValue.uppercased() - let protocolType = endpoint.protocolType ?? ProtocolType.https - let query = String(describing: queryItems) - return "\(method) \(protocolType):\(endpoint.port) \n Path: \(endpoint.path) \n \(headers) \n \(query)" + let protocolType = self.destination.scheme + let query = self.destination.queryString ?? "" + let port = self.destination.port.map { String($0) } ?? "" + return "\(method) \(protocolType):\(port) \n " + + "Path: \(endpoint.uri.path) \n Headers: \(headers) \n Query: \(query)" } } @@ -198,8 +206,8 @@ public class SdkHttpRequestBuilder: RequestMessageBuilder { var host: String = "" var path: String = "/" var body: ByteStream = .noStream - var queryItems: [SDKURLQueryItem]? - var port: Int16 = 443 + var queryItems: [SDKURLQueryItem] = [] + var port: Int16? var protocolType: ProtocolType = .https var trailingHeaders: Headers = Headers() @@ -267,8 +275,7 @@ public class SdkHttpRequestBuilder: RequestMessageBuilder { @discardableResult public func withQueryItems(_ value: [SDKURLQueryItem]) -> SdkHttpRequestBuilder { - self.queryItems = self.queryItems ?? [] - self.queryItems?.append(contentsOf: value) + self.queryItems.append(contentsOf: value) return self } @@ -278,7 +285,7 @@ public class SdkHttpRequestBuilder: RequestMessageBuilder { } @discardableResult - public func withPort(_ value: Int16) -> SdkHttpRequestBuilder { + public func withPort(_ value: Int16?) -> SdkHttpRequestBuilder { self.port = value return self } @@ -290,15 +297,14 @@ public class SdkHttpRequestBuilder: RequestMessageBuilder { } public func build() -> SdkHttpRequest { - let endpoint = Endpoint(host: host, - path: path, - port: port, - queryItems: queryItems, - protocolType: protocolType, - headers: headers) - return SdkHttpRequest(method: methodType, - endpoint: endpoint, - body: body) + let uri = URIBuilder() + .withScheme(protocolType) + .withPath(path) + .withHost(host) + .withPort(port) + .withQueryItems(queryItems) + .build() + return SdkHttpRequest(method: methodType, uri: uri, headers: headers, body: body) } } diff --git a/Sources/ClientRuntime/Networking/Http/URLSession/URLSessionHTTPClient.swift b/Sources/ClientRuntime/Networking/Http/URLSession/URLSessionHTTPClient.swift index 433409b50..ef5817aca 100644 --- a/Sources/ClientRuntime/Networking/Http/URLSession/URLSessionHTTPClient.swift +++ b/Sources/ClientRuntime/Networking/Http/URLSession/URLSessionHTTPClient.swift @@ -462,10 +462,10 @@ public final class URLSessionHTTPClient: HTTPClient { /// - Returns: A `URLRequest` ready to be transmitted by `URLSession` for this operation. private func makeURLRequest(from request: SdkHttpRequest, httpBodyStream: InputStream?) throws -> URLRequest { var components = URLComponents() - components.scheme = config.protocolType?.rawValue ?? request.endpoint.protocolType?.rawValue ?? "https" - components.host = request.endpoint.host + components.scheme = config.protocolType?.rawValue ?? request.destination.scheme.rawValue + components.host = request.endpoint.uri.host components.port = port(for: request) - components.percentEncodedPath = request.path + components.percentEncodedPath = request.destination.path if let queryItems = request.queryItems, !queryItems.isEmpty { components.percentEncodedQueryItems = queryItems.map { Foundation.URLQueryItem(name: $0.name, value: $0.value) @@ -484,12 +484,12 @@ public final class URLSessionHTTPClient: HTTPClient { } private func port(for request: SdkHttpRequest) -> Int? { - switch (request.endpoint.protocolType, request.endpoint.port) { + switch (request.destination.scheme, request.destination.port) { case (.https, 443), (.http, 80): // Don't set port explicitly if it's the default port for the scheme return nil default: - return Int(request.endpoint.port) + return request.destination.port.map { Int($0) } } } } diff --git a/Sources/ClientRuntime/PrimitiveTypeExtensions/URL+Extension.swift b/Sources/ClientRuntime/PrimitiveTypeExtensions/URL+Extension.swift index 791ff1fe1..095435642 100644 --- a/Sources/ClientRuntime/PrimitiveTypeExtensions/URL+Extension.swift +++ b/Sources/ClientRuntime/PrimitiveTypeExtensions/URL+Extension.swift @@ -9,7 +9,7 @@ public typealias URL = Foundation.URL extension URL { - func toQueryItems() -> [SDKURLQueryItem]? { + func getQueryItems() -> [SDKURLQueryItem]? { URLComponents(url: self, resolvingAgainstBaseURL: false)? .queryItems? .map { SDKURLQueryItem(name: $0.name, value: $0.value) } diff --git a/Sources/SmithyTestUtil/RequestTestUtil/ExpectedSdkHttpRequest.swift b/Sources/SmithyTestUtil/RequestTestUtil/ExpectedSdkHttpRequest.swift index d8ba0506b..ade90fd6f 100644 --- a/Sources/SmithyTestUtil/RequestTestUtil/ExpectedSdkHttpRequest.swift +++ b/Sources/SmithyTestUtil/RequestTestUtil/ExpectedSdkHttpRequest.swift @@ -12,7 +12,7 @@ public struct ExpectedSdkHttpRequest { public var headers: Headers? public var forbiddenHeaders: [String]? public var requiredHeaders: [String]? - public let queryItems: [SDKURLQueryItem]? + public var queryItems: [SDKURLQueryItem] { endpoint.uri.queryItems } public let forbiddenQueryItems: [SDKURLQueryItem]? public let requiredQueryItems: [SDKURLQueryItem]? public let endpoint: Endpoint @@ -23,7 +23,6 @@ public struct ExpectedSdkHttpRequest { headers: Headers? = nil, forbiddenHeaders: [String]? = nil, requiredHeaders: [String]? = nil, - queryItems: [SDKURLQueryItem]? = nil, forbiddenQueryItems: [SDKURLQueryItem]? = nil, requiredQueryItems: [SDKURLQueryItem]? = nil, body: ByteStream = ByteStream.noStream) { @@ -32,7 +31,6 @@ public struct ExpectedSdkHttpRequest { self.headers = headers self.forbiddenHeaders = forbiddenHeaders self.requiredHeaders = requiredHeaders - self.queryItems = queryItems self.forbiddenQueryItems = forbiddenQueryItems self.requiredQueryItems = requiredQueryItems self.body = body @@ -139,12 +137,14 @@ public class ExpectedSdkHttpRequestBuilder { } public func build() -> ExpectedSdkHttpRequest { - let endpoint = Endpoint(host: host, - path: path, - port: port, - queryItems: queryItems, - protocolType: protocolType) - let queryItems = !queryItems.isEmpty ? queryItems : nil + let uri = URIBuilder() + .withScheme(protocolType) + .withPath(path) + .withHost(host) + .withPort(port) + .withQueryItems(queryItems) + .build() + let endpoint = Endpoint(uri: uri, headers: headers) let forbiddenQueryItems = !forbiddenQueryItems.isEmpty ? forbiddenQueryItems : nil let requiredQueryItems = !requiredQueryItems.isEmpty ? requiredQueryItems : nil @@ -156,7 +156,6 @@ public class ExpectedSdkHttpRequestBuilder { headers: headers, forbiddenHeaders: forbiddenHeaders, requiredHeaders: requiredHeaders, - queryItems: queryItems, forbiddenQueryItems: forbiddenQueryItems, requiredQueryItems: requiredQueryItems, body: body) diff --git a/Sources/SmithyTestUtil/RequestTestUtil/HttpRequestTestBase.swift b/Sources/SmithyTestUtil/RequestTestUtil/HttpRequestTestBase.swift index 32fb6ed7e..c46f0f729 100644 --- a/Sources/SmithyTestUtil/RequestTestUtil/HttpRequestTestBase.swift +++ b/Sources/SmithyTestUtil/RequestTestUtil/HttpRequestTestBase.swift @@ -214,8 +214,8 @@ open class HttpRequestTestBase: XCTestCase { assertQueryItems(expected.queryItems, actual.queryItems, file: file, line: line) - XCTAssertEqual(expected.endpoint.path, actual.path, file: file, line: line) - XCTAssertEqual(expected.endpoint.host, actual.host, file: file, line: line) + XCTAssertEqual(expected.endpoint.uri.path, actual.destination.path, file: file, line: line) + XCTAssertEqual(expected.endpoint.uri.host, actual.destination.host, file: file, line: line) XCTAssertEqual(expected.method, actual.method, file: file, line: line) assertForbiddenQueryItems(expected.forbiddenQueryItems, actual.queryItems, file: file, line: line) diff --git a/Sources/WeatherSDK/EndpointResolver.swift b/Sources/WeatherSDK/EndpointResolver.swift index 6a3e397a0..ee1f5de61 100644 --- a/Sources/WeatherSDK/EndpointResolver.swift +++ b/Sources/WeatherSDK/EndpointResolver.swift @@ -132,8 +132,8 @@ extension EndpointResolverMiddleware: ApplyEndpoint { attributes.set(key: AttributeKeys.signingAlgorithm, value: AWSSigningAlgorithm(rawValue: signingAlgorithm)) } - if let headers = endpoint.headers { - builder.withHeaders(headers) + if !endpoint.headers.isEmpty { + builder.withHeaders(endpoint.headers) } return builder.withMethod(attributes.getMethod()) diff --git a/Tests/ClientRuntimeTests/MessageTests/URITests.swift b/Tests/ClientRuntimeTests/MessageTests/URITests.swift new file mode 100644 index 000000000..db3c14da8 --- /dev/null +++ b/Tests/ClientRuntimeTests/MessageTests/URITests.swift @@ -0,0 +1,199 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation +import XCTest +@testable import ClientRuntime + +class URITests: XCTestCase { + let url = URL(string: "https://xctest.amazonaws.com?abc=def&ghi=jkl&mno=pqr")! + + let unencodedReservedCharacters: String = "!$&'()*+,;=" + + let encodedReservedCharacters: String = "%21%24%26%27%28%29%2A%2B%2C%3B%3D" + + func test_queryItems_setsQueryItemsFromURLInOrder() throws { + let uri = URIBuilder() + .withScheme(Scheme(rawValue: url.scheme!)!) + .withPath(url.path) + .withHost(url.host!) + .withPort(url.port) + .withQueryItems(url.getQueryItems()!) + .build() + + let expectedQueryItems = [ + SDKURLQueryItem(name: "abc", value: "def"), + SDKURLQueryItem(name: "ghi", value: "jkl"), + SDKURLQueryItem(name: "mno", value: "pqr") + ] + XCTAssertEqual(uri.queryItems, expectedQueryItems) + XCTAssertEqual(uri.queryString, "abc=def&ghi=jkl&mno=pqr") + } + + func test_hashableAndEquatable_hashesMatch() throws { + let uri1 = URIBuilder() + .withScheme(Scheme(rawValue: url.scheme!)!) + .withPath(url.path) + .withHost(url.host!) + .withPort(url.port) + .withQueryItems(url.getQueryItems()!) + .build() + let uri2 = URIBuilder() + .withScheme(Scheme(rawValue: url.scheme!)!) + .withPath(url.path) + .withHost(url.host!) + .withPort(url.port) + .withQueryItems(url.getQueryItems()!) + .build() + XCTAssertEqual(uri1, uri2) + XCTAssertEqual(uri1.hashValue, uri2.hashValue) + } + + func test_path_percentEncodedInput() throws { + let uri = URIBuilder() + .withScheme(Scheme(rawValue: url.scheme!)!) + .withPath(url.path) + .withHost(url.host!) + .withPort(443) + .withQueryItems(url.getQueryItems()!) + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.amazonaws.com:443?abc=def&ghi=jkl&mno=pqr") + } + + func test_path_unencodedInput() throws { + let uri = URIBuilder() + .withScheme(.https) + .withPath("/abc+def") + .withHost("xctest.amazonaws.com") + .withPort(443) + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.amazonaws.com:443/abc+def") + } + + func test_modifyURI() throws { + var uri = URIBuilder() + .withScheme(Scheme(rawValue: url.scheme!)!) + .withPath(url.path) + .withHost(url.host!) + .withPort(url.port) + .withQueryItems(url.getQueryItems()!) + .build() + + uri = uri.toBuilder() + .withPath("/x%2Dy%2Dz") + .withHost("%2Bxctest2.com") + .appendQueryItem(SDKURLQueryItem(name: "test", value: "1%2B2")) + .withFragment("fragment%21") + .withUsername("dan%21") + .withPassword("%24008") + .build() + + XCTAssertEqual(uri.url?.absoluteString, + "https://dan%21:%24008@+xctest2.com/x%2Dy%2Dz?abc=def&ghi=jkl&mno=pqr&test=1%2B2#fragment%21") + } + + func test_host_unencodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.\(unencodedReservedCharacters).com") + .withPath("/") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.!$&\'()*+,;=.com/") + } + + func test_host_encodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.\(encodedReservedCharacters).com") + .withPath("/") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.!$&\'()*+,;=.com/") + } + + func test_host_encodedAndUnencodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.\(unencodedReservedCharacters)\(encodedReservedCharacters).com") + .withPath("/") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.!$&\'()*+,;=!$&\'()*+,;=.com/") + } + + func test_path_unencodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/:@\(unencodedReservedCharacters)") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.com/:@!$&\'()*+,%3B=") + } + + func test_path_encodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/\(encodedReservedCharacters)") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.com/%21%24%26%27%28%29%2A%2B%2C%3B%3D") + } + + func test_path_encodedAndUnencodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/:@\(unencodedReservedCharacters)\(encodedReservedCharacters)") + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.com/:@!$&\'()*+,%3B=%21%24%26%27%28%29%2A%2B%2C%3B%3D") + } + + func test_query_unencodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/") + .withQueryItems([ + SDKURLQueryItem( + name: "key:@\(unencodedReservedCharacters))", + value: "value:@\(unencodedReservedCharacters)" + ), + ]) + .build() + XCTAssertEqual(uri.url?.absoluteString, "https://xctest.com/?key:@!$%26\'()*+,;%3D)=value:@!$%26\'()*+,;%3D") + } + + func test_query_encodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/") + .withQueryItems([ + SDKURLQueryItem( + name: "key:@\(encodedReservedCharacters))", + value: "value:@\(encodedReservedCharacters)" + ), + ]) + .build() + XCTAssertEqual(uri.url?.absoluteString, + "https://xctest.com/?key:@%21%24%26%27%28%29%2A%2B%2C%3B%3D)=value:@%21%24%26%27%28%29%2A%2B%2C%3B%3D") + } + + func test_query_unencodedAndEncodedReservedCharacters() throws { + let uri = URIBuilder() + .withScheme(.https) + .withHost("xctest.com") + .withPath("/") + .withQueryItems([ + SDKURLQueryItem( + name: "key:@\(encodedReservedCharacters))", + value: "value:@\(unencodedReservedCharacters)" + ), + ]) + .build() + XCTAssertEqual(uri.url?.absoluteString, + "https://xctest.com/?key:@%21%24%26%27%28%29%2A%2B%2C%3B%3D)=value:@!$&\'()*+,;=") + } +} diff --git a/Tests/ClientRuntimeTests/NetworkingTests/EndpointTests.swift b/Tests/ClientRuntimeTests/NetworkingTests/EndpointTests.swift index b86a286a8..e30ac80f4 100644 --- a/Tests/ClientRuntimeTests/NetworkingTests/EndpointTests.swift +++ b/Tests/ClientRuntimeTests/NetworkingTests/EndpointTests.swift @@ -19,7 +19,7 @@ class EndpointTests: XCTestCase { SDKURLQueryItem(name: "ghi", value: "jkl"), SDKURLQueryItem(name: "mno", value: "pqr") ] - XCTAssertEqual(endpoint.queryItems, expectedQueryItems) + XCTAssertEqual(endpoint.uri.queryItems, expectedQueryItems) } func test_hashableAndEquatable_hashesMatchWhenURLQueryItemsAreEqual() throws { @@ -35,9 +35,9 @@ class EndpointTests: XCTestCase { path: "/abc%2Bdef", protocolType: .https ) - let foundationURL = try XCTUnwrap(endpoint.url) + let foundationURL = try XCTUnwrap(endpoint.uri.url) let absoluteString = foundationURL.absoluteString - XCTAssertEqual(absoluteString, "https://xctest.amazonaws.com/abc%2Bdef") + XCTAssertEqual(absoluteString, "https://xctest.amazonaws.com:443/abc%2Bdef") } func test_path_unencodedInput() throws { @@ -46,8 +46,8 @@ class EndpointTests: XCTestCase { path: "/abc+def", protocolType: .https ) - let foundationURL = try XCTUnwrap(endpoint.url) + let foundationURL = try XCTUnwrap(endpoint.uri.url) let absoluteString = foundationURL.absoluteString - XCTAssertEqual(absoluteString, "https://xctest.amazonaws.com/abc+def") + XCTAssertEqual(absoluteString, "https://xctest.amazonaws.com:443/abc+def") } } diff --git a/Tests/ClientRuntimeTests/NetworkingTests/Http/HttpRequestTests.swift b/Tests/ClientRuntimeTests/NetworkingTests/Http/HttpRequestTests.swift index 1b7d5a433..1f3f5ee20 100644 --- a/Tests/ClientRuntimeTests/NetworkingTests/Http/HttpRequestTests.swift +++ b/Tests/ClientRuntimeTests/NetworkingTests/Http/HttpRequestTests.swift @@ -77,7 +77,7 @@ class HttpRequestTests: NetworkingTestUtils { XCTAssertTrue(headersFromRequest.contains { $0.key == "Testname-2" && $0.value == "testvalue-2" }) let expectedBody = try await httpBody.readData() XCTAssertEqual(urlRequest.httpBody, expectedBody) - XCTAssertEqual(urlRequest.url, endpoint.url) + XCTAssertEqual(urlRequest.url, endpoint.uri.url) XCTAssertEqual(urlRequest.httpMethod, mockHttpRequest.method.rawValue) } @@ -118,21 +118,21 @@ class HttpRequestTests: NetworkingTestUtils { .withQueryItem(queryItem2) .withHeader(name: "Content-Length", value: "6") - XCTAssert(builder.queryItems?.count == 2) + XCTAssert(builder.queryItems.count == 2) let httpRequest = try builder.build().toHttpRequest() httpRequest.path = "/hello?foo=bar&quz=bar&signedthing=signed" let updatedRequest = builder.update(from: httpRequest, originalRequest: builder.build()) XCTAssert(updatedRequest.path == "/hello") - XCTAssert(updatedRequest.queryItems?.count == 3) - XCTAssert(updatedRequest.queryItems?.contains(queryItem1) ?? false) - XCTAssert(updatedRequest.queryItems?.contains(queryItem2) ?? false) - XCTAssert(updatedRequest.queryItems?.contains(SDKURLQueryItem(name: "signedthing", value: "signed")) ?? false) + XCTAssert(updatedRequest.queryItems.count == 3) + XCTAssert(updatedRequest.queryItems.contains(queryItem1)) + XCTAssert(updatedRequest.queryItems.contains(queryItem2)) + XCTAssert(updatedRequest.queryItems.contains(SDKURLQueryItem(name: "signedthing", value: "signed"))) } func testPathInInHttpRequestIsNotAltered() throws { - let path = "/space /colon:/dollar$/tilde~/dash-/underscore_/period." + let path = "/space%20/colon:/dollar$/tilde~/dash-/underscore_/period." let builder = SdkHttpRequestBuilder() .withHeader(name: "Host", value: "xctest.amazon.com") .withPath(path) @@ -143,7 +143,7 @@ class HttpRequestTests: NetworkingTestUtils { func testConversionToUrlRequestFailsWithInvalidEndpoint() { // Testing with an invalid endpoint where host is empty, // path is empty, and protocolType is nil. - let endpoint = Endpoint(host: "", path: "", protocolType: nil) + let endpoint = Endpoint(host: "", path: "") XCTAssertNil(endpoint.url, "An invalid endpoint should result in a nil URL.") } } diff --git a/Tests/ClientRuntimeTests/NetworkingTests/Http/SdkRequestBuilderTests.swift b/Tests/ClientRuntimeTests/NetworkingTests/Http/SdkRequestBuilderTests.swift index d6b18d1c3..8c6c96b17 100644 --- a/Tests/ClientRuntimeTests/NetworkingTests/Http/SdkRequestBuilderTests.swift +++ b/Tests/ClientRuntimeTests/NetworkingTests/Http/SdkRequestBuilderTests.swift @@ -17,8 +17,8 @@ class SdkRequestBuilderTests: XCTestCase { crtRequest.path = pathToMatch let updatedRequest = SdkHttpRequestBuilder().update(from: crtRequest, originalRequest: originalRequest).build() - let updatedPath = [updatedRequest.endpoint.path, updatedRequest.endpoint.queryItemString].compactMap { $0 }.joined(separator: "?") + let updatedPath = [updatedRequest.destination.path, updatedRequest.destination.queryString].compactMap { $0 }.joined(separator: "?") XCTAssertEqual(pathToMatch, updatedPath) - XCTAssertEqual(url, updatedRequest.endpoint.url?.absoluteString) + XCTAssertEqual(url, updatedRequest.destination.url?.absoluteString) } } diff --git a/Tests/SmithyTestUtilTests/RequestTestUtilTests/HttpRequestTestBaseTests.swift b/Tests/SmithyTestUtilTests/RequestTestUtilTests/HttpRequestTestBaseTests.swift index 18b593143..b5971b20f 100644 --- a/Tests/SmithyTestUtilTests/RequestTestUtilTests/HttpRequestTestBaseTests.swift +++ b/Tests/SmithyTestUtilTests/RequestTestUtilTests/HttpRequestTestBaseTests.swift @@ -196,7 +196,7 @@ class HttpRequestTestBaseTests: HttpRequestTestBase { let forbiddenQueryParams = ["ForbiddenQuery"] for forbiddenQueryParam in forbiddenQueryParams { XCTAssertFalse( - self.queryItemExists(forbiddenQueryParam, in: actual.endpoint.queryItems), + self.queryItemExists(forbiddenQueryParam, in: actual.destination.queryItems), "Forbidden Query:\(forbiddenQueryParam) exists in query items" ) } @@ -208,7 +208,7 @@ class HttpRequestTestBaseTests: HttpRequestTestBase { let requiredQueryParams = ["RequiredQuery"] for requiredQueryParam in requiredQueryParams { - XCTAssertTrue(self.queryItemExists(requiredQueryParam, in: actual.endpoint.queryItems), + XCTAssertTrue(self.queryItemExists(requiredQueryParam, in: actual.destination.queryItems), "Required Query:\(requiredQueryParam) does not exist in query items") } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/EndpointResolverMiddleware.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/EndpointResolverMiddleware.kt index 7b8a754ef..c80bcef84 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/EndpointResolverMiddleware.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/EndpointResolverMiddleware.kt @@ -103,8 +103,8 @@ open class EndpointResolverMiddleware( attributes.set(key: AttributeKeys.signingAlgorithm, value: AWSSigningAlgorithm(rawValue: signingAlgorithm)) } - if let headers = endpoint.headers { - builder.withHeaders(headers) + if !endpoint.headers.isEmpty { + builder.withHeaders(endpoint.headers) } return builder.withMethod(attributes.getMethod())