Skip to content

Commit

Permalink
No end stream on ws connect and flush every message (#73762)
Browse files Browse the repository at this point in the history
* No end stream on ws connect and flush every message

* Apply suggestions from code review

Co-authored-by: Natalia Kondratyeva <knatalia@microsoft.com>

* Await flush in web socket after write

* Flush for web socket tests should be supported

* feedback

* Refactoring write and flush tasks in WebSocket

* feedback

* Replace check for more generic extended connect

* Apply suggestions from code review

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* feedback

Co-authored-by: Natalia Kondratyeva <knatalia@microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
3 people authored Aug 12, 2022
1 parent 31d5d23 commit 30bec96
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public override string ToString()

internal bool WasRedirected() => (_sendStatus & MessageIsRedirect) != 0;

internal bool IsWebSocketH2Request() => _version.Major == 2 && Method == HttpMethod.Connect && HasHeaders && string.Equals(Headers.Protocol, "websocket", StringComparison.OrdinalIgnoreCase);
internal bool IsExtendedConnectRequest => Method == HttpMethod.Connect && _headers?.Protocol != null;

#region IDisposable Members

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ private async ValueTask<Http2Stream> SendHeadersAsync(HttpRequestMessage request
// Start the write. This serializes access to write to the connection, and ensures that HEADERS
// and CONTINUATION frames stay together, as they must do. We use the lock as well to ensure new
// streams are created and started in order.
await PerformWriteAsync(totalSize, (thisRef: this, http2Stream, headerBytes, endStream: (request.Content == null && !http2Stream.ConnectProtocolEstablished), mustFlush), static (s, writeBuffer) =>
await PerformWriteAsync(totalSize, (thisRef: this, http2Stream, headerBytes, endStream: (request.Content == null && !request.IsExtendedConnectRequest), mustFlush), static (s, writeBuffer) =>
{
if (NetEventSource.Log.IsEnabled()) s.thisRef.Trace(s.http2Stream.StreamId, $"Started writing. Total header bytes={s.headerBytes.Length}");
Expand Down Expand Up @@ -1962,8 +1962,8 @@ public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, boo
try
{
// Send request headers
bool shouldExpectContinue = request.Content != null && request.HasHeaders && request.Headers.ExpectContinue == true;
Http2Stream http2Stream = await SendHeadersAsync(request, cancellationToken, mustFlush: shouldExpectContinue).ConfigureAwait(false);
bool shouldExpectContinue = (request.Content != null && request.HasHeaders && request.Headers.ExpectContinue == true);
Http2Stream http2Stream = await SendHeadersAsync(request, cancellationToken, mustFlush: shouldExpectContinue || request.IsExtendedConnectRequest).ConfigureAwait(false);

bool duplex = request.Content != null && request.Content.AllowDuplex;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection)
if (_request.Content == null)
{
_requestCompletionState = StreamCompletionState.Completed;
if (_request.IsWebSocketH2Request())
if (_request.IsExtendedConnectRequest)
{
_requestBodyCancellationSource = new CancellationTokenSource();
}
Expand Down Expand Up @@ -637,7 +637,7 @@ private void OnStatus(int statusCode)
}
else
{
if (statusCode == 200 && _response.RequestMessage!.IsWebSocketH2Request())
if (statusCode == 200 && _response.RequestMessage!.IsExtendedConnectRequest)
{
ConnectProtocolEstablished = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ public async ValueTask<HttpResponseMessage> SendWithVersionDetectionAndRetryAsyn
// Use HTTP/3 if possible.
if (IsHttp3Supported() && // guard to enable trimming HTTP/3 support
_http3Enabled &&
!request.IsWebSocketH2Request() &&
!request.IsExtendedConnectRequest &&
(request.Version.Major >= 3 || (request.VersionPolicy == HttpVersionPolicy.RequestVersionOrHigher && IsSecure)))
{
Debug.Assert(async);
Expand Down Expand Up @@ -1050,7 +1050,7 @@ public async ValueTask<HttpResponseMessage> SendWithVersionDetectionAndRetryAsyn
Debug.Assert(connection is not null || !_http2Enabled);
if (connection is not null)
{
if (request.IsWebSocketH2Request())
if (request.IsExtendedConnectRequest)
{
await connection.InitialSettingsReceived.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
if (!connection.IsConnectEnabled)
Expand Down Expand Up @@ -1123,7 +1123,7 @@ public async ValueTask<HttpResponseMessage> SendWithVersionDetectionAndRetryAsyn
if (request.VersionPolicy != HttpVersionPolicy.RequestVersionOrLower)
{
HttpRequestException exception = new HttpRequestException(SR.Format(SR.net_http_requested_version_server_refused, request.Version, request.VersionPolicy), e);
if (request.IsWebSocketH2Request())
if (request.IsExtendedConnectRequest)
{
exception.Data["HTTP2_ENABLED"] = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,17 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
// the task, and we're done.
if (writeTask.IsCompleted)
{
return writeTask;
writeTask.GetAwaiter().GetResult();
ValueTask flushTask = new ValueTask(_stream.FlushAsync());
if (flushTask.IsCompleted)
{
return flushTask;
}
else
{
releaseSendBufferAndSemaphore = false;
return WaitForWriteTaskAsync(flushTask, shouldFlush: false);
}
}

// Up until this point, if an exception occurred (such as when accessing _stream or when
Expand All @@ -447,14 +457,18 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
}
}

return WaitForWriteTaskAsync(writeTask);
return WaitForWriteTaskAsync(writeTask, shouldFlush: true);
}

private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask)
private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFlush)
{
try
{
await writeTask.ConfigureAwait(false);
if (shouldFlush)
{
await _stream.FlushAsync().ConfigureAwait(false);
}
}
catch (Exception exc) when (!(exc is OperationCanceledException))
{
Expand All @@ -478,6 +492,7 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM
using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this))
{
await _stream.WriteAsync(new ReadOnlyMemory<byte>(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false);
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
}
catch (Exception exc) when (!(exc is OperationCanceledException))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella
Write(buffer.Span);
}

public override void Flush() => throw new NotSupportedException();
public override void Flush() { }

public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();

Expand Down

0 comments on commit 30bec96

Please sign in to comment.