diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs index f258c5ca8ae80..40884f72085e9 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpRequestMessage.cs @@ -168,7 +168,7 @@ public override string ToString() internal bool WasRedirected() => (_sendStatus & MessageIsRedirect) != 0; - internal bool IsWebSocketH2Request() => _version.Major == 2 && Method == HttpMethod.Connect && HasHeaders && string.Equals(Headers.Protocol, "websocket", StringComparison.OrdinalIgnoreCase); + internal bool IsExtendedConnectRequest => Method == HttpMethod.Connect && _headers?.Protocol != null; #region IDisposable Members diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 25c27d7d0fdc5..3b2019f5af657 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1600,7 +1600,7 @@ private async ValueTask SendHeadersAsync(HttpRequestMessage request // Start the write. This serializes access to write to the connection, and ensures that HEADERS // and CONTINUATION frames stay together, as they must do. We use the lock as well to ensure new // streams are created and started in order. - await PerformWriteAsync(totalSize, (thisRef: this, http2Stream, headerBytes, endStream: (request.Content == null && !http2Stream.ConnectProtocolEstablished), mustFlush), static (s, writeBuffer) => + await PerformWriteAsync(totalSize, (thisRef: this, http2Stream, headerBytes, endStream: (request.Content == null && !request.IsExtendedConnectRequest), mustFlush), static (s, writeBuffer) => { if (NetEventSource.Log.IsEnabled()) s.thisRef.Trace(s.http2Stream.StreamId, $"Started writing. Total header bytes={s.headerBytes.Length}"); @@ -1962,8 +1962,8 @@ public async Task SendAsync(HttpRequestMessage request, boo try { // Send request headers - bool shouldExpectContinue = request.Content != null && request.HasHeaders && request.Headers.ExpectContinue == true; - Http2Stream http2Stream = await SendHeadersAsync(request, cancellationToken, mustFlush: shouldExpectContinue).ConfigureAwait(false); + bool shouldExpectContinue = (request.Content != null && request.HasHeaders && request.Headers.ExpectContinue == true); + Http2Stream http2Stream = await SendHeadersAsync(request, cancellationToken, mustFlush: shouldExpectContinue || request.IsExtendedConnectRequest).ConfigureAwait(false); bool duplex = request.Content != null && request.Content.AllowDuplex; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 1ecb968fa322a..927a8e74b18a1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -108,7 +108,7 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection) if (_request.Content == null) { _requestCompletionState = StreamCompletionState.Completed; - if (_request.IsWebSocketH2Request()) + if (_request.IsExtendedConnectRequest) { _requestBodyCancellationSource = new CancellationTokenSource(); } @@ -637,7 +637,7 @@ private void OnStatus(int statusCode) } else { - if (statusCode == 200 && _response.RequestMessage!.IsWebSocketH2Request()) + if (statusCode == 200 && _response.RequestMessage!.IsExtendedConnectRequest) { ConnectProtocolEstablished = true; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index e560915288e56..7e6f148c72b9f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -1018,7 +1018,7 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn // Use HTTP/3 if possible. if (IsHttp3Supported() && // guard to enable trimming HTTP/3 support _http3Enabled && - !request.IsWebSocketH2Request() && + !request.IsExtendedConnectRequest && (request.Version.Major >= 3 || (request.VersionPolicy == HttpVersionPolicy.RequestVersionOrHigher && IsSecure))) { Debug.Assert(async); @@ -1047,7 +1047,7 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn Debug.Assert(connection is not null || !_http2Enabled); if (connection is not null) { - if (request.IsWebSocketH2Request()) + if (request.IsExtendedConnectRequest) { await connection.InitialSettingsReceived.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false); if (!connection.IsConnectEnabled) @@ -1120,7 +1120,7 @@ public async ValueTask SendWithVersionDetectionAndRetryAsyn if (request.VersionPolicy != HttpVersionPolicy.RequestVersionOrLower) { HttpRequestException exception = new HttpRequestException(SR.Format(SR.net_http_requested_version_server_refused, request.Version, request.VersionPolicy), e); - if (request.IsWebSocketH2Request()) + if (request.IsExtendedConnectRequest) { exception.Data["HTTP2_ENABLED"] = false; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 4130cff1cc65d..4c0c4e8d693b3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -423,7 +423,17 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, // the task, and we're done. if (writeTask.IsCompleted) { - return writeTask; + writeTask.GetAwaiter().GetResult(); + ValueTask flushTask = new ValueTask(_stream.FlushAsync()); + if (flushTask.IsCompleted) + { + return flushTask; + } + else + { + releaseSendBufferAndSemaphore = false; + return WaitForWriteTaskAsync(flushTask, shouldFlush: false); + } } // Up until this point, if an exception occurred (such as when accessing _stream or when @@ -447,14 +457,18 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } } - return WaitForWriteTaskAsync(writeTask); + return WaitForWriteTaskAsync(writeTask, shouldFlush: true); } - private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask) + private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFlush) { try { await writeTask.ConfigureAwait(false); + if (shouldFlush) + { + await _stream.FlushAsync().ConfigureAwait(false); + } } catch (Exception exc) when (!(exc is OperationCanceledException)) { @@ -478,6 +492,7 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); + await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); } } catch (Exception exc) when (!(exc is OperationCanceledException)) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs index b7dfb3ea7f26d..1e416f4dc49cc 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTestStream.cs @@ -206,7 +206,7 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella Write(buffer.Span); } - public override void Flush() => throw new NotSupportedException(); + public override void Flush() { } public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();