Skip to content

Commit

Permalink
Add SftpClient.CopyFileAsync. (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Oct 19, 2024
1 parent c8601e9 commit 5ab3ea9
Show file tree
Hide file tree
Showing 11 changed files with 520 additions and 73 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class SftpClient : IDisposable

ValueTask RenameAsync(string oldpath, string newpath, CancellationToken cancellationToken = default);

ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default);

ValueTask<FileEntryAttributes?> GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default);
ValueTask SetAttributesAsync(
string path,
Expand Down
2 changes: 2 additions & 0 deletions src/Tmds.Ssh/SftpChannel.PacketType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ internal enum PacketType : byte
SSH_FXP_ATTRS = 105,
SSH_FXP_EXTENDED = 200,
SSH_FXP_EXTENDED_REPLY = 201,

SSH_SFTP_STATUS_RESPONSE = 0
}
}
6 changes: 6 additions & 0 deletions src/Tmds.Ssh/SftpChannel.Writer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ internal void WriteInt64(long value)
_length += 8;
}

internal void WriteUInt64(ulong value)
{
BinaryPrimitives.WriteUInt64BigEndian(_buffer.AsSpan(_length), value);
_length += 8;
}

public void WriteAttributes(
long? length = default,
(int Uid, int Gid)? ids = default,
Expand Down
406 changes: 341 additions & 65 deletions src/Tmds.Ssh/SftpChannel.cs

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion src/Tmds.Ssh/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ enum State
// For testing.
internal SshClient SshClient => _client;
internal bool IsDisposed => _state == State.Disposed;
internal SftpExtension EnabledExtensions
{
get
{
SftpChannel channel = _channel ?? throw new InvalidOperationException();
return channel.EnabledExtensions;
}
}

public SftpClient(string destination, ILoggerFactory? loggerFactory = null, SftpClientOptions? options = null) :
this(destination, SshConfigSettings.NoConfig, loggerFactory, options)
Expand Down Expand Up @@ -175,7 +183,7 @@ private async Task<SftpChannel> DoOpenAsync(bool explicitConnect, CancellationTo
bool success = false;
try
{
SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, cancellationToken).ConfigureAwait(false);
SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, _options, cancellationToken).ConfigureAwait(false);
_channel = channel;
success = true;
return channel;
Expand Down Expand Up @@ -267,6 +275,12 @@ public async ValueTask RenameAsync(string oldPath, string newPath, CancellationT
await channel.RenameAsync(oldPath, newPath, cancellationToken).ConfigureAwait(false);
}

public async ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CopyFileAsync(sourcePath, destinationPath, overwrite, cancellationToken).ConfigureAwait(false);
}

public async ValueTask<FileEntryAttributes?> GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
Expand Down
5 changes: 4 additions & 1 deletion src/Tmds.Ssh/SftpClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
namespace Tmds.Ssh;

public sealed partial class SftpClientOptions
{ }
{
// For testing.
internal SftpExtension DisabledExtensions { get; set; }
}
9 changes: 9 additions & 0 deletions src/Tmds.Ssh/SftpExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Tmds.Ssh;

[Flags]
enum SftpExtension
{
None = 0,
// https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00#section-7
CopyData = 1 // copy-data 1
}
4 changes: 2 additions & 2 deletions src/Tmds.Ssh/SshClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ public async Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = n
}
}

