Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pinning in quic #52368

Merged
merged 5 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ internal delegate uint StreamShutdownDelegate(
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal delegate uint StreamSendDelegate(
SafeMsQuicStreamHandle stream,
QuicBuffer* buffers,
SafeMsQuicBufferHandle buffers,
uint bufferCount,
QUIC_SEND_FLAGS flags,
IntPtr clientSendContext);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;
using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods;

namespace System.Net.Quic.Implementations.MsQuic.Internal
{
internal sealed unsafe class SafeMsQuicBufferHandle : SafeHandle
{
public int Count;
wfurt marked this conversation as resolved.
Show resolved Hide resolved

public override bool IsInvalid => handle == IntPtr.Zero;

public SafeMsQuicBufferHandle(int count)
: base(IntPtr.Zero, ownsHandle: true)
{
IntPtr buffer = Marshal.AllocHGlobal(sizeof(QuicBuffer) * count);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to do this? What pattern is causing this to be required?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean AllocHGlobal or wrapping it in SafeMsQuicBufferHandle?

Copy link
Member

@davidfowl davidfowl May 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both. We do a similar thing in the ASP.NET Core servers and in sockets and we don't use this pattern AFAIK. What makes MSQuic unique here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the SafeHandle should be unnecessary here. The code has to be careful about managing the lifetimes of the async buffers. SafeHandle does not make it easier.

It is an interesting question whether it is better to use unmanaged memory or POH-allocated memory here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another example from SslStream.

const int NumSecBuffers = 4; // header + data + trailer + empty
Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers)
{
pBuffers = unmanagedBuffer
};
fixed (byte* outputPtr = output)

Because we need to only live through the one native call, we can allocate the buffer array on stack. For msquic we need to wait until HandleEventSendComplete() is invoked. So we need to keep it alive past the sending function.

Anyway, we could use the pinning if we want to but as I mentioned this looked simpler and less work on each call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example you have to be careful about situation when somebody disposes the object while the operation is still in flight.

Yep, I'm aware and we also go back and forth on how useful SafeHandles are in certain places, but I'm surprised SafeHandle + native buffer is the solution here. Are we pooling these things?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no pool at the moment. The main goal was to stabilize the test runs and avoid corruption and crashes and to unblock @ManickaP and @CarnaViire. I did 700+ runs and I did not see any failure. I think it would be better to refactor and/or optimize once we have solid CI in place.

I think we will need more work toward #5262 and #32142. The SafeHandle primarily helps with p/invokes. For writes, we will need to preserve the memory longer unless we change the handout model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SafeHandle does not help you much here. You are directly managing the lifetime of the MemoryHandles for the individual buffer, so you can directly managed the lifetime for the array of buffer pointers as well. SafeHandle here is a lot of overhead for negligible benefit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand @jkotas. When we agree on #5262 and #32142 I will either (1) use native memory directly or (2) fall back to pinning on each call. I will also add test for concurrent dispose(s) and IO. I'm pretty sure we still have gaps there. (and possibly investigate and fix #52048 as well)

SetHandle(buffer);
Count = count;
}

protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ private sealed class State

// Buffers to hold during a call to send.
public MemoryHandle[] BufferArrays = new MemoryHandle[1];
public QuicBuffer[] SendQuicBuffers = new QuicBuffer[1];

// Handle to pinned SendQuicBuffers.
public GCHandle SendHandle;
public SafeMsQuicBufferHandle SendBufferHandle = new SafeMsQuicBufferHandle(1);
public int SendBufferCount;

// Resettable completions to be used for multiple calls to send, start, and shutdown.
public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();
Expand Down Expand Up @@ -176,14 +174,12 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<ReadOnlyMemory<byte>

using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken).ConfigureAwait(false);
await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);

HandleWriteCompletedState();
}

internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool endStream, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();

using CancellationTokenRegistration registration = await HandleWriteStartState(cancellationToken).ConfigureAwait(false);

