From 1c547886c8cfe2b2f38a56ef124fa533d26a45ee Mon Sep 17 00:00:00 2001 From: Raul Santelices Date: Tue, 21 Nov 2023 15:21:35 -0500 Subject: [PATCH] Fix for Remote port forwarding buffers can grow without limits (issue #658) (#913) * Fix for Remote port forwarding buffers can grow without limits (issue #658) * Update test classes to use JUnit 5 * Fix MB computation --- src/main/java/net/schmizz/sshj/Config.java | 4 + .../java/net/schmizz/sshj/ConfigImpl.java | 12 + .../schmizz/sshj/common/CircularBuffer.java | 194 +++++++++++++++ .../connection/channel/AbstractChannel.java | 7 +- .../channel/ChannelInputStream.java | 55 +++-- .../channel/direct/SessionChannel.java | 2 +- .../forwarded/RemotePFPerformanceTest.java | 188 +++++++++++++++ .../sshj/common/CircularBufferTest.java | 221 ++++++++++++++++++ 8 files changed, 649 insertions(+), 34 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/common/CircularBuffer.java create mode 100644 src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java create mode 100644 src/test/java/net/schmizz/sshj/common/CircularBufferTest.java diff --git a/src/main/java/net/schmizz/sshj/Config.java b/src/main/java/net/schmizz/sshj/Config.java index dfb6c1229..24166d08b 100644 --- a/src/main/java/net/schmizz/sshj/Config.java +++ b/src/main/java/net/schmizz/sshj/Config.java @@ -200,4 +200,8 @@ public interface Config { * See {@link #isVerifyHostKeyCertificates()}. */ void setVerifyHostKeyCertificates(boolean value); + + int getMaxCircularBufferSize(); + + void setMaxCircularBufferSize(int maxCircularBufferSize); } diff --git a/src/main/java/net/schmizz/sshj/ConfigImpl.java b/src/main/java/net/schmizz/sshj/ConfigImpl.java index 67243cb3d..23ad79bfd 100644 --- a/src/main/java/net/schmizz/sshj/ConfigImpl.java +++ b/src/main/java/net/schmizz/sshj/ConfigImpl.java @@ -49,6 +49,8 @@ public class ConfigImpl private boolean waitForServerIdentBeforeSendingClientIdent = false; private LoggerFactory loggerFactory; private boolean verifyHostKeyCertificates = true; + // HF-982: default to 16MB buffers. + private int maxCircularBufferSize = 16 * 1024 * 1024; @Override public List> getCipherFactories() { @@ -175,6 +177,16 @@ public LoggerFactory getLoggerFactory() { return loggerFactory; } + @Override + public int getMaxCircularBufferSize() { + return maxCircularBufferSize; + } + + @Override + public void setMaxCircularBufferSize(int maxCircularBufferSize) { + this.maxCircularBufferSize = maxCircularBufferSize; + } + @Override public void setLoggerFactory(LoggerFactory loggerFactory) { this.loggerFactory = loggerFactory; diff --git a/src/main/java/net/schmizz/sshj/common/CircularBuffer.java b/src/main/java/net/schmizz/sshj/common/CircularBuffer.java new file mode 100644 index 000000000..ea47351ea --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/CircularBuffer.java @@ -0,0 +1,194 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +public class CircularBuffer> { + + public static class CircularBufferException + extends SSHException { + + public CircularBufferException(String message) { + super(message); + } + } + + public static final class PlainCircularBuffer + extends CircularBuffer { + + public PlainCircularBuffer(int size, int maxSize) { + super(size, maxSize); + } + } + + /** + * Maximum size of the internal array (one plus the maximum capacity of the buffer). + */ + private final int maxSize; + /** + * Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos. + */ + private byte[] data; + /** + * Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty. + * Can take the value data.length, which is equivalent to 0. + */ + private int rpos; + /** + * Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is + * empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to + * data.length - 1 bytes. Can take the value data.length, which is equivalent to 0. + */ + private int wpos; + + /** + * Determines the size to which to grow the internal array. + */ + private int getNextSize(int currentSize) { + // Use next power of 2. + int nextSize = 1; + while (nextSize < currentSize) { + nextSize <<= 1; + if (nextSize <= 0) { + return maxSize; + } + } + return Math.min(nextSize, maxSize); // limit to max size + } + + /** + * Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/ + */ + public CircularBuffer(int size, int maxSize) { + this.maxSize = maxSize; + if (size > maxSize) { + throw new IllegalArgumentException( + String.format("Initial requested size %d larger than maximum size %d", size, maxSize)); + } + int initialSize = getNextSize(size); + this.data = new byte[initialSize]; + this.rpos = 0; + this.wpos = 0; + } + + /** + * Data available in the buffer for reading. + */ + public int available() { + int available = wpos - rpos; + return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos + } + + private void ensureAvailable(int a) + throws CircularBufferException { + if (available() < a) { + throw new CircularBufferException("Underflow"); + } + } + + /** + * Returns how many more bytes this buffer can receive. + */ + public int maxPossibleRemainingCapacity() { + // Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left. + int remaining = rpos - wpos - 1; + if (remaining < 0) { + remaining += data.length; // adjust if rpos is left of wpos + } + // Add the maximum amount the internal array can grow. + return remaining + maxSize - data.length; + } + + /** + * If the internal array does not have room for "capacity" more bytes, resizes the array to make that room. + */ + void ensureCapacity(int capacity) throws CircularBufferException { + int available = available(); + int remaining = data.length - available; + // If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left. + if (remaining <= capacity) { + int neededSize = available + capacity + 1; + int nextSize = getNextSize(neededSize); + if (nextSize < neededSize) { + throw new CircularBufferException("Attempted overflow"); + } + byte[] tmp = new byte[nextSize]; + // Copy data to the beginning of the new array. + if (wpos >= rpos) { + System.arraycopy(data, rpos, tmp, 0, available); + wpos -= rpos; // wpos must be relative to the new rpos, which will be 0 + } else { + int tail = data.length - rpos; + System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos + System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos + wpos += tail; // wpos must be relative to the new rpos, which will be 0 + } + rpos = 0; + data = tmp; + } + } + + /** + * Reads data from this buffer into the provided array. + */ + public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException { + ensureAvailable(length); + + int rposNext = rpos + length; + if (rposNext <= data.length) { + System.arraycopy(data, rpos, destination, offset, length); + } else { + int tail = data.length - rpos; + System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos + rposNext = length - tail; // rpos wraps around the end of the buffer + System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder + } + // This can make rpos equal data.length, which has the same effect as wpos being 0. + rpos = rposNext; + } + + /** + * Writes data to this buffer from the provided array. + */ + @SuppressWarnings("unchecked") + public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException { + ensureCapacity(length); + + int wposNext = wpos + length; + if (wposNext <= data.length) { + System.arraycopy(source, offset, data, wpos, length); + } else { + int tail = data.length - wpos; + System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos + wposNext = length - tail; // wpos wraps around the end of the buffer + System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder + } + // This can make wpos equal data.length, which has the same effect as wpos being 0. + wpos = wposNext; + + return (T) this; + } + + // Used only for testing. + int length() { + return data.length; + } + + @Override + public String toString() { + return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]"; + } + +} diff --git a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java index cb2373439..0ea8fae45 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java @@ -164,8 +164,7 @@ public String getType() { } @Override - public void handle(Message msg, SSHPacket buf) - throws ConnectionException, TransportException { + public void handle(Message msg, SSHPacket buf) throws SSHException { switch (msg) { case CHANNEL_DATA: @@ -354,7 +353,7 @@ protected void finishOff() { } protected void gotExtendedData(SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Extended data not supported on " + type + " channel"); } @@ -375,7 +374,7 @@ protected SSHPacket newBuffer(Message cmd) { } protected void receiveInto(ChannelInputStream stream, SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { final int len; try { len = buf.readUInt32AsInt(); diff --git a/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java b/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java index ee03d23cd..530f0167d 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java @@ -38,7 +38,7 @@ public final class ChannelInputStream private final Channel chan; private final Transport trans; private final Window.Local win; - private final Buffer.PlainBuffer buf; + private final CircularBuffer.PlainCircularBuffer buf; private final byte[] b = new byte[1]; private boolean eof; @@ -46,10 +46,11 @@ public final class ChannelInputStream public ChannelInputStream(Channel chan, Transport trans, Window.Local win) { this.chan = chan; - log = chan.getLoggerFactory().getLogger(getClass()); + this.log = chan.getLoggerFactory().getLogger(getClass()); this.trans = trans; this.win = win; - buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize()); + this.buf = new CircularBuffer.PlainCircularBuffer( + chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize()); } @Override @@ -113,48 +114,44 @@ public int read(byte[] b, int off, int len) len = buf.available(); } buf.readRawBytes(b, off, len); - if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) { - buf.clear(); - } - } - if (!chan.getAutoExpand()) { - checkWindow(); + if (!chan.getAutoExpand()) { + checkWindow(); + } } return len; } - public void receive(byte[] data, int offset, int len) - throws ConnectionException, TransportException { + public void receive(byte[] data, int offset, int len) throws SSHException { if (eof) { throw new ConnectionException("Getting data on EOF'ed stream"); } synchronized (buf) { buf.putRawBytes(data, offset, len); buf.notifyAll(); - } - // Potential fix for #203 (window consumed below 0). - // This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST - // And the window has not expanded yet. - synchronized (win) { + // Potential fix for #203 (window consumed below 0). + // This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST + // And the window has not expanded yet. win.consume(len); - } - if (chan.getAutoExpand()) { - checkWindow(); + if (chan.getAutoExpand()) { + checkWindow(); + } } } - private void checkWindow() - throws TransportException { - synchronized (win) { - final long adjustment = win.neededAdjustment(); - if (adjustment > 0) { - log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment); - trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST) - .putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment)); - win.expand(adjustment); - } + private void checkWindow() throws TransportException { + /* + * Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The + * difference between that and the remaining capacity is the maximum adjustment we can make to the window. + */ + final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize(); + final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment); + if (adjustment > 0) { + log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment); + trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST) + .putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment)); + win.expand(adjustment); } } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java index dfdfa55e4..873958204 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java @@ -210,7 +210,7 @@ protected void eofInputStreams() { @Override protected void gotExtendedData(SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { try { final int dataTypeCode = buf.readUInt32AsInt(); if (dataTypeCode == 1) diff --git a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java new file mode 100644 index 000000000..448fc09e8 --- /dev/null +++ b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java @@ -0,0 +1,188 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.connection.channel.forwarded; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward; +import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RemotePFPerformanceTest { + + private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class); + + @Test + @Disabled + public void startPF() throws IOException, InterruptedException { + DefaultConfig config = new DefaultConfig(); + config.setMaxCircularBufferSize(16 * 1024 * 1024); + SSHClient client = new SSHClient(config); + client.loadKnownHosts(); + client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba"); + + client.getConnection().getKeepAlive().setKeepAliveInterval(5); + client.connect("localhost"); + client.getConnection().getKeepAlive().setKeepAliveInterval(5); + + Object consumerReadyMonitor = new Object(); + ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor); + ProducerThread producerThread = new ProducerThread(); + try { + + client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD")); + + /* + * We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further + * forward all such channels to google.com:80 + */ + client.getRemotePortForwarder().bind( + // where the server should listen + new Forward(8888), + // what we do with incoming connections that are forwarded to us + new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345))); + + consumerThread.start(); + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.wait(); + } + producerThread.start(); + + // Wait for consumer to finish receiving data. + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.wait(); + } + + } finally { + producerThread.interrupt(); + consumerThread.interrupt(); + client.disconnect(); + } + } + + private static class ConsumerThread extends Thread { + private final Object consumerReadyMonitor; + + private ConsumerThread(Object consumerReadyMonitor) { + super("Consumer"); + this.consumerReadyMonitor = consumerReadyMonitor; + } + + @Override + public void run() { + try (ServerSocket serverSocket = new ServerSocket(12345)) { + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + try (Socket acceptedSocket = serverSocket.accept()) { + InputStream in = acceptedSocket.getInputStream(); + int numRead; + byte[] buf = new byte[40000]; + //byte[] buf = new byte[255 * 4 * 1000]; + byte expectedNext = 1; + while ((numRead = in.read(buf)) != 0) { + if (Thread.interrupted()) { + log.info("Consumer thread interrupted"); + return; + } + log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1])); + if (buf[numRead - 1] == 0) { + verifyData(buf, numRead - 1, expectedNext); + break; + } + expectedNext = verifyData(buf, numRead, expectedNext); + // Slow down consumer to test buffering. + Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS"))); + } + log.info("Consumer read end of stream value: " + numRead); + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + } + } catch (Exception e) { + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + e.printStackTrace(); + } + } + + private byte verifyData(byte[] buf, int numRead, byte expectedNext) { + for (int i = 0; i < numRead; ++i) { + if (buf[i] != expectedNext) { + fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext); + } + if (++expectedNext == 0) { + expectedNext = 1; + } + } + return expectedNext; + } + } + + private static class ProducerThread extends Thread { + private ProducerThread() { + super("Producer"); + } + + @Override + public void run() { + try (Socket clientSocket = new Socket("127.0.0.1", 8888); + OutputStream writer = clientSocket.getOutputStream()) { + byte[] buf = getData(); + assertEquals(buf[0], 1); + assertEquals(buf[buf.length - 1], -1); + for (int i = 0; i < 1000; ++i) { + writer.write(buf); + if (Thread.interrupted()) { + log.info("Consumer thread interrupted"); + return; + } + log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1])); + } + writer.write(0); // end of stream value + log.info("Producer finished sending data"); + } catch (Exception e) { + e.printStackTrace(); + } + } + + private byte[] getData() { + byte[] buf = new byte[255 * 4 * 1000]; + byte nextValue = 1; + for (int i = 0; i < buf.length; ++i) { + buf[i] = nextValue++; + // reserve 0 for end of stream + if (nextValue == 0) { + nextValue = 1; + } + } + return buf; + } + } + +} diff --git a/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java b/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java new file mode 100644 index 000000000..da53afc38 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java @@ -0,0 +1,221 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import static org.junit.jupiter.api.Assertions.*; + +import net.schmizz.sshj.common.CircularBuffer.CircularBufferException; +import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer; +import org.junit.jupiter.api.Test; + +public class CircularBufferTest { + + @Test + public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(500); + buffer.putRawBytes(dataToWrite, 0, 100); + buffer.putRawBytes(dataToWrite, 100, 100); + + byte[] dataToRead = new byte[500]; + buffer.readRawBytes(dataToRead, 0, 80); + buffer.readRawBytes(dataToRead, 80, 80); + + buffer.putRawBytes(dataToWrite, 200, 100); + buffer.readRawBytes(dataToRead, 160, 80); + + buffer.putRawBytes(dataToWrite, 300, 100); + buffer.readRawBytes(dataToRead, 240, 80); + + buffer.putRawBytes(dataToWrite, 400, 100); + buffer.readRawBytes(dataToRead, 320, 80); + buffer.readRawBytes(dataToRead, 400, 100); + + assertEquals(256, buffer.length()); + assertArrayEquals(dataToWrite, dataToRead); + } + + @Test + public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(500); + buffer.putRawBytes(dataToWrite, 0, 100); + buffer.putRawBytes(dataToWrite, 100, 100); + + byte[] dataToRead = new byte[500]; + buffer.readRawBytes(dataToRead, 0, 80); + buffer.readRawBytes(dataToRead, 80, 80); + + buffer.putRawBytes(dataToWrite, 200, 100); + buffer.readRawBytes(dataToRead, 160, 80); + + buffer.putRawBytes(dataToWrite, 300, 100); + buffer.readRawBytes(dataToRead, 240, 80); + + buffer.putRawBytes(dataToWrite, 400, 100); + buffer.readRawBytes(dataToRead, 320, 80); + + buffer.readRawBytes(dataToRead, 400, 100); + + assertEquals(256, buffer.length()); + assertArrayEquals(dataToWrite, dataToRead); + } + + @Test + public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(64); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end + + assertEquals(64, buffer.available()); + assertEquals(64 * 2, buffer.length()); + } + + @Test + public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + // Move 1 byte forward. + buffer.putRawBytes(new byte[1], 0, 1); + buffer.readRawBytes(new byte[1], 0, 1); + + // Force writes to wrap around. + byte[] dataToWrite = getData(64); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end + + assertEquals(64, buffer.available()); + assertEquals(64 * 2, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + // Max capacity is always one less than the buffer size. + int maxCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1, maxCapacity); + + byte[] dataToWrite = getData(maxCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + final int initiallyWritten = 10; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length + initiallyWritten, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + // Cause the internal write pointer to wrap around and be left of the read pointer. + final int initiallyWritten = 40; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length + initiallyWritten, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + final int initiallyWritten = 10; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity + 1); + assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length)); + } + + @Test + public void shouldThrowWhenReadingEmptyBuffer() { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1)); + } + + @Test + public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + buffer.putRawBytes(new byte[1], 0, 1); + assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2)); + } + + @Test + public void shouldThrowOnAboveMaximumInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64)); + } + + @Test + public void shouldThrowOnMaximumInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64)); + } + + @Test + public void shouldAllowFullCapacity() throws CircularBufferException { + int maxSize = 1024; + PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize); + buffer.ensureCapacity(maxSize - 1); + assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity()); + } + + @Test + public void shouldThrowOnTooLargeRequestedCapacity() { + int maxSize = 1024; + PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize); + assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize)); + } + + private static byte[] getData(int length) { + byte[] data = new byte[length]; + byte nextValue = 0; + for (int i = 0; i < length; ++i) { + data[i] = nextValue++; + } + return data; + } +}