internal async Task<SftpChannel> OpenSftpChannelAsync(Action<SshChannel> onAbort, bool explicitConnect, CancellationToken cancellationToken)
internal async Task<SftpChannel> OpenSftpChannelAsync(Action<SshChannel> onAbort, bool explicitConnect, SftpClientOptions options, CancellationToken cancellationToken)
{
SshSession session = await GetSessionAsync(cancellationToken, explicitConnect).ConfigureAwait(false);

var channel = await session.OpenSftpClientChannelAsync(onAbort, cancellationToken).ConfigureAwait(false);

var sftpChannel = new SftpChannel(channel);
var sftpChannel = new SftpChannel(channel, options);

try
{
Expand Down
101 changes: 99 additions & 2 deletions test/Tmds.Ssh.Tests/SftpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Tmds.Ssh.Tests;
public class SftpClientTests
{
const int PacketSize = 32768; // roughly amount of bytes sent/received in a single sftp packet.
const int MultiPacketSize = 2 * PacketSize + 1024;

private readonly SshServer _sshServer;

Expand Down Expand Up @@ -54,7 +55,7 @@ public async Task SftpClientCtorFromSshClientSettings()

[InlineData(10)]
[InlineData(10 * 1024)] // 10 kiB
[InlineData(2 * PacketSize + 1024)]
[InlineData(MultiPacketSize)]
[Theory]
public async Task ReadWriteFile(int fileSize)
{
Expand Down Expand Up @@ -615,7 +616,7 @@ public async Task UploadDownloadDirectory()

[InlineData(0)]
[InlineData(10)]
[InlineData(2 * PacketSize + 1024)]
[InlineData(MultiPacketSize)]
[Theory]
public async Task UploadDownloadFile(int fileSize)
{
Expand Down Expand Up @@ -993,6 +994,102 @@ public async Task CacheLength()
}
}

[InlineData(0, SftpExtension.CopyData)]
[InlineData(10, SftpExtension.CopyData)]
[InlineData(MultiPacketSize, SftpExtension.CopyData)]
[InlineData(0, SftpExtension.None)]
[InlineData(10, SftpExtension.None)]
[InlineData(MultiPacketSize, SftpExtension.None)]
[SkippableTheory]
public async Task CopyFile(int fileSize, SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, fileSize);

string destinationFileName = $"/tmp/{Path.GetRandomFileName()}";
await sftpClient.CopyFileAsync(sourceFileName, destinationFileName);

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName);
}

[InlineData(true, SftpExtension.CopyData)]
[InlineData(false, SftpExtension.CopyData)]
[InlineData(true, SftpExtension.None)]
[InlineData(false, SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileOverwrite(bool overwrite, SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);
(string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);

Task copyTask = sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite).AsTask();

if (overwrite)
{
await copyTask;
}
else
{
await Assert.ThrowsAsync<SftpException>(() => copyTask);
}

byte[] expectedData = overwrite ? sourceData : destinationData;
await AssertRemoteFileContentEqualsAsync(sftpClient, expectedData, destinationFileName);
}

[InlineData(SftpExtension.CopyData)]
[InlineData(SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileToSelfDoesntLooseData(SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10);

await sftpClient.CopyFileAsync(sourceFileName, sourceFileName, overwrite: true);

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, sourceFileName);
}

[InlineData(SftpExtension.CopyData)]
[InlineData(SftpExtension.None)]
[SkippableTheory]
public async Task CopyFileOverwriteToLargerTruncates(SftpExtension sftpExtensions)
{
using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions);

const int SourceLength = 10;
(string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: SourceLength);
const int DestinationLength = SourceLength + SourceLength;
(string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: DestinationLength);

await sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite: true).AsTask();

await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName);
}

private async Task AssertRemoteFileContentEqualsAsync(SftpClient client, byte[] expected, string remoteFileName)
{
using var readFile = await client.OpenFileAsync(remoteFileName, FileAccess.Read);
Assert.NotNull(readFile);
var memoryStream = new MemoryStream();
await readFile.CopyToAsync(memoryStream);
Assert.Equal(expected, memoryStream.ToArray());
}

private async Task<(string filename, byte[] data)> CreateRemoteFileWithRandomDataAsync(SftpClient client, int length)
{
string filename = $"/tmp/{Path.GetRandomFileName()}";
byte[] data = new byte[10];
Random.Shared.NextBytes(data);
using var writeFile = await client.CreateNewFileAsync(filename, FileAccess.Write);
await writeFile.WriteAsync(data.AsMemory());
return (filename, data);
}

[InlineData(true)]
[InlineData(false)]
[Theory]
Expand Down
9 changes: 9 additions & 0 deletions test/Tmds.Ssh.Tests/SftpExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Tmds.Ssh.Tests;

// Copy of Tmds.Ssh.SftpExtensions with public access.
[Flags]
public enum SftpExtension
{
None = Tmds.Ssh.SftpExtension.None,
CopyData = Tmds.Ssh.SftpExtension.CopyData
}
33 changes: 31 additions & 2 deletions test/Tmds.Ssh.Tests/SshServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;
using SkipException = Xunit.SkipException;

namespace Tmds.Ssh.Tests;

Expand Down Expand Up @@ -347,11 +348,39 @@ public async Task<SshClient> CreateClientAsync(Action<SshClientSettings>? config
return client;
}

public async Task<SftpClient> CreateSftpClientAsync(Action<SshClientSettings>? configureSsh = null, CancellationToken cancellationToken = default, bool connect = true)
public async Task<SftpClient> CreateSftpClientAsync(Tmds.Ssh.Tests.SftpExtension enabledExtensions, Action<SshClientSettings>? configureSsh = null, CancellationToken cancellationToken = default)
{
var settings = CreateSshClientSettings(configureSsh);

var client = new SftpClient(settings);
SftpClientOptions? options = new()
{
DisabledExtensions = (Tmds.Ssh.SftpExtension)~enabledExtensions
};

var client = new SftpClient(settings, options: options);

await client.ConnectAsync(cancellationToken);

if (client.EnabledExtensions != (Tmds.Ssh.SftpExtension)enabledExtensions)
{
throw new SkipException($"The test server does not support the required {((Tmds.Ssh.SftpExtension)enabledExtensions) & ~client.EnabledExtensions} extensions.");
}

return client;
}

public async Task<SftpClient> CreateSftpClientAsync(Action<SshClientSettings>? configureSsh = null, Action<SftpClientOptions>? configureSftp = null, CancellationToken cancellationToken = default, bool connect = true)
{
var settings = CreateSshClientSettings(configureSsh);

SftpClientOptions? sftpClientOptions = null;
if (configureSftp is not null)
{
sftpClientOptions = new();
configureSftp.Invoke(sftpClientOptions);
}

var client = new SftpClient(settings, options: sftpClientOptions);

if (connect)
{
Expand Down

0 comments on commit 5ab3ea9

Please sign in to comment.