Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement preserve #81

Merged
merged 10 commits into from
May 26, 2024
42 changes: 38 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type Client struct {
// Handler called when calling `Close` to clean up any remaining
// resources managed by `Client`.
closeHandler ICloseHandler

// Preserve the remote file permissions, modification time and access time
datadius marked this conversation as resolved.
Show resolved Hide resolved
preserve bool
}

// Connect connects to the remote SSH server, returns error if it couldn't establish a session to the SSH server.
Expand Down Expand Up @@ -315,14 +318,28 @@ func (a *Client) CopyFromRemotePassThru(
remotePath string,
passThru PassThru,
) error {
_, err := a.CopyFromRemoteFileInfos(ctx, w, remotePath, passThru)

return err
}

// CopyFroRemoteFileInfos copies a file from the remote to a given writer and return a FileInfos struct
// containing information about the file such as permissions, the file size, modification time and access time
func (a *Client) CopyFromRemoteFileInfos(
datadius marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) (*FileInfos, error) {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
return nil, fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
}
defer session.Close()

wg := sync.WaitGroup{}
errCh := make(chan error, 4)
fileInfosCh := make(chan *FileInfos, 4)
datadius marked this conversation as resolved.
Show resolved Hide resolved

wg.Add(1)
go func() {
Expand All @@ -338,35 +355,50 @@ func (a *Client) CopyFromRemotePassThru(

r, err := session.StdoutPipe()
if err != nil {
fileInfosCh <- nil
errCh <- err
return
}

in, err := session.StdinPipe()
if err != nil {
fileInfosCh <- nil
errCh <- err
return
}
defer in.Close()

err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
if a.preserve {
err = session.Start(fmt.Sprintf("%s -pf %q", a.RemoteBinary, remotePath))
} else {
err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
}
if err != nil {
fileInfosCh <- nil
errCh <- err
return
}

err = Ack(in)
if err != nil {
fileInfosCh <- nil
errCh <- err
return
}

fileInfo, err := ParseResponse(r, in)
if err != nil {
fileInfosCh <- nil
errCh <- err
return
}

if fileInfo != nil {
fileInfosCh <- fileInfo
} else {
fileInfosCh <- nil
}
datadius marked this conversation as resolved.
Show resolved Hide resolved

err = Ack(in)
if err != nil {
errCh <- err
Expand Down Expand Up @@ -403,11 +435,13 @@ func (a *Client) CopyFromRemotePassThru(
}

if err := wait(&wg, ctx); err != nil {
return err
return nil, err
}

finalErr := <-errCh
fileInfos := <-fileInfosCh
close(errCh)
return finalErr
return fileInfos, finalErr
}

func (a *Client) Close() {
Expand Down
12 changes: 11 additions & 1 deletion configurer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ClientConfigurer struct {
timeout time.Duration
remoteBinary string
sshClient *ssh.Client
preserve bool
datadius marked this conversation as resolved.
Show resolved Hide resolved
}

// NewConfigurer creates a new client configurer.
Expand All @@ -36,6 +37,7 @@ func NewConfigurer(host string, config *ssh.ClientConfig) *ClientConfigurer {
clientConfig: config,
timeout: 0, // no timeout by default
remoteBinary: "scp",
preserve: false,
}
}

Expand Down Expand Up @@ -70,6 +72,13 @@ func (c *ClientConfigurer) SSHClient(sshClient *ssh.Client) *ClientConfigurer {
return c
}

// Preserve alters the preserve flag
// Defaults to false
func (c *ClientConfigurer) Preserve(preserve bool) *ClientConfigurer {
datadius marked this conversation as resolved.
Show resolved Hide resolved
c.preserve = preserve
return c
}

// Create builds a client with the configuration stored within the ClientConfigurer.
func (c *ClientConfigurer) Create() Client {
return Client{
Expand All @@ -78,6 +87,7 @@ func (c *ClientConfigurer) Create() Client {
Timeout: c.timeout,
RemoteBinary: c.remoteBinary,
sshClient: c.sshClient,
closeHandler: EmptyHandler{},
preserve: c.preserve,
closeHandler: EmptyHandler{},
}
}
25 changes: 13 additions & 12 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,16 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
if bufferedReader.Buffered() == 0 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit lost here what this is supposed to mean. Buffered returns the number of bytes remaining in the read buffer, but as far as I see the buffer should already be empty since we read all its information for parsing the time message. We do not make any Acks in between so we do not expect another message back from the remote. My suggestion is to add more comments to this function to make it more clear why certain checks are necessary and what part of the protocol we are considering at each line in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/charmbracelet/wish/blob/e6e9fc4c8a253cf263334efa95d31e7af7034970/scp/scp.go#L128

wish is a popular ssh custom server. Their implementation allows for sending out the time and permissions in the same message. There is an argument that it's not following the protocol, but being so popular, I thought it would be nice to accommodate and take into consideration the instances in which someone could make a custom ssh implementation.

If you believe we shouldn't accommodate that kind of custom implementation, then I will remove that piece of logic to keep it clean.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think you make a good argument for keeping it, but make sure that it does not interfere with implementations that do follow the protocol.

err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}

Expand All @@ -102,7 +98,7 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
type FileInfos struct {
Message string
Filename string
Permissions string
Permissions uint32
Size int64
Atime int64
Mtime int64
Expand All @@ -119,7 +115,7 @@ func (fileInfos *FileInfos) Update(new *FileInfos) {
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
if new.Permissions != 0 {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
Expand All @@ -140,14 +136,19 @@ func ParseFileInfos(message string, fileInfos *FileInfos) error {
return errors.New("unable to parse Chmod protocol")
}

permissions, err := strconv.ParseUint(parts[0][1:], 0, 32)
if err != nil {
return err
}

size, err := strconv.Atoi(parts[1])
if err != nil {
return err
}

fileInfos.Update(&FileInfos{
Filename: parts[2],
Permissions: parts[0],
Permissions: uint32(permissions),
Size: int64(size),
})

Expand All @@ -168,7 +169,7 @@ func ParseFileTime(
if err != nil {
return errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))
mTime, err := strconv.ParseUint(string(parts[2][0:10]), 0, 32)
datadius marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return errors.New("unable to parse MTime component of message")
}
Expand Down
87 changes: 85 additions & 2 deletions tests/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ func establishConnection(t *testing.T) scp.Client {
return client
}

func establishPreserveConnection(t *testing.T) scp.Client {
datadius marked this conversation as resolved.
Show resolved Hide resolved
clientConfig, err := buildClientConfig()
if err != nil {
t.Fatalf("Couldn't build the client configuration: %s", err)
}

// Create a new SCP client.
client := scp.NewConfigurer("127.0.0.1:2244", &clientConfig).Preserve(true).Create()

// Connect to the remote server.
err = client.Connect()
if err != nil {
t.Fatalf("Couldn't establish a connection to the remote server: %s", err)
}
return client
}

// TestCopy tests the basic functionality of copying a file to the remote
// destination.
//
Expand Down Expand Up @@ -196,6 +213,17 @@ func download(client *scp.Client, file *os.File, remotePath string) error {
return client.CopyFromRemote(context.Background(), file, remotePath)
}

func downloadFileInfo(
datadius marked this conversation as resolved.
Show resolved Hide resolved
client *scp.Client,
file *os.File,
remotePath string,
) (*scp.FileInfos, error) {

fileInfos, err := client.CopyFromRemoteFileInfos(context.Background(), file, remotePath, nil)

return fileInfos, err
}

// TestDownloadFile tests the basic functionality of copying a file from the
// remote destination.
//
Expand All @@ -207,7 +235,35 @@ func TestDownloadFile(t *testing.T) {
client := establishConnection(t)
defer client.Close()

// Open a file we can transfer to the remote container.
// Create a local file to write to.
f, err := os.OpenFile("./tmp/output.txt", os.O_RDWR|os.O_CREATE, 0777)
if err != nil {
t.Errorf("Couldn't open the output file")
}
defer f.Close()

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}

content, err := os.ReadFile("./tmp/output.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

text := string(content)
expected := "It works for download!\n"
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}
}

func TestDownloadFileInfo(t *testing.T) {
client := establishPreserveConnection(t)
defer client.Close()
f, _ := os.Open("./data/input.txt")
defer f.Close()

Expand All @@ -220,7 +276,12 @@ func TestDownloadFile(t *testing.T) {

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
fileInfos, err := client.CopyFromRemoteFileInfos(
context.Background(),
f,
"/input/Exöt1ç download file.txt.txt",
nil,
)
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}
Expand All @@ -235,6 +296,28 @@ func TestDownloadFile(t *testing.T) {
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}

fileStat, err := os.Stat("./tmp/output.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

if fileStat.Size() != fileInfos.Size {
t.Errorf("File size does not match")
}

if fileInfos.Mtime == 0 {
t.Errorf("No file mtime preserved")
}

if fileInfos.Atime == 0 {
t.Errorf("No file atime preserved")
}
bramvdbogaerde marked this conversation as resolved.
Show resolved Hide resolved

if fileInfos.Permissions == 0 {
t.Errorf("No file permissions preserved")
}

Copy link
Owner

@bramvdbogaerde bramvdbogaerde Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when atime is parsed in the wrong way so that it is differnet from zero but not the expected value? I would expect a comparison between the received value and the actual value here, would that be difficult to add?

}

// TestTimeoutDownload tests that a timeout error is produced if the file is not copied in the given
Expand Down