From 9ce903333efeb3eb5fef069f9caf868cb2309170 Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Mon, 5 Mar 2018 23:47:11 -0800 Subject: [PATCH] Fix semantics of ArrayMemoryPool (#27615) * Fix semantics of ArrayMemoryPool * More thread safety * Fix pipes and add tests --- .../src/System/IO/Pipelines/BufferSegment.cs | 1 - .../tests/FlushAsyncCancellationTests.cs | 2 +- .../tests/PipePoolTests.cs | 35 +++++++++++++++--- .../tests/TestMemoryPool.cs | 4 ++ src/System.Memory/src/System.Memory.csproj | 1 + .../ArrayMemoryPool.ArrayMemoryPoolBuffer.cs | 37 ++++++++++++------- .../tests/MemoryPool/MemoryPool.cs | 37 +++++++++++++++++-- 7 files changed, 93 insertions(+), 24 deletions(-) diff --git a/src/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegment.cs b/src/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegment.cs index 3ef8402e1c82..511e327af09d 100644 --- a/src/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegment.cs +++ b/src/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegment.cs @@ -61,7 +61,6 @@ public void SetMemory(OwnedMemory buffer) public void SetMemory(OwnedMemory ownedMemory, int start, int end, bool readOnly = false) { _ownedMemory = ownedMemory; - _ownedMemory.Retain(); AvailableMemory = _ownedMemory.Memory; diff --git a/src/System.IO.Pipelines/tests/FlushAsyncCancellationTests.cs b/src/System.IO.Pipelines/tests/FlushAsyncCancellationTests.cs index 9fa2400680d8..95aceea93fa8 100644 --- a/src/System.IO.Pipelines/tests/FlushAsyncCancellationTests.cs +++ b/src/System.IO.Pipelines/tests/FlushAsyncCancellationTests.cs @@ -313,7 +313,7 @@ public static class TestWriterExtensions { public static PipeWriter WriteEmpty(this PipeWriter writer, int count) { - writer.GetMemory(count); + writer.GetSpan(count).Slice(0, count).Fill(0); writer.Advance(count); return writer; } diff --git a/src/System.IO.Pipelines/tests/PipePoolTests.cs b/src/System.IO.Pipelines/tests/PipePoolTests.cs index 37394ef743e5..b5739b5af4c1 100644 --- a/src/System.IO.Pipelines/tests/PipePoolTests.cs +++ b/src/System.IO.Pipelines/tests/PipePoolTests.cs @@ -13,6 +13,7 @@ public class PipePoolTests private class DisposeTrackingBufferPool : TestMemoryPool { public int ReturnedBlocks { get; set; } + public int DisposedBlocks { get; set; } public int CurrentlyRentedBlocks { get; set; } public override OwnedMemory Rent(int size) @@ -26,10 +27,12 @@ protected override void Dispose(bool disposing) private class DisposeTrackingOwnedMemory : OwnedMemory { - private readonly byte[] _array; + private byte[] _array; private readonly DisposeTrackingBufferPool _bufferPool; + private int _refCount = 1; + public DisposeTrackingOwnedMemory(byte[] array, DisposeTrackingBufferPool bufferPool) { _array = array; @@ -49,9 +52,9 @@ public override Span Span } } - public override bool IsDisposed { get; } + public override bool IsDisposed => _array == null; - protected override bool IsRetained => true; + protected override bool IsRetained => _refCount > 0; public override MemoryHandle Pin(int byteOffset = 0) { @@ -68,18 +71,26 @@ protected override bool TryGetArray(out ArraySegment arraySegment) protected override void Dispose(bool disposing) { - throw new NotImplementedException(); + if (IsRetained) + { + throw new InvalidOperationException(); + } + _bufferPool.DisposedBlocks++; + + _array = null; } public override bool Release() { _bufferPool.ReturnedBlocks++; _bufferPool.CurrentlyRentedBlocks--; + _refCount--; return IsRetained; } public override void Retain() { + _refCount++; } } } @@ -102,6 +113,8 @@ public async Task AdvanceToEndReturnsAllBlocks() pipe.Reader.AdvanceTo(readResult.Buffer.End); Assert.Equal(0, pool.CurrentlyRentedBlocks); + Assert.Equal(0, pool.DisposedBlocks); + Assert.Equal(3, pool.ReturnedBlocks); } [Fact] @@ -128,6 +141,10 @@ public async Task CanWriteAfterReturningMultipleBlocks() // Try writing more await pipe.Writer.WriteAsync(new byte[writeSize]); + + Assert.Equal(1, pool.CurrentlyRentedBlocks); + Assert.Equal(0, pool.DisposedBlocks); + Assert.Equal(2, pool.ReturnedBlocks); } [Fact] @@ -141,10 +158,12 @@ public async Task MultipleCompleteReaderWriterCauseDisposeOnlyOnce() readerWriter.Writer.Complete(); readerWriter.Reader.Complete(); Assert.Equal(1, pool.ReturnedBlocks); + Assert.Equal(0, pool.DisposedBlocks); readerWriter.Writer.Complete(); readerWriter.Reader.Complete(); Assert.Equal(1, pool.ReturnedBlocks); + Assert.Equal(0, pool.DisposedBlocks); } [Fact] @@ -174,11 +193,13 @@ public void ReturnsWriteHeadOnComplete() { var pool = new DisposeTrackingBufferPool(); var pipe = new Pipe(new PipeOptions(pool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false)); - var memory = pipe.Writer.GetMemory(512); + pipe.Writer.GetMemory(512); pipe.Reader.Complete(); pipe.Writer.Complete(); Assert.Equal(0, pool.CurrentlyRentedBlocks); + Assert.Equal(1, pool.ReturnedBlocks); + Assert.Equal(0, pool.DisposedBlocks); } [Fact] @@ -186,12 +207,14 @@ public void ReturnsWriteHeadWhenRequestingLargerBlock() { var pool = new DisposeTrackingBufferPool(); var pipe = new Pipe(new PipeOptions(pool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false)); - var memory = pipe.Writer.GetMemory(512); + pipe.Writer.GetMemory(512); pipe.Writer.GetMemory(4096); pipe.Reader.Complete(); pipe.Writer.Complete(); Assert.Equal(0, pool.CurrentlyRentedBlocks); + Assert.Equal(2, pool.ReturnedBlocks); + Assert.Equal(0, pool.DisposedBlocks); } [Fact] diff --git a/src/System.IO.Pipelines/tests/TestMemoryPool.cs b/src/System.IO.Pipelines/tests/TestMemoryPool.cs index c1706a81a577..ae1efcc3cac5 100644 --- a/src/System.IO.Pipelines/tests/TestMemoryPool.cs +++ b/src/System.IO.Pipelines/tests/TestMemoryPool.cs @@ -53,6 +53,7 @@ public PooledMemory(OwnedMemory ownedMemory, TestMemoryPool pool) _ownedMemory = ownedMemory; _pool = pool; _leaser = Environment.StackTrace; + _referenceCount = 1; } ~PooledMemory() @@ -75,12 +76,15 @@ public override MemoryHandle Pin(int byteOffset = 0) public override void Retain() { _pool.CheckDisposed(); + _ownedMemory.Retain(); Interlocked.Increment(ref _referenceCount); } public override bool Release() { _pool.CheckDisposed(); + _ownedMemory.Release(); + int newRefCount = Interlocked.Decrement(ref _referenceCount); if (newRefCount < 0) diff --git a/src/System.Memory/src/System.Memory.csproj b/src/System.Memory/src/System.Memory.csproj index 226268054836..0dcf986ae881 100644 --- a/src/System.Memory/src/System.Memory.csproj +++ b/src/System.Memory/src/System.Memory.csproj @@ -147,6 +147,7 @@ + diff --git a/src/System.Memory/src/System/Buffers/ArrayMemoryPool.ArrayMemoryPoolBuffer.cs b/src/System.Memory/src/System/Buffers/ArrayMemoryPool.ArrayMemoryPoolBuffer.cs index 08015646e3f7..532edd108b12 100644 --- a/src/System.Memory/src/System/Buffers/ArrayMemoryPool.ArrayMemoryPoolBuffer.cs +++ b/src/System.Memory/src/System/Buffers/ArrayMemoryPool.ArrayMemoryPoolBuffer.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System.Runtime.InteropServices; +using System.Threading; + #if !netstandard using Internal.Runtime.CompilerServices; #else @@ -21,13 +23,14 @@ private sealed class ArrayMemoryPoolBuffer : OwnedMemory public ArrayMemoryPoolBuffer(int size) { _array = ArrayPool.Shared.Rent(size); + _refCount = 1; } public sealed override int Length => _array.Length; public sealed override bool IsDisposed => _array == null; - protected sealed override bool IsRetained => _refCount > 0; + protected sealed override bool IsRetained => Volatile.Read(ref _refCount) > 0; public sealed override Span Span { @@ -79,22 +82,30 @@ public sealed override MemoryHandle Pin(int byteOffset = 0) public sealed override void Retain() { - if (IsDisposed) - ThrowHelper.ThrowObjectDisposedException_ArrayMemoryPoolBuffer(); - - _refCount++; + while (true) + { + int currentCount = Volatile.Read(ref _refCount); + if (currentCount <= 0) ThrowHelper.ThrowObjectDisposedException_ArrayMemoryPoolBuffer(); + if (Interlocked.CompareExchange(ref _refCount, currentCount + 1, currentCount) == currentCount) break; + } } public sealed override bool Release() { - if (IsDisposed) - ThrowHelper.ThrowObjectDisposedException_ArrayMemoryPoolBuffer(); - - int newRefCount = --_refCount; - if (newRefCount < 0) - ThrowHelper.ThrowInvalidOperationException(); - - return newRefCount != 0; + while (true) + { + int currentCount = Volatile.Read(ref _refCount); + if (currentCount <= 0) ThrowHelper.ThrowObjectDisposedException_ArrayMemoryPoolBuffer(); + if (Interlocked.CompareExchange(ref _refCount, currentCount - 1, currentCount) == currentCount) + { + if (currentCount == 1) + { + Dispose(); + return false; + } + return true; + } + } } } } diff --git a/src/System.Memory/tests/MemoryPool/MemoryPool.cs b/src/System.Memory/tests/MemoryPool/MemoryPool.cs index 077ce931c083..63d34199721d 100644 --- a/src/System.Memory/tests/MemoryPool/MemoryPool.cs +++ b/src/System.Memory/tests/MemoryPool/MemoryPool.cs @@ -29,6 +29,7 @@ public static void DisposingTheSharedPoolIsANop() using (OwnedMemory block = mp.Rent(10)) { Assert.True(block.Length >= 10); + block.Release(); } } @@ -55,6 +56,7 @@ public static void MemoryPoolSpan() Assert.Equal((IntPtr)newMemoryHandle.Pointer, (IntPtr)pSpan); } } + block.Release(); } } @@ -77,6 +79,7 @@ public static void MemoryPoolPin(int byteOffset) Assert.Equal((IntPtr)pSpan, ((IntPtr)newMemoryHandle.Pointer) - byteOffset); } } + block.Release(); } } @@ -109,7 +112,7 @@ public static void MemoryPoolPinOffsetAtEnd() { return; // The pool gave us a very large block - too big to compute the byteOffset needed to carry out this test. Skip. } - + using (MemoryHandle newMemoryHandle = block.Pin(byteOffset: byteOffset)) { unsafe @@ -177,6 +180,7 @@ public static void EachRentalIsUniqueUntilDisposed() foreach (OwnedMemory prior in priorBlocks) { + prior.Release(); prior.Dispose(); } } @@ -187,6 +191,7 @@ public static void RentWithDefaultSize() using (OwnedMemory block = MemoryPool.Shared.Rent(minBufferSize: -1)) { Assert.True(block.Length >= 1); + block.Release(); } } @@ -224,6 +229,7 @@ public static void MemoryPoolTryGetArray() Assert.Equal((IntPtr)pSpan, (IntPtr)pArray); } } + block.Release(); } } @@ -243,10 +249,13 @@ public static void RefCounting() moreToGo = block.Release(); Assert.True(moreToGo); + moreToGo = block.Release(); + Assert.True(moreToGo); + moreToGo = block.Release(); Assert.False(moreToGo); - Assert.Throws(() => block.Release()); + Assert.Throws(() => block.Release()); } } @@ -255,7 +264,7 @@ public static void IsDisposed() { OwnedMemory block = MemoryPool.Shared.Rent(42); Assert.False(block.IsDisposed); - block.Dispose(); + block.Release(); Assert.True(block.IsDisposed); block.Dispose(); Assert.True(block.IsDisposed); @@ -265,6 +274,7 @@ public static void IsDisposed() public static void ExtraDisposesAreIgnored() { OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); block.Dispose(); block.Dispose(); } @@ -273,6 +283,7 @@ public static void ExtraDisposesAreIgnored() public static void NoSpanAfterDispose() { OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); block.Dispose(); Assert.Throws(() => block.Span.DontBox()); } @@ -281,6 +292,7 @@ public static void NoSpanAfterDispose() public static void NoRetainAfterDispose() { OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); block.Dispose(); Assert.Throws(() => block.Retain()); } @@ -289,6 +301,7 @@ public static void NoRetainAfterDispose() public static void NoRelease_AfterDispose() { OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); block.Dispose(); Assert.Throws(() => block.Release()); } @@ -297,6 +310,7 @@ public static void NoRelease_AfterDispose() public static void NoPinAfterDispose() { OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); block.Dispose(); Assert.Throws(() => block.Pin()); } @@ -306,9 +320,26 @@ public static void NoTryGetArrayAfterDispose() { OwnedMemory block = MemoryPool.Shared.Rent(42); Memory memory = block.Memory; + block.Release(); block.Dispose(); Assert.Throws(() => MemoryMarshal.TryGetArray(memory, out ArraySegment arraySegment)); } + + [Fact] + public static void IsRetainedWhenReturned() + { + OwnedMemory block = MemoryPool.Shared.Rent(42); + Assert.False(block.Release()); + } + + [Fact] + public static void IsDisposedWhenReleased() + { + OwnedMemory block = MemoryPool.Shared.Rent(42); + block.Release(); + + Assert.True(block.IsDisposed); + } } }