Skip to content

Commit

Permalink
SftpFileSystemEnumerator: don't throw when nested directory is no lon…
Browse files Browse the repository at this point in the history
…ger found during recursion (#246)
  • Loading branch information
tmds authored Oct 25, 2024
1 parent 86c1bd7 commit daebaa1
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 26 deletions.
8 changes: 7 additions & 1 deletion src/Tmds.Ssh/SftpChannel.PendingOperation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ internal void HandleReply(SftpChannel channel, ReadOnlySpan<byte> reply)
or (SftpError.NoSuchFile, PacketType.SSH_FXP_STAT // GetAttributes: return null
or PacketType.SSH_FXP_LSTAT // GetAttributes: return null
or PacketType.SSH_FXP_OPEN // OpenFile: return null
or PacketType.SSH_FXP_OPENDIR // OpenDirectory: return null
or PacketType.SSH_FXP_REMOVE // DeleteFile: don't throw
or PacketType.SSH_FXP_RMDIR // DeleteDirectory: don't throw
)
Expand All @@ -139,13 +140,18 @@ internal void HandleReply(SftpChannel channel, ReadOnlySpan<byte> reply)
switch (RequestType, responseType)
{
case (PacketType.SSH_FXP_OPEN, _):
{
SftpFile? file = error == SftpError.NoSuchFile ? null : new SftpFile(channel, handle: reader.ReadStringAsBytes(), (FileOpenOptions)Options!);
Options = null;
SetResult(file);
return;
}
case (PacketType.SSH_FXP_OPENDIR, _):
SetResult(new SftpFile(channel, handle: reader.ReadStringAsBytes(), SftpClient.DefaultFileOpenOptions));
{
SftpFile? file = error == SftpError.NoSuchFile ? null : new SftpFile(channel, handle: reader.ReadStringAsBytes(), SftpClient.DefaultFileOpenOptions);
SetResult(file);
return;
}
case (PacketType.SSH_FXP_STAT, _):
case (PacketType.SSH_FXP_LSTAT, _):
case (PacketType.SSH_FXP_FSTAT, _):
Expand Down
4 changes: 2 additions & 2 deletions src/Tmds.Ssh/SftpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ public ValueTask CreateSymbolicLinkAsync(string linkPath, string targetPath, boo
return ExecuteAsync(packet, id, pendingOperation, cancellationToken);
}

public ValueTask<SftpFile> OpenDirectoryAsync(string path, CancellationToken cancellationToken = default)
public ValueTask<SftpFile?> OpenDirectoryAsync(string path, CancellationToken cancellationToken = default)
{
PacketType packetType = PacketType.SSH_FXP_OPENDIR;

Expand All @@ -549,7 +549,7 @@ public ValueTask<SftpFile> OpenDirectoryAsync(string path, CancellationToken can
packet.WriteString(path);

// note: Return as 'SftpFile' so it gets Disposed in case the open is cancelled.
return ExecuteAsync<SftpFile>(packet, id, pendingOperation, cancellationToken);
return ExecuteAsync<SftpFile?>(packet, id, pendingOperation, cancellationToken);
}

public async ValueTask CreateDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = SftpClient.DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default)
Expand Down
6 changes: 0 additions & 6 deletions src/Tmds.Ssh/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,6 @@ public async ValueTask CreateSymbolicLinkAsync(string linkPath, string targetPat
public IAsyncEnumerable<T> GetDirectoryEntriesAsync<T>(string path, SftpFileEntryTransform<T> transform, EnumerationOptions? options = null)
=> new SftpFileSystemEnumerable<T>(this, path, transform, options ?? DefaultEnumerationOptions);

internal async ValueTask<SftpFile> OpenDirectoryAsync(string path, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
return await channel.OpenDirectoryAsync(path, cancellationToken);
}

public ValueTask CreateDirectoryAsync(string path, CancellationToken cancellationToken)
=> CreateDirectoryAsync(path, createParents: false, DefaultCreateDirectoryPermissions, cancellationToken);

Expand Down
65 changes: 49 additions & 16 deletions src/Tmds.Ssh/SftpFileSystemEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ sealed class SftpFileSystemEnumerator<T> : IAsyncEnumerator<T>
private readonly SftpFileEntryPredicate? _shouldInclude;

private Queue<string>? _pending;
private string _path;
private string? _path;
private SftpChannel? _channel;
private SftpFile? _fileHandle;

Expand Down Expand Up @@ -141,12 +141,13 @@ private async ValueTask<bool> TryReadNewBufferAsync()
{
_fileHandle!.Dispose();
_fileHandle = null;
_path = null;
}

if (_pending?.TryDequeue(out string? path) == true)
{
_path = path;
}
else
if (_fileHandle is null)
{
await OpenFileHandleAsync();
if (_fileHandle is null)
{
return false;
}
Expand All @@ -158,16 +159,7 @@ private async ValueTask<bool> TryReadNewBufferAsync()

private async ValueTask ReadNewBufferAsync()
{
if (_fileHandle is null)
{
if (_channel is null)
{
Debug.Assert(_client is not null);
_channel = await _client.GetChannelAsync(_cancellationToken);
}
_fileHandle = await _channel.OpenDirectoryAsync(_path, _cancellationToken).ConfigureAwait(false);
_readAhead = _channel.ReadDirAsync(_fileHandle, _cancellationToken);
}
Debug.Assert(_fileHandle is not null);
Debug.Assert(_channel is not null);

const int CountIndex = 4 /* packet length */ + 1 /* packet type */ + 4 /* id */;
Expand All @@ -186,8 +178,48 @@ private async ValueTask ReadNewBufferAsync()
}
}

private async ValueTask OpenFileHandleAsync()
{
if (_channel is null)
{
Debug.Assert(_client is not null);
_channel = await _client.GetChannelAsync(_cancellationToken);
}

string? path = _path;
bool isRootPath = path is not null; // path passed to the constructor.
do
{
if (!isRootPath)
{
if (_pending?.TryDequeue(out path) != true)
{
return;
}
}
Debug.Assert(path is not null);

_fileHandle = await _channel.OpenDirectoryAsync(path, _cancellationToken).ConfigureAwait(false);
if (_fileHandle is null)
{
if (isRootPath)
{
throw new SftpException(SftpError.NoSuchFile);
}
else
{
continue;
}
}
_path = path;
_readAhead = _channel.ReadDirAsync(_fileHandle, _cancellationToken);
return;
} while (true);
}

private bool ReadNextEntry(bool followLink, out string? linkPath, out Memory<byte> linkEntry)
{
Debug.Assert(_path is not null);
int startOffset = _bufferOffset;
SftpFileEntry entry = new SftpFileEntry(_path, _readDirPacket.AsSpan(startOffset), _pathBuffer, _nameBuffer, out int entryLength);

Expand Down Expand Up @@ -238,6 +270,7 @@ private bool SetCurrent(ref SftpFileEntry entry)

private async Task<bool> ReadLinkTargetEntry(string linkPath, Memory<byte> linkEntry)
{
Debug.Assert(_path is not null);
FileEntryAttributes? attributes = await _channel!.GetAttributesAsync(linkPath, followLinks: true, _cancellationToken).ConfigureAwait(false);
if (attributes is not null)
{
Expand Down
48 changes: 47 additions & 1 deletion test/Tmds.Ssh.Tests/SftpClientTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Xunit;
using Xunit.Abstractions;

namespace Tmds.Ssh.Tests;

Expand All @@ -9,10 +10,17 @@ public class SftpClientTests
const int MultiPacketSize = 2 * PacketSize + 1024;

private readonly SshServer _sshServer;
private readonly ITestOutputHelper _output;

public SftpClientTests(SshServer sshServer)
private void WriteMessage(string message)
{
_output.WriteLine(message);
}

public SftpClientTests(SshServer sshServer, ITestOutputHelper output)
{
_sshServer = sshServer;
_output = output;
}

[Fact]
Expand Down Expand Up @@ -538,6 +546,44 @@ public async Task EnumerateRootDirectory()
}
}

[Fact]
public async Task EnumerateRootNotFound()
{
using var sftpClient = await _sshServer.CreateSftpClientAsync();

string path = "/no_such_dir";
var exception = await Assert.ThrowsAsync<SftpException>(() => sftpClient.GetDirectoryEntriesAsync(path).ToListAsync().AsTask());
Assert.Equal(SftpError.NoSuchFile, exception.Error);
}

[Fact]
public async Task EnumerateNestedDirNotFoundDoesNotThrow()
{
using var sftpClient = await _sshServer.CreateSftpClientAsync();

string directoryPath = $"/tmp/{Path.GetRandomFileName()}";
await sftpClient.CreateNewDirectoryAsync(directoryPath);
string childDirectoryPath = $"{directoryPath}/child";
await sftpClient.CreateDirectoryAsync(childDirectoryPath);
string childChildDirectoryPath = $"{childDirectoryPath}/nestedchild";
await sftpClient.CreateDirectoryAsync(childChildDirectoryPath);

bool childDirWasReturned = false;
int count = 0;
await foreach (var entry in sftpClient.GetDirectoryEntriesAsync(directoryPath, new Tmds.Ssh.EnumerationOptions() { RecurseSubdirectories = true }))
{
count++;
childDirWasReturned = entry.Path == childDirectoryPath;

// Delete the nested directories before we recurse into them.
await sftpClient.DeleteDirectoryAsync(childChildDirectoryPath);
await sftpClient.DeleteDirectoryAsync(childDirectoryPath);
}

Assert.True(childDirWasReturned);
Assert.Equal(1, count);
}

[Fact]
public void DefaultEnumerationOptions()
{
Expand Down

0 comments on commit daebaa1

Please sign in to comment.