await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);
Expand Down Expand Up @@ -212,7 +208,7 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
bool shouldComplete = false;
lock (state)
{
if (state.SendState == SendState.None)
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
{
state.SendState = SendState.Aborted;
shouldComplete = true;
Expand Down Expand Up @@ -240,7 +236,7 @@ private void HandleWriteCompletedState()
{
lock (_state)
{
if (_state.SendState == SendState.Finished)
if (_state.SendState == SendState.Finished || _state.SendState == SendState.Aborted)
{
_state.SendState = SendState.None;
}
Expand Down Expand Up @@ -501,11 +497,11 @@ private void Dispose(bool disposing)
return;
}

_disposed = true;
_state.Handle.Dispose();
_state.SendBufferHandle.Dispose();
if (_stateHandle.IsAllocated) _stateHandle.Free();
CleanupSendState(_state);

_disposed = true;
}

private void EnableReceive()
Expand Down Expand Up @@ -602,7 +598,7 @@ private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
bool shouldComplete = false;
lock (state)
{
if (state.SendState == SendState.None)
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
{
shouldComplete = true;
}
Expand Down Expand Up @@ -761,7 +757,7 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt)

lock (state)
{
if (state.SendState == SendState.None)
if (state.SendState == SendState.Pending)
{
state.SendState = SendState.Finished;
complete = true;
Expand All @@ -771,7 +767,6 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt)
if (complete)
{
CleanupSendState(state);

// TODO throw if a write was canceled.
state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);
}
Expand All @@ -781,15 +776,14 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt)

private static void CleanupSendState(State state)
{
if (state.SendHandle.IsAllocated)
{
state.SendHandle.Free();
}
lock (state) {
wfurt marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert(state.SendState != SendState.Pending);
Debug.Assert(state.SendBufferCount <= state.BufferArrays.Length);

// Callings dispose twice on a memory handle should be okay
foreach (MemoryHandle buffer in state.BufferArrays)
{
buffer.Dispose();
for (int i = 0; i < state.SendBufferCount; i++)
{
state.BufferArrays[i].Dispose();
}
}
}

Expand All @@ -798,6 +792,12 @@ private unsafe ValueTask SendReadOnlyMemoryAsync(
ReadOnlyMemory<byte> buffer,
QUIC_SEND_FLAGS flags)
{
lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we want to prevent overlapping Sends. Shouldn't this then rather be a condition with throw QuicException...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not meant as guard agains overlaying Sends. (I think that should be done much earlier)
I added this to make sure our internal state logic (and msquic) always moves to some other state.

_state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffer.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand All @@ -809,18 +809,17 @@ private unsafe ValueTask SendReadOnlyMemoryAsync(
}

MemoryHandle handle = buffer.Pin();
_state.SendQuicBuffers[0].Length = (uint)buffer.Length;
_state.SendQuicBuffers[0].Buffer = (byte*)handle.Pointer;
QuicBuffer* quicBuffer = (QuicBuffer*)_state.SendBufferHandle.DangerousGetHandle();

_state.BufferArrays[0] = handle;

_state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned);
quicBuffer->Length = (uint)buffer.Length;
quicBuffer->Buffer = (byte*)handle.Pointer;

var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0);
_state.BufferArrays[0] = handle;
_state.SendBufferCount = 1;

uint status = MsQuicApi.Api.StreamSendDelegate(
_state.Handle,
quicBufferPointer,
_state.SendBufferHandle,
bufferCount: 1,
flags,
IntPtr.Zero);
Expand All @@ -841,6 +840,13 @@ private unsafe ValueTask SendReadOnlySequenceAsync(
ReadOnlySequence<byte> buffers,
QUIC_SEND_FLAGS flags)
{

lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffers.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand All @@ -851,38 +857,37 @@ private unsafe ValueTask SendReadOnlySequenceAsync(
return default;
}

uint count = 0;
int count = 0;

foreach (ReadOnlyMemory<byte> buffer in buffers)
{
++count;
}

if (_state.SendQuicBuffers.Length < count)
if (_state.SendBufferHandle.Count < count)
{
_state.SendQuicBuffers = new QuicBuffer[count];
_state.SendBufferHandle.Dispose();
_state.SendBufferHandle = new SafeMsQuicBufferHandle(count);
_state.BufferArrays = new MemoryHandle[count];
}

_state.SendBufferCount = count;
count = 0;

QuicBuffer* quicBuffers = (QuicBuffer*)_state.SendBufferHandle.DangerousGetHandle();
foreach (ReadOnlyMemory<byte> buffer in buffers)
{
MemoryHandle handle = buffer.Pin();
_state.SendQuicBuffers[count].Length = (uint)buffer.Length;
_state.SendQuicBuffers[count].Buffer = (byte*)handle.Pointer;
quicBuffers[count].Length = (uint)buffer.Length;
quicBuffers[count].Buffer = (byte*)handle.Pointer;
_state.BufferArrays[count] = handle;
++count;
}

_state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned);

var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0);

