Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement preserve #81

Merged
merged 10 commits into from
May 26, 2024
39 changes: 35 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,38 @@ func (a *Client) CopyFromRemotePassThru(
remotePath string,
passThru PassThru,
) error {
_, err := a.copyFromRemote(ctx, w, remotePath, passThru, false)

return err
}

// CopyFroRemoteFileInfos copies a file from the remote to a given writer and return a FileInfos struct
// containing information about the file such as permissions, the file size, modification time and access time
func (a *Client) CopyFromRemoteFileInfos(
datadius marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) (*FileInfos, error) {
return a.copyFromRemote(ctx, w, remotePath, passThru, true)
}

func (a *Client) copyFromRemote(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
preserveFileTimes bool,
) (*FileInfos, error) {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
return nil, fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
}
defer session.Close()

wg := sync.WaitGroup{}
errCh := make(chan error, 4)
var fileInfos *FileInfos

wg.Add(1)
go func() {
Expand All @@ -349,7 +373,11 @@ func (a *Client) CopyFromRemotePassThru(
}
defer in.Close()

err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
if preserveFileTimes {
err = session.Start(fmt.Sprintf("%s -pf %q", a.RemoteBinary, remotePath))
} else {
err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
}
if err != nil {
errCh <- err
return
Expand All @@ -367,6 +395,8 @@ func (a *Client) CopyFromRemotePassThru(
return
}

fileInfos = fileInfo

err = Ack(in)
if err != nil {
errCh <- err
Expand Down Expand Up @@ -403,11 +433,12 @@ func (a *Client) CopyFromRemotePassThru(
}

if err := wait(&wg, ctx); err != nil {
return err
return nil, err
}

finalErr := <-errCh
close(errCh)
return finalErr
return fileInfos, finalErr
}

func (a *Client) Close() {
Expand Down
2 changes: 1 addition & 1 deletion configurer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ func (c *ClientConfigurer) Create() Client {
Timeout: c.timeout,
RemoteBinary: c.remoteBinary,
sshClient: c.sshClient,
closeHandler: EmptyHandler{},
closeHandler: EmptyHandler{},
}
}
37 changes: 24 additions & 13 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,19 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
// A custom ssh server can send both time, permissions and size information at once
// without needing an Ack response. Example: wish from charmbracelet while using their default scp implementation
// If the buffer is empty, then it's likely the default implementation for ssh, so send Ack
if bufferedReader.Buffered() == 0 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit lost here what this is supposed to mean. Buffered returns the number of bytes remaining in the read buffer, but as far as I see the buffer should already be empty since we read all its information for parsing the time message. We do not make any Acks in between so we do not expect another message back from the remote. My suggestion is to add more comments to this function to make it more clear why certain checks are necessary and what part of the protocol we are considering at each line in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/charmbracelet/wish/blob/e6e9fc4c8a253cf263334efa95d31e7af7034970/scp/scp.go#L128

wish is a popular ssh custom server. Their implementation allows for sending out the time and permissions in the same message. There is an argument that it's not following the protocol, but being so popular, I thought it would be nice to accommodate and take into consideration the instances in which someone could make a custom ssh implementation.

If you believe we shouldn't accommodate that kind of custom implementation, then I will remove that piece of logic to keep it clean.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I think you make a good argument for keeping it, but make sure that it does not interfere with implementations that do follow the protocol.

err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}

Expand All @@ -102,7 +101,7 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
type FileInfos struct {
Message string
Filename string
Permissions string
Permissions uint32
Size int64
Atime int64
Mtime int64
Expand All @@ -119,7 +118,7 @@ func (fileInfos *FileInfos) Update(new *FileInfos) {
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
if new.Permissions != 0 {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
Expand All @@ -140,14 +139,19 @@ func ParseFileInfos(message string, fileInfos *FileInfos) error {
return errors.New("unable to parse Chmod protocol")
}

permissions, err := strconv.ParseUint(parts[0][1:], 0, 32)
if err != nil {
return err
}

size, err := strconv.Atoi(parts[1])
if err != nil {
return err
}

fileInfos.Update(&FileInfos{
Filename: parts[2],
Permissions: parts[0],
Permissions: uint32(permissions),
Size: int64(size),
})

Expand All @@ -164,11 +168,18 @@ func ParseFileTime(
return errors.New("unable to parse Time protocol")
}

aTime, err := strconv.Atoi(string(parts[0][0:10]))
if len(parts[0]) != 10 {
return errors.New("length of ATime is not 10")
}
mTime, err := strconv.Atoi(parts[0][0:10])
if err != nil {
return errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))

if len(parts[2]) != 10 {
return errors.New("length of MTime is not 10")
}
aTime, err := strconv.Atoi(parts[2][0:10])
if err != nil {
return errors.New("unable to parse MTime component of message")
}
Expand Down
63 changes: 61 additions & 2 deletions tests/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scp
import (
"context"
"fmt"
"io/fs"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -207,7 +208,35 @@ func TestDownloadFile(t *testing.T) {
client := establishConnection(t)
defer client.Close()

// Open a file we can transfer to the remote container.
// Create a local file to write to.
f, err := os.OpenFile("./tmp/output.txt", os.O_RDWR|os.O_CREATE, 0777)
if err != nil {
t.Errorf("Couldn't open the output file")
}
defer f.Close()

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}

content, err := os.ReadFile("./tmp/output.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

text := string(content)
expected := "It works for download!\n"
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}
}

func TestDownloadFileInfo(t *testing.T) {
client := establishConnection(t)
defer client.Close()
f, _ := os.Open("./data/input.txt")
defer f.Close()

Expand All @@ -220,7 +249,12 @@ func TestDownloadFile(t *testing.T) {

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
fileInfos, err := client.CopyFromRemoteFileInfos(
context.Background(),
f,
"/input/Exöt1ç download file.txt.txt",
nil,
)
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}
Expand All @@ -235,6 +269,31 @@ func TestDownloadFile(t *testing.T) {
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}

fileStat, err := os.Stat("./data/Exöt1ç download file.txt.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

if fileInfos.Size != fileStat.Size() {
t.Errorf("File size does not match")
}

if fs.FileMode(fileInfos.Permissions) == fs.FileMode(0777) {
t.Errorf(
"File permissions don't match %s vs %s",
fs.FileMode(fileInfos.Permissions),
fileStat.Mode().Perm(),
)
}

if fileInfos.Mtime != fileStat.ModTime().Unix() {
t.Errorf(
"File modification time does not match %d vs %d",
fileInfos.Mtime,
fileStat.ModTime().Unix(),
)
}
bramvdbogaerde marked this conversation as resolved.
Show resolved Hide resolved
}

// TestTimeoutDownload tests that a timeout error is produced if the file is not copied in the given
Expand Down