diff --git a/README.md b/README.md index 02d2298..825f48e 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,8 @@ class SftpClient : IDisposable ValueTask RenameAsync(string oldpath, string newpath, CancellationToken cancellationToken = default); + ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default); + ValueTask GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default); ValueTask SetAttributesAsync( string path, diff --git a/src/Tmds.Ssh/SftpChannel.PacketType.cs b/src/Tmds.Ssh/SftpChannel.PacketType.cs index 06660a6..c16cf5a 100644 --- a/src/Tmds.Ssh/SftpChannel.PacketType.cs +++ b/src/Tmds.Ssh/SftpChannel.PacketType.cs @@ -34,5 +34,7 @@ internal enum PacketType : byte SSH_FXP_ATTRS = 105, SSH_FXP_EXTENDED = 200, SSH_FXP_EXTENDED_REPLY = 201, + + SSH_SFTP_STATUS_RESPONSE = 0 } } diff --git a/src/Tmds.Ssh/SftpChannel.Writer.cs b/src/Tmds.Ssh/SftpChannel.Writer.cs index cd6e288..742264d 100644 --- a/src/Tmds.Ssh/SftpChannel.Writer.cs +++ b/src/Tmds.Ssh/SftpChannel.Writer.cs @@ -69,6 +69,12 @@ internal void WriteInt64(long value) _length += 8; } + internal void WriteUInt64(ulong value) + { + BinaryPrimitives.WriteUInt64BigEndian(_buffer.AsSpan(_length), value); + _length += 8; + } + public void WriteAttributes( long? length = default, (int Uid, int Gid)? ids = default, diff --git a/src/Tmds.Ssh/SftpChannel.cs b/src/Tmds.Ssh/SftpChannel.cs index 3b355ea..284b776 100644 --- a/src/Tmds.Ssh/SftpChannel.cs +++ b/src/Tmds.Ssh/SftpChannel.cs @@ -27,16 +27,18 @@ sealed partial class SftpChannel : IDisposable // An onGoing ValueTask may allocate multiple buffers. const int MaxConcurrentBuffers = 64; - internal SftpChannel(ISshChannel channel) + internal SftpChannel(ISshChannel channel, SftpClientOptions options) { _channel = channel; _receivePacketSize = _channel.ReceiveMaxPacket; + _options = options; } public CancellationToken ChannelAborted => _channel.ChannelAborted; private readonly ISshChannel _channel; + private readonly SftpClientOptions _options; // Limits the number of buffers concurrently used for uploading/downloading. private readonly SemaphoreSlim s_downloadBufferSemaphore = new SemaphoreSlim(MaxConcurrentBuffers, MaxConcurrentBuffers); @@ -46,6 +48,7 @@ public CancellationToken ChannelAborted private int _nextId = 5; private int GetNextId() => Interlocked.Increment(ref _nextId); private int _receivePacketSize; + private SftpExtension _supportedExtensions; internal int GetMaxWritePayload(byte[] handle) // SSH_FXP_WRITE payload => _channel.SendMaxPacket @@ -56,6 +59,13 @@ internal int MaxReadPayload // SSH_FXP_DATA payload => _channel.ReceiveMaxPacket - 4 /* packet length */ - 1 /* packet type */ - 4 /* id */ - 4 /* payload length */; + internal int GetCopyBetweenSftpFilesBufferSize(byte[] destinationHandle) + => Math.Min(MaxReadPayload, GetMaxWritePayload(destinationHandle)); + + internal SftpExtension EnabledExtensions => _supportedExtensions; + + private bool SupportsCopyData => (_supportedExtensions & SftpExtension.CopyData) != 0; + public void Dispose() { _channel.Dispose(); @@ -168,6 +178,209 @@ public ValueTask RenameAsync(string oldPath, string newPath, CancellationToken c return ExecuteAsync(packet, id, pendingOperation, cancellationToken); } + public async ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default) + { + // Get the source file attributes and open it in parallel. + // We get the attribute to dermine the permissions for the destination path. + ValueTask sourceAttributesTask = GetAttributesAsync(sourcePath, followLinks: true, cancellationToken); + using SftpFile? sourceFile = await OpenFileCoreAsync(sourcePath, SftpOpenFlags.Open | SftpOpenFlags.Read, default(UnixFilePermissions), SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false); + if (sourceFile is null) + { + throw new SftpException(SftpError.NoSuchFile); + } + // Get the attributes ignoring any errors and falling back to getting them from the handle (unlikely). + FileEntryAttributes? sourceAttributes = null; + try + { + sourceAttributes = await sourceAttributesTask; + } + catch + { } + if (sourceAttributes is null) + { + sourceAttributes = await sourceFile.GetAttributesAsync(cancellationToken).ConfigureAwait(false); + } + + UnixFilePermissions permissions = sourceAttributes.Permissions & OwnershipPermissions; // Do not preserve setid bits (since the owner may change). + + // Refresh our source length from the handle (in parallel with with opening the destination file). +#pragma warning disable CS8619 // Nullability of reference types in value doesn't match target type. + sourceAttributesTask = sourceFile.GetAttributesAsync(cancellationToken); +#pragma warning restore CS8619 + + // When we are overwriting, the file may exists and be larger than the source file. + // We could open with Truncate but then the user would lose their data if they (by accident) uses a source and destination that are the same file. + // To avoid that, we'll truncate after copying the data instead. + SftpOpenFlags openFlags = overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew; + using SftpFile destinationFile = (await OpenFileCoreAsync(destinationPath, openFlags | SftpOpenFlags.Write, permissions, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!; + + // Get the length before we start writing so we know if we need to truncate. + ValueTask initialLengthTask = overwrite ? destinationFile.GetLengthAsync(cancellationToken) : ValueTask.FromResult(0L); + + long copyLength = (await sourceAttributesTask.ConfigureAwait(false))!.Length; + if (copyLength > 0) + { + bool doCopyAsync = true; + if (SupportsCopyData) + { + try + { + await CopyDataAsync(sourceFile.Handle, 0, destinationFile.Handle, 0, (ulong)copyLength, cancellationToken).ConfigureAwait(false); + doCopyAsync = false; + } + catch (SftpException ex) when (ex.Error == SftpError.Eof || // source has less data than copyLength (unlikely). + ex.Error == SftpError.Failure) // (maybe) source and destination are same path + { + // Fall through to async copy. + } + } + + if (doCopyAsync) + { + await CopyAsync(copyLength, cancellationToken).ConfigureAwait(false); + } + } + + // Truncate if the sourceFile is smaller than the destination file's initial length. + long initialLength = await initialLengthTask.ConfigureAwait(false); + if (initialLength > copyLength) + { + await destinationFile.SetLengthAsync(copyLength).ConfigureAwait(false); + } + + async ValueTask CopyAsync(long length, CancellationToken cancellationToken) + { + Debug.Assert(length > 0); + + int bufferSize = GetCopyBetweenSftpFilesBufferSize(destinationFile.Handle); + + ValueTask previous = default; + + CancellationTokenSource breakLoop = new(); + + for (long offset = 0; offset < length; offset += bufferSize) + { + if (!breakLoop.IsCancellationRequested) + { + await s_downloadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + previous = CopyBuffer(previous, offset, bufferSize); + } + } + + await previous.ConfigureAwait(false); + + async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) + { + try + { + do + { + byte[]? buffer = null; + try + { + int bytesRead; + try + { + if (breakLoop.IsCancellationRequested) + { + return; + } + + buffer = ArrayPool.Shared.Rent(length); + bytesRead = await sourceFile.ReadAtAsync(buffer, sourceFile.Position + offset, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + break; + } + + // Our download buffer becomes an upload buffer. + await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + breakLoop.Cancel(); + throw; + } + finally + { + s_downloadBufferSemaphore.Release(); + } + try + { + await destinationFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset).ConfigureAwait(false); + length -= bytesRead; + offset += bytesRead; + } + catch + { + breakLoop.Cancel(); + throw; + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + buffer = null; + } + s_uploadBufferSemaphore.Release(); + } + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + } + if (length > 0) + { + await s_downloadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } while (length > 0); + } + finally + { + await previousCopy.ConfigureAwait(false); + } + } + } + } + + // https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00#section-7 + private ValueTask CopyDataAsync(byte[] sourceFileHandle, ulong sourceOffset, byte[] destinationFileHandle, ulong destinationOffset, ulong? length, CancellationToken cancellationToken = default) + { + Debug.Assert((_supportedExtensions & SftpExtension.CopyData) != 0); + + if (length == 0) + { + return default; + } + + /* + byte SSH_FXP_EXTENDED + uint32 request-id + string "copy-data" + string read-from-handle + uint64 read-from-offset + uint64 read-data-length + string write-to-handle + uint64 write-to-offset + */ + PacketType packetType = PacketType.SSH_FXP_EXTENDED; + int id = GetNextId(); + PendingOperation pendingOperation = CreatePendingOperation(PacketType.SSH_SFTP_STATUS_RESPONSE); + Packet packet = new Packet(packetType); + packet.WriteInt(id); + packet.WriteString("copy-data"); + packet.WriteString(sourceFileHandle); + packet.WriteUInt64(sourceOffset); + packet.WriteUInt64(length ?? 0); + packet.WriteString(destinationFileHandle); + packet.WriteUInt64(destinationOffset); + return ExecuteAsync(packet, id, pendingOperation, cancellationToken); + } + public ValueTask GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default) { PacketType packetType = followLinks ? PacketType.SSH_FXP_STAT : PacketType.SSH_FXP_LSTAT; @@ -541,11 +754,11 @@ private static UnixFilePermissions GetPermissionsForDirectory(string directoryPa { const UnixFilePermissions Default = SftpClient.DefaultCreateDirectoryPermissions & ~PretendUMask; #if NET7_0_OR_GREATER - if (!OperatingSystem.IsWindows()) - { - return File.GetUnixFileMode(directoryPath).ToUnixFilePermissions(); - } - return Default; // TODO: do something better on Windows? + if (!OperatingSystem.IsWindows()) + { + return File.GetUnixFileMode(directoryPath).ToUnixFilePermissions(); + } + return Default; // TODO: do something better on Windows? #else return Default; #endif @@ -555,11 +768,11 @@ private static UnixFilePermissions GetPermissionsForFile(SafeFileHandle fileHand { const UnixFilePermissions Default = SftpClient.DefaultCreateFilePermissions & ~PretendUMask; #if NET7_0_OR_GREATER - if (!OperatingSystem.IsWindows()) - { - return File.GetUnixFileMode(fileHandle).ToUnixFilePermissions(); - } - return Default; // TODO: do something better on Windows? + if (!OperatingSystem.IsWindows()) + { + return File.GetUnixFileMode(fileHandle).ToUnixFilePermissions(); + } + return Default; // TODO: do something better on Windows? #else return Default; #endif @@ -577,12 +790,16 @@ public async ValueTask UploadFileAsync(string localPath, string remotePath, long ValueTask previous = default; + CancellationTokenSource? breakLoop = length > 0 ? new() : null; + for (long offset = 0; offset < length; offset += GetMaxWritePayload(remoteFile.Handle)) { - // Obtain a buffer before starting the copy to ensure we're not competing - // for buffers with the previous copy. - await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - previous = CopyBuffer(previous, offset, GetMaxWritePayload(remoteFile.Handle)); + Debug.Assert(breakLoop is not null); + if (!breakLoop.IsCancellationRequested) + { + await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + previous = CopyBuffer(previous, offset, GetMaxWritePayload(remoteFile.Handle)); + } } await previous.ConfigureAwait(false); @@ -591,31 +808,46 @@ public async ValueTask UploadFileAsync(string localPath, string remotePath, long async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) { - byte[]? buffer = null; try { - buffer = ArrayPool.Shared.Rent(length); - do + byte[]? buffer = null; + try { - int bytesRead = RandomAccess.Read(localFile, buffer.AsSpan(0, length), offset); - if (bytesRead == 0) + if (breakLoop.IsCancellationRequested) { - break; + return; } - await remoteFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset, cancellationToken).ConfigureAwait(false); - length -= bytesRead; - offset += bytesRead; - } while (length > 0); - await previousCopy.ConfigureAwait(false); + buffer = ArrayPool.Shared.Rent(length); + do + { + int bytesRead = RandomAccess.Read(localFile, buffer.AsSpan(0, length), offset); + if (bytesRead == 0) + { + break; + } + await remoteFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset, cancellationToken).ConfigureAwait(false); + length -= bytesRead; + offset += bytesRead; + } while (length > 0); + } + catch + { + breakLoop.Cancel(); + throw; + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + s_uploadBufferSemaphore.Release(); + } } finally { - if (buffer != null) - { - ArrayPool.Shared.Return(buffer); - } - s_uploadBufferSemaphore.Release(); + await previousCopy.ConfigureAwait(false); } } } @@ -762,14 +994,14 @@ static string EnsureParentDirectory(string lastDirectory, string itemPath) private static void CreateLocalDirectory(string path, UnixFilePermissions permissions) { #if NET7_0_OR_GREATER - if (OperatingSystem.IsWindows()) - { - Directory.CreateDirectory(path); - } - else - { - Directory.CreateDirectory(path, (permissions & CreateDirectoryPermissionMask).ToUnixFileMode()); - } + if (OperatingSystem.IsWindows()) + { + Directory.CreateDirectory(path); + } + else + { + Directory.CreateDirectory(path, (permissions & CreateDirectoryPermissionMask).ToUnixFileMode()); + } #else Directory.CreateDirectory(path); #endif @@ -785,10 +1017,10 @@ private static FileStream OpenFileStream(string path, FileMode mode, FileAccess Share = share }; #if NET7_0_OR_GREATER - if (!OperatingSystem.IsWindows()) - { - options.UnixCreateMode = (permissions & CreateFilePermissionMask).ToUnixFileMode(); - } + if (!OperatingSystem.IsWindows()) + { + options.UnixCreateMode = (permissions & CreateFilePermissionMask).ToUnixFileMode(); + } #endif return new FileStream(path, options); } @@ -839,7 +1071,7 @@ private async ValueTask DownloadFileAsync(string remotePath, string localPath, l FileEntryAttributes? attributes = await getAttributes.ConfigureAwait(false); if (attributes is null) // unlikely { - attributes = await remoteFile.GetAttributesAsync(cancellationToken). ConfigureAwait(false); + attributes = await remoteFile.GetAttributesAsync(cancellationToken).ConfigureAwait(false); } length = attributes.Length; permissions = attributes.Permissions; @@ -849,12 +1081,16 @@ private async ValueTask DownloadFileAsync(string remotePath, string localPath, l ValueTask previous = default; + CancellationTokenSource? breakLoop = length > 0 ? new() : null; + for (long offset = 0; offset < length; offset += MaxReadPayload) { - // Obtain a buffer before starting the copy to ensure we're not competing - // for buffers with the previous copy. - await s_downloadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - previous = CopyBuffer(previous, offset, MaxReadPayload); + Debug.Assert(breakLoop is not null); + if (!breakLoop.IsCancellationRequested) + { + await s_downloadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + previous = CopyBuffer(previous, offset, MaxReadPayload); + } } await previous.ConfigureAwait(false); @@ -863,31 +1099,46 @@ private async ValueTask DownloadFileAsync(string remotePath, string localPath, l async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) { - byte[]? buffer = null; try { - buffer = ArrayPool.Shared.Rent(length); - do + byte[]? buffer = null; + try { - int bytesRead = await remoteFile.ReadAtAsync(buffer, offset, cancellationToken).ConfigureAwait(false); - if (bytesRead == 0) + if (breakLoop.IsCancellationRequested) { - break; + return; } - RandomAccess.Write(localFile.SafeFileHandle, buffer.AsSpan(0, bytesRead), offset); - length -= bytesRead; - offset += bytesRead; - } while (length > 0); - await previousCopy.ConfigureAwait(false); + buffer = ArrayPool.Shared.Rent(length); + do + { + int bytesRead = await remoteFile.ReadAtAsync(buffer, offset, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + break; + } + RandomAccess.Write(localFile.SafeFileHandle, buffer.AsSpan(0, bytesRead), offset); + length -= bytesRead; + offset += bytesRead; + } while (length > 0); + } + catch + { + breakLoop.Cancel(); + throw; + } + finally + { + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } + s_downloadBufferSemaphore.Release(); + } } finally { - if (buffer != null) - { - ArrayPool.Shared.Return(buffer); - } - s_downloadBufferSemaphore.Release(); + await previousCopy.ConfigureAwait(false); } } } @@ -1058,11 +1309,36 @@ private async Task ReadAllPacketsAsync() private void HandleVersionPacket(ReadOnlySpan packet) { - PacketType type = (PacketType)packet[0]; + PacketReader reader = new(packet); + + PacketType type = reader.ReadPacketType(); + if (type != PacketType.SSH_FXP_VERSION) { throw new SshChannelException($"Expected packet SSH_FXP_VERSION, but received {type}."); } + + uint version = reader.ReadUInt(); + if (version != ProtocolVersion) + { + throw new SshOperationException($"Unsupported protocol version {version}."); + } + + SftpExtension supportedExtensions = default; + while (!reader.Remainder.IsEmpty) + { + string extensionName = reader.ReadString(); + string extensionData = reader.ReadString(); + + switch (extensionName, extensionData) + { + case ("copy-data", "1"): + supportedExtensions |= SftpExtension.CopyData; + break; + } + } + + _supportedExtensions = supportedExtensions & ~_options.DisabledExtensions; } internal ValueTask ReadFileAsync(byte[] handle, long offset, Memory buffer, CancellationToken cancellationToken) diff --git a/src/Tmds.Ssh/SftpClient.cs b/src/Tmds.Ssh/SftpClient.cs index 97e0357..8a63a38 100644 --- a/src/Tmds.Ssh/SftpClient.cs +++ b/src/Tmds.Ssh/SftpClient.cs @@ -48,6 +48,14 @@ enum State // For testing. internal SshClient SshClient => _client; internal bool IsDisposed => _state == State.Disposed; + internal SftpExtension EnabledExtensions + { + get + { + SftpChannel channel = _channel ?? throw new InvalidOperationException(); + return channel.EnabledExtensions; + } + } public SftpClient(string destination, ILoggerFactory? loggerFactory = null, SftpClientOptions? options = null) : this(destination, SshConfigSettings.NoConfig, loggerFactory, options) @@ -175,7 +183,7 @@ private async Task DoOpenAsync(bool explicitConnect, CancellationTo bool success = false; try { - SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, cancellationToken).ConfigureAwait(false); + SftpChannel channel = await _client.OpenSftpChannelAsync(OnChannelAbort, explicitConnect, _options, cancellationToken).ConfigureAwait(false); _channel = channel; success = true; return channel; @@ -267,6 +275,12 @@ public async ValueTask RenameAsync(string oldPath, string newPath, CancellationT await channel.RenameAsync(oldPath, newPath, cancellationToken).ConfigureAwait(false); } + public async ValueTask CopyFileAsync(string sourcePath, string destinationPath, bool overwrite = false, CancellationToken cancellationToken = default) + { + var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); + await channel.CopyFileAsync(sourcePath, destinationPath, overwrite, cancellationToken).ConfigureAwait(false); + } + public async ValueTask GetAttributesAsync(string path, bool followLinks = true, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); diff --git a/src/Tmds.Ssh/SftpClientOptions.cs b/src/Tmds.Ssh/SftpClientOptions.cs index 6be0572..88cd0ea 100644 --- a/src/Tmds.Ssh/SftpClientOptions.cs +++ b/src/Tmds.Ssh/SftpClientOptions.cs @@ -4,4 +4,7 @@ namespace Tmds.Ssh; public sealed partial class SftpClientOptions -{ } +{ + // For testing. + internal SftpExtension DisabledExtensions { get; set; } +} diff --git a/src/Tmds.Ssh/SftpExtension.cs b/src/Tmds.Ssh/SftpExtension.cs new file mode 100644 index 0000000..e1f0998 --- /dev/null +++ b/src/Tmds.Ssh/SftpExtension.cs @@ -0,0 +1,9 @@ +namespace Tmds.Ssh; + +[Flags] +enum SftpExtension +{ + None = 0, + // https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00https://datatracker.ietf.org/doc/html/draft-ietf-secsh-filexfer-extensions-00#section-7 + CopyData = 1 // copy-data 1 +} \ No newline at end of file diff --git a/src/Tmds.Ssh/SshClient.cs b/src/Tmds.Ssh/SshClient.cs index 3528cc1..64f1522 100644 --- a/src/Tmds.Ssh/SshClient.cs +++ b/src/Tmds.Ssh/SshClient.cs @@ -271,13 +271,13 @@ public async Task OpenSftpClientAsync(SftpClientOptions? options = n } } - internal async Task OpenSftpChannelAsync(Action onAbort, bool explicitConnect, CancellationToken cancellationToken) + internal async Task OpenSftpChannelAsync(Action onAbort, bool explicitConnect, SftpClientOptions options, CancellationToken cancellationToken) { SshSession session = await GetSessionAsync(cancellationToken, explicitConnect).ConfigureAwait(false); var channel = await session.OpenSftpClientChannelAsync(onAbort, cancellationToken).ConfigureAwait(false); - var sftpChannel = new SftpChannel(channel); + var sftpChannel = new SftpChannel(channel, options); try { diff --git a/test/Tmds.Ssh.Tests/SftpClientTests.cs b/test/Tmds.Ssh.Tests/SftpClientTests.cs index 53c6a07..5bd4036 100644 --- a/test/Tmds.Ssh.Tests/SftpClientTests.cs +++ b/test/Tmds.Ssh.Tests/SftpClientTests.cs @@ -6,6 +6,7 @@ namespace Tmds.Ssh.Tests; public class SftpClientTests { const int PacketSize = 32768; // roughly amount of bytes sent/received in a single sftp packet. + const int MultiPacketSize = 2 * PacketSize + 1024; private readonly SshServer _sshServer; @@ -54,7 +55,7 @@ public async Task SftpClientCtorFromSshClientSettings() [InlineData(10)] [InlineData(10 * 1024)] // 10 kiB - [InlineData(2 * PacketSize + 1024)] + [InlineData(MultiPacketSize)] [Theory] public async Task ReadWriteFile(int fileSize) { @@ -615,7 +616,7 @@ public async Task UploadDownloadDirectory() [InlineData(0)] [InlineData(10)] - [InlineData(2 * PacketSize + 1024)] + [InlineData(MultiPacketSize)] [Theory] public async Task UploadDownloadFile(int fileSize) { @@ -993,6 +994,102 @@ public async Task CacheLength() } } + [InlineData(0, SftpExtension.CopyData)] + [InlineData(10, SftpExtension.CopyData)] + [InlineData(MultiPacketSize, SftpExtension.CopyData)] + [InlineData(0, SftpExtension.None)] + [InlineData(10, SftpExtension.None)] + [InlineData(MultiPacketSize, SftpExtension.None)] + [SkippableTheory] + public async Task CopyFile(int fileSize, SftpExtension sftpExtensions) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions); + + (string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, fileSize); + + string destinationFileName = $"/tmp/{Path.GetRandomFileName()}"; + await sftpClient.CopyFileAsync(sourceFileName, destinationFileName); + + await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName); + } + + [InlineData(true, SftpExtension.CopyData)] + [InlineData(false, SftpExtension.CopyData)] + [InlineData(true, SftpExtension.None)] + [InlineData(false, SftpExtension.None)] + [SkippableTheory] + public async Task CopyFileOverwrite(bool overwrite, SftpExtension sftpExtensions) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions); + + (string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10); + (string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10); + + Task copyTask = sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite).AsTask(); + + if (overwrite) + { + await copyTask; + } + else + { + await Assert.ThrowsAsync(() => copyTask); + } + + byte[] expectedData = overwrite ? sourceData : destinationData; + await AssertRemoteFileContentEqualsAsync(sftpClient, expectedData, destinationFileName); + } + + [InlineData(SftpExtension.CopyData)] + [InlineData(SftpExtension.None)] + [SkippableTheory] + public async Task CopyFileToSelfDoesntLooseData(SftpExtension sftpExtensions) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions); + + (string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 10); + + await sftpClient.CopyFileAsync(sourceFileName, sourceFileName, overwrite: true); + + await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, sourceFileName); + } + + [InlineData(SftpExtension.CopyData)] + [InlineData(SftpExtension.None)] + [SkippableTheory] + public async Task CopyFileOverwriteToLargerTruncates(SftpExtension sftpExtensions) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(sftpExtensions); + + const int SourceLength = 10; + (string sourceFileName, byte[] sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: SourceLength); + const int DestinationLength = SourceLength + SourceLength; + (string destinationFileName, byte[] destinationData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: DestinationLength); + + await sftpClient.CopyFileAsync(sourceFileName, destinationFileName, overwrite: true).AsTask(); + + await AssertRemoteFileContentEqualsAsync(sftpClient, sourceData, destinationFileName); + } + + private async Task AssertRemoteFileContentEqualsAsync(SftpClient client, byte[] expected, string remoteFileName) + { + using var readFile = await client.OpenFileAsync(remoteFileName, FileAccess.Read); + Assert.NotNull(readFile); + var memoryStream = new MemoryStream(); + await readFile.CopyToAsync(memoryStream); + Assert.Equal(expected, memoryStream.ToArray()); + } + + private async Task<(string filename, byte[] data)> CreateRemoteFileWithRandomDataAsync(SftpClient client, int length) + { + string filename = $"/tmp/{Path.GetRandomFileName()}"; + byte[] data = new byte[10]; + Random.Shared.NextBytes(data); + using var writeFile = await client.CreateNewFileAsync(filename, FileAccess.Write); + await writeFile.WriteAsync(data.AsMemory()); + return (filename, data); + } + [InlineData(true)] [InlineData(false)] [Theory] diff --git a/test/Tmds.Ssh.Tests/SftpExtension.cs b/test/Tmds.Ssh.Tests/SftpExtension.cs new file mode 100644 index 0000000..cb7c86c --- /dev/null +++ b/test/Tmds.Ssh.Tests/SftpExtension.cs @@ -0,0 +1,9 @@ +namespace Tmds.Ssh.Tests; + +// Copy of Tmds.Ssh.SftpExtensions with public access. +[Flags] +public enum SftpExtension +{ + None = Tmds.Ssh.SftpExtension.None, + CopyData = Tmds.Ssh.SftpExtension.CopyData +} \ No newline at end of file diff --git a/test/Tmds.Ssh.Tests/SshServer.cs b/test/Tmds.Ssh.Tests/SshServer.cs index 2243fef..1fec905 100644 --- a/test/Tmds.Ssh.Tests/SshServer.cs +++ b/test/Tmds.Ssh.Tests/SshServer.cs @@ -4,6 +4,7 @@ using Xunit; using Xunit.Abstractions; using Xunit.Sdk; +using SkipException = Xunit.SkipException; namespace Tmds.Ssh.Tests; @@ -347,11 +348,39 @@ public async Task CreateClientAsync(Action? config return client; } - public async Task CreateSftpClientAsync(Action? configureSsh = null, CancellationToken cancellationToken = default, bool connect = true) + public async Task CreateSftpClientAsync(Tmds.Ssh.Tests.SftpExtension enabledExtensions, Action? configureSsh = null, CancellationToken cancellationToken = default) { var settings = CreateSshClientSettings(configureSsh); - var client = new SftpClient(settings); + SftpClientOptions? options = new() + { + DisabledExtensions = (Tmds.Ssh.SftpExtension)~enabledExtensions + }; + + var client = new SftpClient(settings, options: options); + + await client.ConnectAsync(cancellationToken); + + if (client.EnabledExtensions != (Tmds.Ssh.SftpExtension)enabledExtensions) + { + throw new SkipException($"The test server does not support the required {((Tmds.Ssh.SftpExtension)enabledExtensions) & ~client.EnabledExtensions} extensions."); + } + + return client; + } + + public async Task CreateSftpClientAsync(Action? configureSsh = null, Action? configureSftp = null, CancellationToken cancellationToken = default, bool connect = true) + { + var settings = CreateSshClientSettings(configureSsh); + + SftpClientOptions? sftpClientOptions = null; + if (configureSftp is not null) + { + sftpClientOptions = new(); + configureSftp.Invoke(sftpClientOptions); + } + + var client = new SftpClient(settings, options: sftpClientOptions); if (connect) {