diff --git a/build.gradle b/build.gradle index 26bce92b4..6d0338917 100644 --- a/build.gradle +++ b/build.gradle @@ -21,6 +21,9 @@ configurations { } test { + testLogging { + exceptionFormat = 'full' + } include "**/*Test.*" if (!project.hasProperty("allTests")) { useJUnit { diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java new file mode 100644 index 000000000..af455da07 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java @@ -0,0 +1,16 @@ +package net.schmizz.sshj.xfer.scp; + +abstract class AbstractSCPClient { + + protected final SCPEngine engine; + protected int bandwidthLimit; + + AbstractSCPClient(SCPEngine engine) { + this.engine = engine; + } + + AbstractSCPClient(SCPEngine engine, int bandwidthLimit) { + this.engine = engine; + this.bandwidthLimit = bandwidthLimit; + } +} diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java index c2ad4258b..4c35513d8 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java @@ -24,18 +24,22 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; -import java.util.LinkedList; import java.util.List; +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArgument; +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments; + /** Support for uploading files over a connected link using SCP. */ -public final class SCPDownloadClient { +public final class SCPDownloadClient extends AbstractSCPClient { private boolean recursiveMode = true; - private final SCPEngine engine; - SCPDownloadClient(SCPEngine engine) { - this.engine = engine; + super(engine); + } + + SCPDownloadClient(SCPEngine engine, int bandwidthLimit) { + super(engine, bandwidthLimit); } /** Download a file from {@code sourcePath} on the connected host to {@code targetPath} locally. */ @@ -60,12 +64,12 @@ public void setRecursiveMode(boolean recursive) { void startCopy(String sourcePath, LocalDestFile targetFile) throws IOException { - List args = new LinkedList(); - args.add(Arg.SOURCE); - args.add(Arg.QUIET); - args.add(Arg.PRESERVE_TIMES); - if (recursiveMode) - args.add(Arg.RECURSIVE); + List args = SCPArguments.with(Arg.SOURCE) + .and(Arg.QUIET) + .and(Arg.PRESERVE_TIMES) + .and(Arg.RECURSIVE, recursiveMode) + .and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0)) + .arguments(); engine.execSCPWith(args, sourcePath); engine.signal("Start status OK"); diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java index 66bac99dd..321ec943e 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.LinkedList; import java.util.List; /** @see SCP Protocol */ @@ -39,7 +40,8 @@ enum Arg { RECURSIVE('r'), VERBOSE('v'), PRESERVE_TIMES('p'), - QUIET('q'); + QUIET('q'), + LIMIT('l'); private final char a; @@ -97,10 +99,10 @@ void cleanSlate() { exitStatus = -1; } - void execSCPWith(List args, String path) + void execSCPWith(List args, String path) throws SSHException { final StringBuilder cmd = new StringBuilder(SCP_COMMAND); - for (Arg arg : args) { + for (SCPArgument arg : args) { cmd.append(" ").append(arg); } cmd.append(" "); @@ -186,4 +188,84 @@ TransferListener getTransferListener() { return listener; } + public static class SCPArgument { + + private Arg name; + private String value; + + private SCPArgument(Arg name, String value) { + this.name = name; + this.value = value; + } + + public static SCPArgument addArgument(Arg name, String value) { + return new SCPArgument(name, value); + } + + @Override + public String toString() { + String option = name.toString(); + if (value != null) { + option = option + value; + } + return option; + } + } + + public static class SCPArguments { + + private static List args = null; + + private SCPArguments() { + this.args = new LinkedList(); + } + + private static void addArgument(Arg name, String value, boolean accept) { + if (accept) { + args.add(SCPArgument.addArgument(name, value)); + } + } + + public static SCPArguments with(Arg name) { + return with(name, null, true); + } + + public static SCPArguments with(Arg name, String value) { + return with(name, value, true); + } + + public static SCPArguments with(Arg name, boolean accept) { + return with(name, null, accept); + } + + public static SCPArguments with(Arg name, String value, boolean accept) { + SCPArguments scpArguments = new SCPArguments(); + addArgument(name, value, accept); + return scpArguments; + } + + public SCPArguments and(Arg name) { + addArgument(name, null, true); + return this; + } + + public SCPArguments and(Arg name, String value) { + addArgument(name, value, true); + return this; + } + + public SCPArguments and(Arg name, boolean accept) { + addArgument(name, null, accept); + return this; + } + + public SCPArguments and(Arg name, String value, boolean accept) { + addArgument(name, value, accept); + return this; + } + + public List arguments() { + return args; + } + } } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java index 7e055e3bb..c25cdde6b 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java @@ -28,18 +28,23 @@ public class SCPFileTransfer extends AbstractFileTransfer implements FileTransfer { + /** Default bandwidth limit for SCP transfer in kilobit/s (-1 means unlimited) */ + private static final int DEFAULT_BANDWIDTH_LIMIT = -1; + private final SessionFactory sessionFactory; + private int bandwidthLimit; public SCPFileTransfer(SessionFactory sessionFactory) { this.sessionFactory = sessionFactory; + this.bandwidthLimit = DEFAULT_BANDWIDTH_LIMIT; } public SCPDownloadClient newSCPDownloadClient() { - return new SCPDownloadClient(newSCPEngine()); + return new SCPDownloadClient(newSCPEngine(), bandwidthLimit); } public SCPUploadClient newSCPUploadClient() { - return new SCPUploadClient(newSCPEngine()); + return new SCPUploadClient(newSCPEngine(), bandwidthLimit); } private SCPEngine newSCPEngine() { @@ -70,4 +75,10 @@ public void upload(LocalSourceFile localFile, String remotePath) newSCPUploadClient().copy(localFile, remotePath); } + public SCPFileTransfer bandwidthLimit(int limit) { + if (limit > 0) { + this.bandwidthLimit = limit; + } + return this; + } } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java index 474ce1fc1..58fa3257c 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java @@ -24,17 +24,22 @@ import java.io.IOException; import java.io.InputStream; -import java.util.LinkedList; import java.util.List; +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArgument; +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments; + /** Support for uploading files over a connected link using SCP. */ -public final class SCPUploadClient { +public final class SCPUploadClient extends AbstractSCPClient { - private final SCPEngine engine; private LocalFileFilter uploadFilter; SCPUploadClient(SCPEngine engine) { - this.engine = engine; + super(engine); + } + + SCPUploadClient(SCPEngine engine, int bandwidthLimit) { + super(engine, bandwidthLimit); } /** Upload a local file from {@code localFile} to {@code targetPath} on the remote host. */ @@ -55,11 +60,11 @@ public void setUploadFilter(LocalFileFilter uploadFilter) { private synchronized void startCopy(LocalSourceFile sourceFile, String targetPath) throws IOException { - List args = new LinkedList(); - args.add(Arg.SINK); - args.add(Arg.RECURSIVE); - if (sourceFile.providesAtimeMtime()) - args.add(Arg.PRESERVE_TIMES); + List args = SCPArguments.with(Arg.SINK) + .and(Arg.RECURSIVE) + .and(Arg.PRESERVE_TIMES, sourceFile.providesAtimeMtime()) + .and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0)) + .arguments(); engine.execSCPWith(args, targetPath); engine.check("Start status OK"); process(engine.getTransferListener(), sourceFile); diff --git a/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java b/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java new file mode 100644 index 000000000..d7a8d466b --- /dev/null +++ b/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java @@ -0,0 +1,82 @@ +package net.schmizz.sshj.xfer.scp; + +import com.hierynomus.sshj.test.SshFixture; +import com.hierynomus.sshj.test.util.FileUtil; +import net.schmizz.sshj.SSHClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; + +import static junit.framework.Assert.assertFalse; +import static junit.framework.Assert.assertTrue; + +public class SCPFileTransferTest { + + public static final String DEFAULT_FILE_NAME = "my_file.txt"; + File targetDir; + File sourceFile; + File targetFile; + SSHClient sshClient; + + @Rule + public SshFixture fixture = new SshFixture(); + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Before + public void init() throws IOException { + sourceFile = tempFolder.newFile(DEFAULT_FILE_NAME); + FileUtil.writeToFile(sourceFile, "This is my file"); + targetDir = tempFolder.newFolder(); + targetFile = new File(targetDir + File.separator + DEFAULT_FILE_NAME); + sshClient = fixture.setupConnectedDefaultClient(); + sshClient.authPassword("test", "test"); + } + + @After + public void cleanup() { + if (targetFile.exists()) { + targetFile.delete(); + } + } + + @Test + public void shouldSCPUploadFile() throws IOException { + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + assertFalse(targetFile.exists()); + scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void shouldSCPUploadFileWithBandwidthLimit() throws IOException { + // Limit upload transfer at 2Mo/s + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(16000); + assertFalse(targetFile.exists()); + scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void shouldSCPDownloadFile() throws IOException { + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + assertFalse(targetFile.exists()); + scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void shouldSCPDownloadFileWithBandwidthLimit() throws IOException { + // Limit download transfer at 128Ko/s + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(1024); + assertFalse(targetFile.exists()); + scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } +}