From cb9acdff1f8660963beaeed735bc6d722f98498b Mon Sep 17 00:00:00 2001 From: Vladimir Lagunov Date: Fri, 4 Mar 2022 19:13:05 +0700 Subject: [PATCH] Fix ReadAheadRemoteFileInputStream not reading the whole file if a buffer is too big If an instance of ReadAheadRemoteFileInputStream before this change is wrapped into a BufferedInputStream with a big buffer, the SSH client requests big packets from the server. It turned out that if the server had sent a response smaller than requested, the client wouldn't have adjusted to decreased window size, and would have read the file incorrectly. This change detects cases when the server is not able to fulfil client's requests. Since this change, the client adjusts the maximum request length, sends new read-ahead requests, and starts to ignore all read-ahead requests sent earlier. Just specifying some allegedly small constant buffer size wouldn't have helped in all possible cases. There is no way to explicitly get the maximum request length inside a client. All that limits differ from server to server. For instance, OpenSSH defines SFTP_MAX_MSG_LENGTH as 256 * 1024. Apache SSHD defines MAX_READDATA_PACKET_LENGTH as 63 * 1024, and it allows to redefine that size. Interestingly, a similar issue #183 was fixed many years ago, but the bug was actually in the code introduced for that fix. --- .../net/schmizz/sshj/sftp/RemoteFile.java | 101 ++++++++++-------- .../hierynomus/sshj/sftp/RemoteFileTest.java | 57 +++++++++- 2 files changed, 112 insertions(+), 46 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java index a5558030..81e95fe1 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java @@ -220,16 +220,42 @@ public int read(byte[] into, int off, int len) throws IOException { public class ReadAheadRemoteFileInputStream extends InputStream { + private class UnconfirmedRead { + private final long offset; + private final Promise promise; + private final int length; + + private UnconfirmedRead(long offset, int length, Promise promise) { + this.offset = offset; + this.length = length; + this.promise = promise; + } + + UnconfirmedRead(long offset, int length) throws IOException { + this(offset, length, RemoteFile.this.asyncRead(offset, length)); + } + + public long getOffset() { + return offset; + } + + public Promise getPromise() { + return promise; + } + + public int getLength() { + return length; + } + } private final byte[] b = new byte[1]; private final int maxUnconfirmedReads; private final long readAheadLimit; - private final Queue> unconfirmedReads = new LinkedList>(); - private final Queue unconfirmedReadOffsets = new LinkedList(); + private final Queue unconfirmedReads = new LinkedList<>(); - private long requestOffset; - private long responseOffset; + private long currentOffset; + private int maxReadLength = Integer.MAX_VALUE; private boolean eof; public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) { @@ -247,28 +273,42 @@ public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset, assert 0 <= fileOffset; this.maxUnconfirmedReads = maxUnconfirmedReads; - this.requestOffset = this.responseOffset = fileOffset; + this.currentOffset = fileOffset; this.readAheadLimit = readAheadLimit > 0 ? fileOffset + readAheadLimit : Long.MAX_VALUE; } private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]); private boolean retrieveUnconfirmedRead(boolean blocking) throws IOException { - if (unconfirmedReads.size() <= 0) { + final UnconfirmedRead unconfirmedRead = unconfirmedReads.peek(); + if (unconfirmedRead == null || !blocking && !unconfirmedRead.getPromise().isDelivered()) { return false; } + unconfirmedReads.remove(unconfirmedRead); - if (!blocking && !unconfirmedReads.peek().isDelivered()) { - return false; - } - - unconfirmedReadOffsets.remove(); - final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); + final Response res = unconfirmedRead.promise.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); switch (res.getType()) { case DATA: int recvLen = res.readUInt32AsInt(); - responseOffset += recvLen; - pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen); + if (unconfirmedRead.offset == currentOffset) { + currentOffset += recvLen; + pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen); + + if (recvLen < unconfirmedRead.length) { + // The server returned a packet smaller than the client had requested. + // It can be caused by at least one of the following: + // * The file has been read fully. Then, few futile read requests can be sent during + // the next read(), but the file will be downloaded correctly anyway. + // * The server shapes the request length. Then, the read window will be adjusted, + // and all further read-ahead requests won't be shaped. + // * The file on the server is not a regular file, it is something like fifo. + // Then, the window will shrink, and the client will start reading the file slower than it + // hypothetically can. It must be a rare case, and it is not worth implementing a sort of + // congestion control algorithm here. + maxReadLength = recvLen; + unconfirmedReads.clear(); + } + } break; case STATUS: @@ -296,49 +336,24 @@ public int read(byte[] into, int off, int len) throws IOException { // we also need to go here for len <= 0, because pending may be at // EOF in which case it would return -1 instead of 0 + long requestOffset = currentOffset; while (unconfirmedReads.size() <= maxUnconfirmedReads) { // Send read requests as long as there is no EOF and we have not reached the maximum parallelism - int reqLen = Math.max(1024, len); // don't be shy! + int reqLen = Math.min(Math.max(1024, len), maxReadLength); if (readAheadLimit > requestOffset) { long remaining = readAheadLimit - requestOffset; if (reqLen > remaining) { reqLen = (int) remaining; } } - unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen)); - unconfirmedReadOffsets.add(requestOffset); + unconfirmedReads.add(new UnconfirmedRead(requestOffset, reqLen)); requestOffset += reqLen; if (requestOffset >= readAheadLimit) { break; } } - long nextOffset = unconfirmedReadOffsets.peek(); - if (responseOffset != nextOffset) { - - // the server could not give us all the data we needed, so - // we try to fill the gap synchronously - - assert responseOffset < nextOffset; - assert 0 < (nextOffset - responseOffset); - assert (nextOffset - responseOffset) <= Integer.MAX_VALUE; - - byte[] buf = new byte[(int) (nextOffset - responseOffset)]; - int recvLen = RemoteFile.this.read(responseOffset, buf, 0, buf.length); - - if (recvLen < 0) { - eof = true; - return -1; - } - - if (0 == recvLen) { - // avoid infinite loops - throw new SFTPException("Unexpected response size (0), bailing out"); - } - - responseOffset += recvLen; - pending = new ByteArrayInputStream(buf, 0, recvLen); - } else if (!retrieveUnconfirmedRead(true /*blocking*/)) { + if (!retrieveUnconfirmedRead(true /*blocking*/)) { // this may happen if we change prefetch strategy // currently, we should never get here... diff --git a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java index 949a917c..c69d1240 100644 --- a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java +++ b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java @@ -17,22 +17,24 @@ import com.hierynomus.sshj.test.SshFixture; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.ByteArrayUtils; import net.schmizz.sshj.sftp.OpenMode; import net.schmizz.sshj.sftp.RemoteFile; import net.schmizz.sshj.sftp.SFTPEngine; import net.schmizz.sshj.sftp.SFTPException; +import org.apache.sshd.common.util.io.IoUtils; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; +import java.security.SecureRandom; import java.util.EnumSet; import java.util.Random; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class RemoteFileTest { @@ -174,4 +176,53 @@ public void limitedReadAheadInputStream() throws IOException { assertThat("The written and received data should match", data, equalTo(test2)); } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_FullSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 1024 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_HalfSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 512 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_QuarterSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 256 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_SmallSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 1024); + } + + private void doTestShouldReadCorrectlyWhenWrappedInBufferedStream(int fileSize, int bufferSize) throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + final byte[] expected = new byte[fileSize]; + new SecureRandom(new byte[] { 31 }).nextBytes(expected); + + File file = temp.newFile("shouldReadCorrectlyWhenWrappedInBufferedStream.bin"); + try (OutputStream fStream = new FileOutputStream(file)) { + IoUtils.copy(new ByteArrayInputStream(expected), fStream); + } + + RemoteFile rf = sftp.open(file.getPath()); + final byte[] actual; + try (InputStream inputStream = new BufferedInputStream( + rf.new ReadAheadRemoteFileInputStream(10), + bufferSize) + ) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + IoUtils.copy(inputStream, baos, expected.length); + actual = baos.toByteArray(); + } + + assertEquals("The file should be fully read", expected.length, actual.length); + assertThat("The file should be read correctly", + ByteArrayUtils.equals(expected, 0, actual, 0, expected.length)); + } }