From 10807072fc00f195a6f40c90b291a203d962851b Mon Sep 17 00:00:00 2001 From: Eric Rosenberg Date: Thu, 6 Jun 2024 17:02:12 -0700 Subject: [PATCH 1/3] Implement RFC8441 Extended CONNECT --- .../ConnectionStateMachine.swift | 29 +++++++++++++ .../ReceivingHeadersState.swift | 12 ++++-- .../SendingHeadersState.swift | 12 ++++-- .../HTTP2SettingsState.swift | 5 +++ .../HasExtendedConnectSettings.swift | 21 +++++++++ .../HasLocalSettings.swift | 6 +++ .../HasRemoteSettings.swift | 14 ++++++ .../NIOHTTP2/HPACKHeaders+Validation.swift | 31 +++++++++---- .../NIOHTTP2/HTTP2ConnectionStateChange.swift | 4 ++ Sources/NIOHTTP2/HTTP2Error.swift | 43 +++++++++++++++++++ Sources/NIOHTTP2/StreamStateMachine.swift | 19 +++++--- .../ConnectionStateMachineTests.swift | 42 +++++++++++++++--- Tests/NIOHTTP2Tests/HTTP2ErrorTests.swift | 1 + ...eClientServerFramePayloadStreamTests.swift | 24 +++++++++++ Tests/NIOHTTP2Tests/TestUtilities.swift | 4 +- 15 files changed, 237 insertions(+), 30 deletions(-) create mode 100644 Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStateMachine.swift b/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStateMachine.swift index fa6510f6..ac7fd0a8 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStateMachine.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/ConnectionStateMachine.swift @@ -87,6 +87,10 @@ struct HTTP2ConnectionStateMachine { return self.localSettings.initialWindowSize } + var remoteSupportsExtendedConnect: Bool { + false + } + init(fromIdle idleState: IdleConnectionState, localSettings settings: HTTP2SettingsState) { self.role = idleState.role self.headerBlockValidation = idleState.headerBlockValidation @@ -117,6 +121,10 @@ struct HTTP2ConnectionStateMachine { return HTTP2SettingsState.defaultInitialWindowSize } + var localSupportsExtendedConnect: Bool { + false + } + init(fromIdle idleState: IdleConnectionState, remoteSettings settings: HTTP2SettingsState) { self.role = idleState.role self.headerBlockValidation = idleState.headerBlockValidation @@ -198,6 +206,10 @@ struct HTTP2ConnectionStateMachine { return self.role == .client } + var localSupportsExtendedConnect: Bool { + false + } + init(fromPrefaceReceived state: PrefaceReceivedState, lastStreamID: HTTP2StreamID) { self.role = state.role self.headerBlockValidation = state.headerBlockValidation @@ -236,6 +248,10 @@ struct HTTP2ConnectionStateMachine { return self.role == .server } + var remoteSupportsExtendedConnect: Bool { + false + } + init(fromPrefaceSent state: PrefaceSentState, lastStreamID: HTTP2StreamID) { self.role = state.role self.headerBlockValidation = state.headerBlockValidation @@ -412,6 +428,14 @@ struct HTTP2ConnectionStateMachine { var lastLocalStreamID: HTTP2StreamID var lastRemoteStreamID: HTTP2StreamID + var localSupportsExtendedConnect: Bool { + false + } + + var remoteSupportsExtendedConnect: Bool { + false + } + init(previousState: PreviousState) { self.role = previousState.role self.headerBlockValidation = previousState.headerBlockValidation @@ -1630,6 +1654,11 @@ extension HTTP2ConnectionStateMachine { guard setting._value >= (1 << 14) && setting._value <= ((1 << 24) - 1) else { return .connectionError(underlyingError: NIOHTTP2Errors.invalidSetting(setting: setting), type: .protocolError) } + case .enableConnectProtocol: + // Must be 0 or 1 + guard setting._value <= 1 else { + return .connectionError(underlyingError: NIOHTTP2Errors.invalidSetting(setting: setting), type: .protocolError) + } default: // All other settings have unrestricted ranges. break diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingHeadersState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingHeadersState.swift index 4bac538b..fed681a8 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingHeadersState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/FrameReceivingStates/ReceivingHeadersState.swift @@ -17,7 +17,7 @@ import NIOHPACK /// can validly accept headers. /// /// This protocol should only be conformed to by states for the HTTP/2 connection state machine. -protocol ReceivingHeadersState: HasFlowControlWindows { +protocol ReceivingHeadersState: HasFlowControlWindows, HasLocalExtendedConnectSettings, HasRemoteExtendedConnectSettings { var role: HTTP2ConnectionStateMachine.ConnectionRole { get } var headerBlockValidation: HTTP2ConnectionStateMachine.ValidationState { get } @@ -37,11 +37,13 @@ extension ReceivingHeadersState { let result: StateMachineResultWithStreamEffect let validateHeaderBlock = self.headerBlockValidation == .enabled let validateContentLength = self.contentLengthValidation == .enabled + let localSupportsExtendedConnect = self.localSupportsExtendedConnect + let remoteSupportsExtendedConnect = self.remoteSupportsExtendedConnect if self.role == .server && streamID.mayBeInitiatedBy(.client) { do { result = try self.streamState.modifyStreamStateCreateIfNeeded(streamID: streamID, localRole: .server, localInitialWindowSize: self.localInitialWindowSize, remoteInitialWindowSize: self.remoteInitialWindowSize) { - $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } } catch { return StateMachineResultWithEffect(result: .connectionError(underlyingError: error, type: .protocolError), effect: nil) @@ -49,7 +51,7 @@ extension ReceivingHeadersState { } else { // HEADERS cannot create streams for servers, so this must be for a stream we already know about. result = self.streamState.modifyStreamState(streamID: streamID, ignoreRecentlyReset: true) { - $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } } @@ -69,6 +71,8 @@ extension ReceivingHeadersState where Self: LocallyQuiescingState { mutating func receiveHeaders(streamID: HTTP2StreamID, headers: HPACKHeaders, isEndStreamSet endStream: Bool) -> StateMachineResultWithEffect { let validateHeaderBlock = self.headerBlockValidation == .enabled let validateContentLength = self.contentLengthValidation == .enabled + let localSupportsExtendedConnect = self.localSupportsExtendedConnect + let remoteSupportsExtendedConnect = self.remoteSupportsExtendedConnect if streamID.mayBeInitiatedBy(.client) && streamID > self.lastRemoteStreamID { return StateMachineResultWithEffect(result: .ignoreFrame, effect: nil) @@ -76,7 +80,7 @@ extension ReceivingHeadersState where Self: LocallyQuiescingState { // At this stage we've quiesced, so the remote peer is not allowed to create new streams. let result = self.streamState.modifyStreamState(streamID: streamID, ignoreRecentlyReset: true) { - $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.receiveHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } return StateMachineResultWithEffect(result, inboundFlowControlWindow: self.inboundFlowControlWindow, diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingHeadersState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingHeadersState.swift index 5a32d626..bcbdffff 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingHeadersState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/FrameSendingStates/SendingHeadersState.swift @@ -18,7 +18,7 @@ import NIOHPACK /// can validly send headers. /// /// This protocol should only be conformed to by states for the HTTP/2 connection state machine. -protocol SendingHeadersState: HasFlowControlWindows { +protocol SendingHeadersState: HasFlowControlWindows, HasLocalExtendedConnectSettings, HasRemoteExtendedConnectSettings { var role: HTTP2ConnectionStateMachine.ConnectionRole { get } var headerBlockValidation: HTTP2ConnectionStateMachine.ValidationState { get } @@ -38,6 +38,8 @@ extension SendingHeadersState { let result: StateMachineResultWithStreamEffect let validateHeaderBlock = self.headerBlockValidation == .enabled let validateContentLength = self.contentLengthValidation == .enabled + let localSupportsExtendedConnect = self.localSupportsExtendedConnect + let remoteSupportsExtendedConnect = self.remoteSupportsExtendedConnect if self.role == .client && streamID.mayBeInitiatedBy(.client) { do { @@ -45,7 +47,7 @@ extension SendingHeadersState { localRole: .client, localInitialWindowSize: self.localInitialWindowSize, remoteInitialWindowSize: self.remoteInitialWindowSize) { - $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } } catch { return StateMachineResultWithEffect(result: .connectionError(underlyingError: error, type: .protocolError), effect: nil) @@ -53,7 +55,7 @@ extension SendingHeadersState { } else { // HEADERS cannot create streams for servers, so this must be for a stream we already know about. result = self.streamState.modifyStreamState(streamID: streamID, ignoreRecentlyReset: false) { - $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } } @@ -70,6 +72,8 @@ extension SendingHeadersState where Self: RemotelyQuiescingState { mutating func sendHeaders(streamID: HTTP2StreamID, headers: HPACKHeaders, isEndStreamSet endStream: Bool) -> StateMachineResultWithEffect { let validateHeaderBlock = self.headerBlockValidation == .enabled let validateContentLength = self.contentLengthValidation == .enabled + let localSupportsExtendedConnect = self.localSupportsExtendedConnect + let remoteSupportsExtendedConnect = self.remoteSupportsExtendedConnect if streamID.mayBeInitiatedBy(.client) && self.role == .client && streamID > self.streamState.lastClientStreamID { @@ -77,7 +81,7 @@ extension SendingHeadersState where Self: RemotelyQuiescingState { return StateMachineResultWithEffect(result: .connectionError(underlyingError: error, type: .protocolError), effect: nil) } let result = self.streamState.modifyStreamState(streamID: streamID, ignoreRecentlyReset: false) { - $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, isEndStreamSet: endStream) + $0.sendHeaders(headers: headers, validateHeaderBlock: validateHeaderBlock, validateContentLength: validateContentLength, localSupportsExtendedConnect: localSupportsExtendedConnect, remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, isEndStreamSet: endStream) } return StateMachineResultWithEffect(result, inboundFlowControlWindow: self.inboundFlowControlWindow, diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/HTTP2SettingsState.swift b/Sources/NIOHTTP2/ConnectionStateMachine/HTTP2SettingsState.swift index 57c90971..e5796156 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/HTTP2SettingsState.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/HTTP2SettingsState.swift @@ -70,6 +70,11 @@ struct HTTP2SettingsState { return self[.enablePush]! } + /// The current value of SETTINGS_ENABLE_CONNECT_PROTOCOL + var enableConnectProtocol: UInt32? { + return self[.enableConnectProtocol] + } + /// The default value of SETTINGS_INITIAL_WINDOW_SIZE. static let defaultInitialWindowSize: UInt32 = 65535 diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift b/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift new file mode 100644 index 00000000..8099e64f --- /dev/null +++ b/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2024 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +protocol HasLocalExtendedConnectSettings { + var localSupportsExtendedConnect: Bool { get } +} + +protocol HasRemoteExtendedConnectSettings { + var remoteSupportsExtendedConnect: Bool { get } +} diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/HasLocalSettings.swift b/Sources/NIOHTTP2/ConnectionStateMachine/HasLocalSettings.swift index 8f4f3ee3..e5fd855c 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/HasLocalSettings.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/HasLocalSettings.swift @@ -25,6 +25,12 @@ protocol HasLocalSettings { var inboundFlowControlWindow: HTTP2FlowControlWindow { get set } } +extension HasLocalExtendedConnectSettings where Self: HasLocalSettings { + var localSupportsExtendedConnect: Bool { + self.localSettings.enableConnectProtocol == 1 + } +} + extension HasLocalSettings { mutating func receiveSettingsAck(frameDecoder: inout HTTP2FrameDecoder) -> StateMachineResultWithEffect { // We do a little switcheroo here to avoid problems with overlapping accesses to diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/HasRemoteSettings.swift b/Sources/NIOHTTP2/ConnectionStateMachine/HasRemoteSettings.swift index ea2b4e10..93281a65 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/HasRemoteSettings.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/HasRemoteSettings.swift @@ -25,6 +25,12 @@ protocol HasRemoteSettings { var outboundFlowControlWindow: HTTP2FlowControlWindow { get set } } +extension HasRemoteExtendedConnectSettings where Self: HasRemoteSettings { + var remoteSupportsExtendedConnect: Bool { + self.remoteSettings.enableConnectProtocol == 1 + } +} + extension HasRemoteSettings { mutating func receiveSettingsChange(_ settings: HTTP2Settings, frameEncoder: inout HTTP2FrameEncoder) -> (StateMachineResultWithEffect, PostFrameOperation) { // We do a little switcheroo here to avoid problems with overlapping accesses to @@ -65,6 +71,12 @@ extension HasRemoteSettings { effect.streamWindowSizeChange += Int(delta) case .maxFrameSize: effect.newMaxFrameSize = newValue + case .enableConnectProtocol: + // Must not transition from 1 -> 0 + if originalValue == 1 && newValue == 0 { + throw NIOHTTP2Errors.invalidSetting(setting: HTTP2Setting(parameter: setting, value: Int(newValue))) + } + effect.enableConnectProtocol = newValue == 1 default: // No operation required return @@ -73,6 +85,8 @@ extension HasRemoteSettings { return (.init(result: .succeed, effect: .remoteSettingsChanged(effect)), .sendAck) } catch let err where err is NIOHTTP2Errors.InvalidFlowControlWindowSize { return (.init(result: .connectionError(underlyingError: err, type: .flowControlError), effect: nil), .nothing) + } catch let err where err is NIOHTTP2Errors.InvalidSetting { + return (.init(result: .connectionError(underlyingError: err, type: .protocolError), effect: nil), .nothing) } catch { preconditionFailure("Unexpected error thrown: \(error)") } diff --git a/Sources/NIOHTTP2/HPACKHeaders+Validation.swift b/Sources/NIOHTTP2/HPACKHeaders+Validation.swift index 83aebbaa..0f9149be 100644 --- a/Sources/NIOHTTP2/HPACKHeaders+Validation.swift +++ b/Sources/NIOHTTP2/HPACKHeaders+Validation.swift @@ -17,8 +17,8 @@ extension HPACKHeaders { /// Checks that a given HPACKHeaders block is a valid request header block, meeting all of the constraints of RFC 7540. /// /// If the header block is not valid, throws an error. - internal func validateRequestBlock() throws { - return try RequestBlockValidator.validateBlock(self) + internal func validateRequestBlock(supportsExtendedConnect: Bool) throws { + return try RequestBlockValidator.validateBlock(self, supportsExtendedConnect: supportsExtendedConnect) } /// Checks that a given HPACKHeaders block is a valid response header block, meeting all of the constraints of RFC 7540. @@ -78,7 +78,7 @@ fileprivate protocol HeaderBlockValidator { extension HeaderBlockValidator { /// Validates that a header block meets the requirements of this `HeaderBlockValidator`. - fileprivate static func validateBlock(_ block: HPACKHeaders) throws { + fileprivate static func validateBlock(_ block: HPACKHeaders, supportsExtendedConnect: Bool = false) throws { var validator = Self() var blockSection = BlockSection.pseudoHeaders var seenPseudoHeaders = PseudoHeaders(rawValue: 0) @@ -88,7 +88,7 @@ extension HeaderBlockValidator { try blockSection.validField(fieldName) try fieldName.legalHeaderField(value: value) - let thisPseudoHeaderFieldType = try seenPseudoHeaders.seenNewHeaderField(fieldName) + let thisPseudoHeaderFieldType = try seenPseudoHeaders.seenNewHeaderField(fieldName, supportsExtendedConnect: supportsExtendedConnect) try validator.validateNextField(name: fieldName, value: value, pseudoHeaderType: thisPseudoHeaderFieldType) } @@ -106,6 +106,7 @@ extension HeaderBlockValidator { /// An object that can be used to validate if a given header block is a valid request header block. fileprivate struct RequestBlockValidator { private var isConnectRequest: Bool = false + private var containsProtocolPseudoheader: Bool = false } extension RequestBlockValidator: HeaderBlockValidator { @@ -141,8 +142,6 @@ extension RequestBlockValidator: HeaderBlockValidator { // - On CONNECT requests without the :protocol pseudo-header, :method and :authority are mandatory, no others are allowed. // // This is a bit awkward. - // - // For now we don't support extended-CONNECT, but when we do we'll need to update the logic here. if let pseudoHeaderType = pseudoHeaderType { assert(name.fieldType == .pseudoHeaderField) @@ -150,6 +149,8 @@ extension RequestBlockValidator: HeaderBlockValidator { case .method: // This is a method pseudo-header. Check if the value is CONNECT. self.isConnectRequest = value == "CONNECT" + case .extConnectProtocol: + self.containsProtocolPseudoheader = true case .path: // This is a path pseudo-header. It must not be empty. if value.utf8.count == 0 { @@ -171,7 +172,11 @@ extension RequestBlockValidator: HeaderBlockValidator { var allowedPseudoHeaderFields: PseudoHeaders { // For the logic behind this if statement, see the comment in validateNextField. if self.isConnectRequest { - return .allowedConnectRequestHeaders + if self.containsProtocolPseudoheader { + return .allowedExtendedConnectRequestHeaders + } else { + return .allowedConnectRequestHeaders + } } else { return .allowedRequestHeaders } @@ -179,7 +184,7 @@ extension RequestBlockValidator: HeaderBlockValidator { var mandatoryPseudoHeaderFields: PseudoHeaders { // For the logic behind this if statement, see the comment in validateNextField. - if self.isConnectRequest { + if self.isConnectRequest && !self.containsProtocolPseudoheader { return .mandatoryConnectRequestHeaders } else { return .mandatoryRequestHeaders @@ -339,9 +344,11 @@ fileprivate struct PseudoHeaders: OptionSet { static let scheme = PseudoHeaders(rawValue: 1 << 2) static let authority = PseudoHeaders(rawValue: 1 << 3) static let status = PseudoHeaders(rawValue: 1 << 4) + static let extConnectProtocol = PseudoHeaders(rawValue: 1 << 5) static let mandatoryRequestHeaders: PseudoHeaders = [.path, .method, .scheme] static let allowedRequestHeaders: PseudoHeaders = [.path, .method, .scheme, .authority] + static let allowedExtendedConnectRequestHeaders: PseudoHeaders = [.path, .method, .scheme, .authority, .extConnectProtocol] static let mandatoryConnectRequestHeaders: PseudoHeaders = [.method, .authority] static let allowedConnectRequestHeaders: PseudoHeaders = [.method, .authority] static let mandatoryResponseHeaders: PseudoHeaders = [.status] @@ -365,6 +372,8 @@ extension PseudoHeaders { self = .authority case "status": self = .status + case "protocol": + self = .extConnectProtocol default: return nil } @@ -374,7 +383,7 @@ extension PseudoHeaders { extension PseudoHeaders { /// Updates this set of PseudoHeaders with any new pseudo headers we've seen. Also returns a PseudoHeaders that marks /// the type of this specific header field. - mutating func seenNewHeaderField(_ name: HeaderFieldName) throws -> PseudoHeaders? { + mutating func seenNewHeaderField(_ name: HeaderFieldName, supportsExtendedConnect: Bool) throws -> PseudoHeaders? { // We need to check if this is a pseudo-header field we've seen before and one we recognise. // We only want to see a pseudo-header field once. guard name.fieldType == .pseudoHeaderField else { @@ -385,6 +394,10 @@ extension PseudoHeaders { throw NIOHTTP2Errors.unknownPseudoHeader(":\(name.baseName)") } + if pseudoHeaderType == .extConnectProtocol && !supportsExtendedConnect { + throw NIOHTTP2Errors.unsupportedPseudoHeader(":\(name.baseName)") + } + if self.contains(pseudoHeaderType) { throw NIOHTTP2Errors.duplicatePseudoHeader(":\(name.baseName)") } diff --git a/Sources/NIOHTTP2/HTTP2ConnectionStateChange.swift b/Sources/NIOHTTP2/HTTP2ConnectionStateChange.swift index 870fd3b4..09b34fda 100644 --- a/Sources/NIOHTTP2/HTTP2ConnectionStateChange.swift +++ b/Sources/NIOHTTP2/HTTP2ConnectionStateChange.swift @@ -150,6 +150,8 @@ internal enum NIOHTTP2ConnectionStateChange: Hashable { internal var newMaxFrameSize: UInt32? internal var newMaxConcurrentStreams: UInt32? + + internal var enableConnectProtocol: Bool? } /// The local peer's settings have changed in a way that is not trivial to decode. @@ -162,6 +164,8 @@ internal enum NIOHTTP2ConnectionStateChange: Hashable { internal var newMaxFrameSize: UInt32? internal var newMaxHeaderListSize: UInt32? + + internal var enableConnectProtocol: Bool? } } diff --git a/Sources/NIOHTTP2/HTTP2Error.swift b/Sources/NIOHTTP2/HTTP2Error.swift index c32c43f1..5bb0ab4b 100644 --- a/Sources/NIOHTTP2/HTTP2Error.swift +++ b/Sources/NIOHTTP2/HTTP2Error.swift @@ -185,6 +185,10 @@ public enum NIOHTTP2Errors { return UnknownPseudoHeader(name, file: file, line: line) } + public static func unsupportedPseudoHeader(_ name: String, file: String = #fileID, line: UInt = #line) -> UnsupportedPseudoHeader { + return UnsupportedPseudoHeader(name, file: file, line: line) + } + /// Creates a ``InvalidPseudoHeaders`` error with appropriate source context. /// /// - Parameters: @@ -1111,6 +1115,45 @@ public enum NIOHTTP2Errors { } } + /// An unsupported pseudo-header was received. + public struct UnsupportedPseudoHeader: NIOHTTP2Error, CustomStringConvertible, @unchecked Sendable { + // @unchecked Sendable because access is controlled by getters and copy-on-write setters giving this value semantics + + private var storage: StringAndLocationStorage + + private mutating func copyStorageIfNotUniquelyReferenced() { + if !isKnownUniquelyReferenced(&self.storage) { + self.storage = self.storage.copy() + } + } + + /// The name of the unsupported pseudo-header field. + public var name: String { + get { + return self.storage.value + } + set { + self.copyStorageIfNotUniquelyReferenced() + self.storage.value = newValue + } + } + + /// The file and line where the error was created. + public var location: String { + get { + return self.storage.location + } + } + + public var description: String { + return "UnsupportedPseudoHeader(name: \(self.name), location: \(self.location))" + } + + fileprivate init(_ name: String, file: String, line: UInt) { + self.storage = .init(name, file: file, line: line) + } + } + /// A header block was received with an invalid set of pseudo-headers for the block type. public struct InvalidPseudoHeaders: NIOHTTP2Error { /// The header block containing the invalid set of pseudo-headers. diff --git a/Sources/NIOHTTP2/StreamStateMachine.swift b/Sources/NIOHTTP2/StreamStateMachine.swift index 809b6db8..35e12893 100644 --- a/Sources/NIOHTTP2/StreamStateMachine.swift +++ b/Sources/NIOHTTP2/StreamStateMachine.swift @@ -265,7 +265,7 @@ extension HTTP2StreamStateMachine { /// it meets the requirements of RFC 7540 for containing a well-formed header block, and additionally /// checks whether the value of the end stream bit is acceptable. If all checks pass, transitions the /// state to the appropriate next entry. - mutating func sendHeaders(headers: HPACKHeaders, validateHeaderBlock: Bool, validateContentLength: Bool, isEndStreamSet endStream: Bool) -> StateMachineResultWithStreamEffect { + mutating func sendHeaders(headers: HPACKHeaders, validateHeaderBlock: Bool, validateContentLength: Bool, localSupportsExtendedConnect: Bool, remoteSupportsExtendedConnect: Bool, isEndStreamSet endStream: Bool) -> StateMachineResultWithStreamEffect { do { // We can send headers in the following states: // @@ -297,6 +297,9 @@ extension HTTP2StreamStateMachine { let targetEffect: StreamStateChange = .streamCreated(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, + localRole: .client, + localSupportsExtendedConnect: localSupportsExtendedConnect, + remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, targetState: targetState, targetEffect: targetEffect) @@ -396,7 +399,7 @@ extension HTTP2StreamStateMachine { } } - mutating func receiveHeaders(headers: HPACKHeaders, validateHeaderBlock: Bool, validateContentLength: Bool, isEndStreamSet endStream: Bool) -> StateMachineResultWithStreamEffect { + mutating func receiveHeaders(headers: HPACKHeaders, validateHeaderBlock: Bool, validateContentLength: Bool, localSupportsExtendedConnect: Bool, remoteSupportsExtendedConnect: Bool, isEndStreamSet endStream: Bool) -> StateMachineResultWithStreamEffect { do { // We can receive headers in the following states: // @@ -428,6 +431,9 @@ extension HTTP2StreamStateMachine { let targetEffect: StreamStateChange = .streamCreated(.init(streamID: self.streamID, localStreamWindowSize: Int(localWindow), remoteStreamWindowSize: Int(remoteWindow))) return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, + localRole: .server, + localSupportsExtendedConnect: localSupportsExtendedConnect, + remoteSupportsExtendedConnect: remoteSupportsExtendedConnect, targetState: targetState, targetEffect: targetEffect) @@ -686,7 +692,7 @@ extension HTTP2StreamStateMachine { .halfOpenRemoteLocalIdle(localWindow: _, remoteContentLength: _, remoteWindow: _, requestVerb: _), .halfClosedRemoteLocalIdle(localWindow: _), .halfClosedRemoteLocalActive(localRole: .server, initiatedBy: .client, localContentLength: _, localWindow: _): - return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, targetState: self.state, targetEffect: nil) + return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, localRole: .server, localSupportsExtendedConnect: false, remoteSupportsExtendedConnect: false, targetState: self.state, targetEffect: nil) // Sending a PUSH_PROMISE frame outside any of these states is a stream error of type PROTOCOL_ERROR. // Authors note: I cannot find a citation for this in RFC 7540, but this seems a sensible choice. @@ -713,7 +719,7 @@ extension HTTP2StreamStateMachine { .halfOpenLocalPeerIdle(localWindow: _, localContentLength: _, remoteWindow: _, requestVerb: _), .halfClosedLocalPeerIdle(remoteWindow: _), .halfClosedLocalPeerActive(localRole: .client, initiatedBy: .client, remoteContentLength: _, remoteWindow: _): - return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, targetState: self.state, targetEffect: nil) + return self.processRequestHeaders(headers, validateHeaderBlock: validateHeaderBlock, localRole: .client, localSupportsExtendedConnect: false, remoteSupportsExtendedConnect: false, targetState: self.state, targetEffect: nil) // Receiving a PUSH_PROMISE frame outside any of these states is a stream error of type PROTOCOL_ERROR. // Authors note: I cannot find a citation for this in RFC 7540, but this seems a sensible choice. @@ -958,10 +964,11 @@ extension HTTP2StreamStateMachine { extension HTTP2StreamStateMachine { /// Validate that the request headers meet the requirements of RFC 7540. If they do, /// transitions to the target state. - private mutating func processRequestHeaders(_ headers: HPACKHeaders, validateHeaderBlock: Bool, targetState target: State, targetEffect effect: StreamStateChange?) -> StateMachineResultWithStreamEffect { + private mutating func processRequestHeaders(_ headers: HPACKHeaders, validateHeaderBlock: Bool, localRole: StreamRole, localSupportsExtendedConnect: Bool, remoteSupportsExtendedConnect: Bool, targetState target: State, targetEffect effect: StreamStateChange?) -> StateMachineResultWithStreamEffect { if validateHeaderBlock { do { - try headers.validateRequestBlock() + let supportsExtendedConnect = (localRole == .client && remoteSupportsExtendedConnect) || (localRole == .server && localSupportsExtendedConnect) + try headers.validateRequestBlock(supportsExtendedConnect: supportsExtendedConnect) } catch { return StateMachineResultWithStreamEffect(result: .streamError(streamID: self.streamID, underlyingError: error, type: .protocolError), effect: nil) } diff --git a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift index 0d11b4d1..d6083fc6 100644 --- a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift +++ b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift @@ -155,12 +155,12 @@ class ConnectionStateMachineTests: XCTestCase { self.clientDecoder = HTTP2FrameDecoder(allocator: ByteBufferAllocator(), expectClientMagic: false) } - private func exchangePreamble() { - assertSucceeds(self.client.sendSettings(HTTP2Settings())) - assertSucceeds(self.server.receiveSettings(.settings(HTTP2Settings()), frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) + private func exchangePreamble(client: HTTP2Settings = HTTP2Settings(), server: HTTP2Settings = HTTP2Settings()) { + assertSucceeds(self.client.sendSettings(client)) + assertSucceeds(self.server.receiveSettings(.settings(client), frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) - assertSucceeds(self.server.sendSettings(HTTP2Settings())) - assertSucceeds(self.client.receiveSettings(.settings(HTTP2Settings()), frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) + assertSucceeds(self.server.sendSettings(server)) + assertSucceeds(self.client.receiveSettings(.settings(server), frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) assertSucceeds(self.client.receiveSettings(.ack, frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) assertSucceeds(self.server.receiveSettings(.ack, frameEncoder: &self.serverEncoder, frameDecoder: &self.serverDecoder)) @@ -2368,6 +2368,38 @@ class ConnectionStateMachineTests: XCTestCase { assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) } + func testProtocolPseudoheaderWithoutEnableConnectProtocolSetting() { + let streamOne = HTTP2StreamID(1) + + self.exchangePreamble() + + let headers = HPACKHeaders([(":method", "CONNECT"), (":protocol", "websocket"), (":scheme", "https"), (":path", "/chat") ]) + + assertStreamError(type: .protocolError, self.client.sendHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + assertStreamError(type: .protocolError, self.server.receiveHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + } + + func testProtocolPseudoheaderWithEnableConnectProtocolSetting() { + let streamOne = HTTP2StreamID(1) + + self.exchangePreamble(server: [HTTP2Setting(parameter: .enableConnectProtocol, value: 1)]) + + let headers = HPACKHeaders([(":method", "CONNECT"), (":protocol", "websocket"), (":scheme", "https"), (":path", "/chat") ]) + + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + } + + func testRejectProtocolPseudoHeaderWithoutConnectMethod() { + let streamOne = HTTP2StreamID(1) + + self.exchangePreamble(server: [HTTP2Setting(parameter: .enableConnectProtocol, value: 1)]) + + let headers = HPACKHeaders([(":method", "GET"), (":protocol", "websocket"), (":scheme", "https"), (":path", "/chat") ]) + assertStreamError(type: .protocolError, self.client.sendHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + assertStreamError(type: .protocolError, self.server.receiveHeaders(streamID: streamOne, headers: headers, isEndStreamSet: true)) + } + func testRejectHeadersWithConnectionHeader() { let streamOne = HTTP2StreamID(1) let streamThree = HTTP2StreamID(3) diff --git a/Tests/NIOHTTP2Tests/HTTP2ErrorTests.swift b/Tests/NIOHTTP2Tests/HTTP2ErrorTests.swift index b0b79e79..b9d51705 100644 --- a/Tests/NIOHTTP2Tests/HTTP2ErrorTests.swift +++ b/Tests/NIOHTTP2Tests/HTTP2ErrorTests.swift @@ -219,6 +219,7 @@ class HTTP2ErrorTests: XCTestCase { XCTAssertLessThanOrEqual(MemoryLayout.size, 24) XCTAssertLessThanOrEqual(MemoryLayout.size, 24) XCTAssertLessThanOrEqual(MemoryLayout.size, 24) + XCTAssertLessThanOrEqual(MemoryLayout.size, 24) XCTAssertLessThanOrEqual(MemoryLayout.size, 24) XCTAssertLessThanOrEqual(MemoryLayout.size, 24) XCTAssertLessThanOrEqual(MemoryLayout.size, 24) diff --git a/Tests/NIOHTTP2Tests/SimpleClientServerFramePayloadStreamTests.swift b/Tests/NIOHTTP2Tests/SimpleClientServerFramePayloadStreamTests.swift index 4a254dcc..4e8d814c 100644 --- a/Tests/NIOHTTP2Tests/SimpleClientServerFramePayloadStreamTests.swift +++ b/Tests/NIOHTTP2Tests/SimpleClientServerFramePayloadStreamTests.swift @@ -2099,4 +2099,28 @@ class SimpleClientServerFramePayloadStreamTests: XCTestCase { XCTAssertNoThrow(try self.clientChannel.finish()) XCTAssertNoThrow(try self.serverChannel.finish()) } + + func testExtendedConnect() throws { + // Begin by getting the connection up. + try self.basicHTTP2Connection(serverSettings: [HTTP2Setting(parameter: .enableConnectProtocol, value: 1)]) + + // We're now going to try to send a request from the client to the server with the protocol pseudoheader present + let headers = HPACKHeaders([(":path", "/"), (":method", "CONNECT"), (":scheme", "https"), (":authority", "localhost"), (":protocol", "foo")]) + var requestBody = self.clientChannel.allocator.buffer(capacity: 128) + requestBody.writeStaticString("A simple HTTP/2 request.") + + let clientStreamID = HTTP2StreamID(1) + let reqFrame = HTTP2Frame(streamID: clientStreamID, payload: .headers(.init(headers: headers))) + let reqBodyFrame = HTTP2Frame(streamID: clientStreamID, payload: .data(.init(data: .byteBuffer(requestBody), endStream: true))) + + let serverStreamID = try self.assertFramesRoundTrip(frames: [reqFrame, reqBodyFrame], sender: self.clientChannel, receiver: self.serverChannel).first!.streamID + + // Let's send a quick response back. + let responseHeaders = HPACKHeaders([(":status", "200"), ("content-length", "0")]) + let respFrame = HTTP2Frame(streamID: serverStreamID, payload: .headers(.init(headers: responseHeaders, endStream: true))) + try self.assertFramesRoundTrip(frames: [respFrame], sender: self.serverChannel, receiver: self.clientChannel) + + XCTAssertNoThrow(try self.clientChannel.finish()) + XCTAssertNoThrow(try self.serverChannel.finish()) + } } diff --git a/Tests/NIOHTTP2Tests/TestUtilities.swift b/Tests/NIOHTTP2Tests/TestUtilities.swift index 2cd6c11c..0f13cc87 100644 --- a/Tests/NIOHTTP2Tests/TestUtilities.swift +++ b/Tests/NIOHTTP2Tests/TestUtilities.swift @@ -516,7 +516,7 @@ extension HTTP2Frame.FramePayload { switch type { case .some(.request): - XCTAssertNoThrow(try actualPayload.headers.validateRequestBlock(), + XCTAssertNoThrow(try actualPayload.headers.validateRequestBlock(supportsExtendedConnect: true), "\(actualPayload.headers) not a valid \(type!) headers block", file: (file), line: line) case .some(.response): XCTAssertNoThrow(try actualPayload.headers.validateResponseBlock(), @@ -527,7 +527,7 @@ extension HTTP2Frame.FramePayload { case .some(.doNotValidate): () // alright, let's not validate then case .none: - XCTAssertTrue((try? actualPayload.headers.validateRequestBlock()) != nil || + XCTAssertTrue((try? actualPayload.headers.validateRequestBlock(supportsExtendedConnect: true)) != nil || (try? actualPayload.headers.validateResponseBlock()) != nil || (try? actualPayload.headers.validateTrailersBlock()) != nil, "\(actualPayload.headers) not a valid request/response/trailers header block", From 7cfcbd56bf628b3504781744482e4b9b94b34777 Mon Sep 17 00:00:00 2001 From: Eric Rosenberg Date: Thu, 20 Jun 2024 09:32:15 -0700 Subject: [PATCH 2/3] Update Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift Co-authored-by: Cory Benfield --- .../ConnectionStateMachine/HasExtendedConnectSettings.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift b/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift index 8099e64f..e3ef322a 100644 --- a/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift +++ b/Sources/NIOHTTP2/ConnectionStateMachine/HasExtendedConnectSettings.swift @@ -2,7 +2,7 @@ // // This source file is part of the SwiftNIO open source project // -// Copyright (c) 2017-2024 Apple Inc. and the SwiftNIO project authors +// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information From e87b94001999799fb7bc27d615c006b8c91b0830 Mon Sep 17 00:00:00 2001 From: Eric Rosenberg Date: Fri, 21 Jun 2024 09:17:28 -0700 Subject: [PATCH 3/3] s/Pseudoheader/PseudoHeader/ --- Sources/NIOHTTP2/HPACKHeaders+Validation.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/NIOHTTP2/HPACKHeaders+Validation.swift b/Sources/NIOHTTP2/HPACKHeaders+Validation.swift index 0f9149be..88887f45 100644 --- a/Sources/NIOHTTP2/HPACKHeaders+Validation.swift +++ b/Sources/NIOHTTP2/HPACKHeaders+Validation.swift @@ -106,7 +106,7 @@ extension HeaderBlockValidator { /// An object that can be used to validate if a given header block is a valid request header block. fileprivate struct RequestBlockValidator { private var isConnectRequest: Bool = false - private var containsProtocolPseudoheader: Bool = false + private var containsProtocolPseudoHeader: Bool = false } extension RequestBlockValidator: HeaderBlockValidator { @@ -150,7 +150,7 @@ extension RequestBlockValidator: HeaderBlockValidator { // This is a method pseudo-header. Check if the value is CONNECT. self.isConnectRequest = value == "CONNECT" case .extConnectProtocol: - self.containsProtocolPseudoheader = true + self.containsProtocolPseudoHeader = true case .path: // This is a path pseudo-header. It must not be empty. if value.utf8.count == 0 { @@ -172,7 +172,7 @@ extension RequestBlockValidator: HeaderBlockValidator { var allowedPseudoHeaderFields: PseudoHeaders { // For the logic behind this if statement, see the comment in validateNextField. if self.isConnectRequest { - if self.containsProtocolPseudoheader { + if self.containsProtocolPseudoHeader { return .allowedExtendedConnectRequestHeaders } else { return .allowedConnectRequestHeaders @@ -184,7 +184,7 @@ extension RequestBlockValidator: HeaderBlockValidator { var mandatoryPseudoHeaderFields: PseudoHeaders { // For the logic behind this if statement, see the comment in validateNextField. - if self.isConnectRequest && !self.containsProtocolPseudoheader { + if self.isConnectRequest && !self.containsProtocolPseudoHeader { return .mandatoryConnectRequestHeaders } else { return .mandatoryRequestHeaders