diff --git a/src/System.Buffers.Primitives/System/Buffers/ReadOnlyBuffer_helpers.cs b/src/System.Buffers.Primitives/System/Buffers/ReadOnlyBuffer_helpers.cs index e16113adc19..8e7c268ac04 100644 --- a/src/System.Buffers.Primitives/System/Buffers/ReadOnlyBuffer_helpers.cs +++ b/src/System.Buffers.Primitives/System/Buffers/ReadOnlyBuffer_helpers.cs @@ -189,10 +189,9 @@ private static SequencePosition SeekMultiSegment(IMemoryList start, int st memory = memory.Slice(0, currentEnd - currentIndex); // We would prefer to put cursor in the beginning of next segment - // then past the end of previous one, but only if next exists - + // then past the end of previous one, but only if we are not leaving current buffer if (memory.Length > bytes || - (memory.Length == bytes && current.Next == null)) + (memory.Length == bytes && current == end)) { result = new SequencePosition(current, currentIndex + (int)bytes); foundResult = true; diff --git a/src/System.IO.Pipelines/System/IO/Pipelines/Pipe.cs b/src/System.IO.Pipelines/System/IO/Pipelines/Pipe.cs index 1e9a54fe046..18672baadcf 100644 --- a/src/System.IO.Pipelines/System/IO/Pipelines/Pipe.cs +++ b/src/System.IO.Pipelines/System/IO/Pipelines/Pipe.cs @@ -407,8 +407,7 @@ internal void Advance(SequencePosition consumed, SequencePosition examined) // if we are going to return commit head // we need to check that there is no writing operation that // might be using tailspace - if (consumed.Index == returnEnd.Length && - !(_commitHead == returnEnd && _writingHead != null)) + if (consumed.Index == returnEnd.Length && _writingHead != returnEnd) { var nextBlock = returnEnd.NextSegment; if (_commitHead == returnEnd) diff --git a/tests/System.Buffers.Primitives.Tests/ReadableBufferFacts.cs b/tests/System.Buffers.Primitives.Tests/ReadableBufferFacts.cs index c23cf506f25..779a6bfd62f 100644 --- a/tests/System.Buffers.Primitives.Tests/ReadableBufferFacts.cs +++ b/tests/System.Buffers.Primitives.Tests/ReadableBufferFacts.cs @@ -87,7 +87,7 @@ public void ReadableBufferDoesNotAllowSlicingOutOfRange(Action(() => buffer.GetPosition(buffer.Start, 101)); } [Fact] - public void ReadableBufferMove_DoesNotAlowNegative() + public void ReadableBufferGetPosition_DoesNotAlowNegative() { var buffer = Factory.CreateOfSize(20); Assert.Throws(() => buffer.GetPosition(buffer.Start, -1)); @@ -140,7 +140,7 @@ public void SegmentStartIsConsideredInBoundsCheck() } [Fact] - public void MovePrefersNextSegment() + public void GetPositionPrefersNextSegment() { var bufferSegment1 = new BufferSegment(); bufferSegment1.SetMemory(new OwnedArray(new byte[100]), 49, 99); @@ -157,6 +157,29 @@ public void MovePrefersNextSegment() Assert.Equal(bufferSegment2, c1.Segment); } + [Fact] + public void GetPositionDoesNotCrossOutsideBuffer() + { + var bufferSegment1 = new BufferSegment(); + bufferSegment1.SetMemory(new OwnedArray(new byte[100]), 0, 100); + + var bufferSegment2 = new BufferSegment(); + bufferSegment2.SetMemory(new OwnedArray(new byte[100]), 0, 100); + + var bufferSegment3 = new BufferSegment(); + bufferSegment3.SetMemory(new OwnedArray(new byte[100]), 0, 0); + + bufferSegment1.SetNext(bufferSegment2); + bufferSegment2.SetNext(bufferSegment3); + + var readableBuffer = new ReadOnlyBuffer(bufferSegment1, 0, bufferSegment2, 100); + + var c1 = readableBuffer.GetPosition(readableBuffer.Start, 200); + + Assert.Equal(100, c1.Index); + Assert.Equal(bufferSegment2, c1.Segment); + } + [Fact] public void Create_WorksWithArray() { diff --git a/tests/System.IO.Pipelines.Tests/PipelineReaderWriterFacts.cs b/tests/System.IO.Pipelines.Tests/PipelineReaderWriterFacts.cs index 0e0b6c630ae..72298874a17 100644 --- a/tests/System.IO.Pipelines.Tests/PipelineReaderWriterFacts.cs +++ b/tests/System.IO.Pipelines.Tests/PipelineReaderWriterFacts.cs @@ -531,5 +531,34 @@ public async Task AdvanceResetsCommitHeadIndex() awaitable = _pipe.Reader.ReadAsync(); Assert.False(awaitable.IsCompleted); } + + [Fact] + public async Task AdvanceWithGetPositionCrossingIntoWriteHeadWorks() + { + // Create two blocks + var memory = _pipe.Writer.GetMemory(1); + _pipe.Writer.Advance(memory.Length); + memory = _pipe.Writer.GetMemory(1); + _pipe.Writer.Advance(memory.Length); + await _pipe.Writer.FlushAsync(); + + // Read single block + var readResult = await _pipe.Reader.ReadAsync(); + + // Allocate more memory + memory = _pipe.Writer.GetMemory(1); + + // Create position that would cross into write head + var buffer = readResult.Buffer; + var position = buffer.GetPosition(buffer.Start, buffer.Length); + + // Return everything + _pipe.Reader.AdvanceTo(position); + + // Advance writer + _pipe.Writer.Advance(memory.Length); + _pipe.Writer.Commit(); + } + } }