uint status = MsQuicApi.Api.StreamSendDelegate(
_state.Handle,
quicBufferPointer,
count,
_state.SendBufferHandle,
(uint)count,
flags,
IntPtr.Zero);

Expand All @@ -902,6 +907,12 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync(
ReadOnlyMemory<ReadOnlyMemory<byte>> buffers,
QUIC_SEND_FLAGS flags)
{
lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffers.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand All @@ -916,28 +927,29 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync(

uint length = (uint)array.Length;

if (_state.SendQuicBuffers.Length < length)
if (_state.SendBufferHandle.Count < array.Length)
{
_state.SendQuicBuffers = new QuicBuffer[length];
_state.BufferArrays = new MemoryHandle[length];
_state.SendBufferHandle.Dispose();
_state.SendBufferHandle = new SafeMsQuicBufferHandle(array.Length);
_state.BufferArrays = new MemoryHandle[array.Length];
}

_state.SendBufferCount = array.Length;
QuicBuffer* quicBuffers = (QuicBuffer*)_state.SendBufferHandle.DangerousGetHandle();
for (int i = 0; i < length; i++)
{
ReadOnlyMemory<byte> buffer = array[i];
MemoryHandle handle = buffer.Pin();
_state.SendQuicBuffers[i].Length = (uint)buffer.Length;
_state.SendQuicBuffers[i].Buffer = (byte*)handle.Pointer;
_state.BufferArrays[i] = handle;
}

_state.SendHandle = GCHandle.Alloc(_state.SendQuicBuffers, GCHandleType.Pinned);
quicBuffers[i].Length = (uint)buffer.Length;
quicBuffers[i].Buffer = (byte*)handle.Pointer;

var quicBufferPointer = (QuicBuffer*)Marshal.UnsafeAddrOfPinnedArrayElement(_state.SendQuicBuffers, 0);
_state.BufferArrays[i] = handle;
}

uint status = MsQuicApi.Api.StreamSendDelegate(
_state.Handle,
quicBufferPointer,
_state.SendBufferHandle,
length,
flags,
IntPtr.Zero);
Expand Down Expand Up @@ -1014,6 +1026,7 @@ private enum ShutdownState
private enum SendState
{
None,
Pending,
Aborted,
Finished
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ public BufferSegment Append(ReadOnlyMemory<byte> memory)
}
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/52047")]
[Fact]
public async Task ByteMixingOrNativeAVE_MinimalFailingTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ public async Task GetStreamIdWithoutStartWorks()
await clientConnection.CloseAsync(0);
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/52047")]
[Fact]
public async Task LargeDataSentAndReceived()
{
Expand Down Expand Up @@ -348,7 +347,6 @@ private static async Task SendAndReceiveEOFAsync(QuicStream s1, QuicStream s2)
Assert.Equal(0, bytesRead);
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/52047")]
[Theory]
[MemberData(nameof(ReadWrite_Random_Success_Data))]
public async Task ReadWrite_Random_Success(int readSize, int writeSize)
Expand Down Expand Up @@ -434,7 +432,7 @@ await Task.Run(async () =>
byte[] buffer = new byte[100];
QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => serverStream.ReadAsync(buffer).AsTask());
Assert.Equal(ExpectedErrorCode, ex.ErrorCode);
}).WaitAsync(TimeSpan.FromSeconds(5));
}).WaitAsync(TimeSpan.FromSeconds(15));
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/32050")]
Expand Down