From cc9ce5c5cd7845da8d7fa488418f961788309cec Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 18 Jun 2022 16:25:30 +0800 Subject: [PATCH] support transfer directories --- trzsz/comm.go | 117 +++++++++++-- trzsz/progress.go | 2 +- trzsz/pty_unix.go | 9 + trzsz/pty_windows.go | 41 +++-- trzsz/transfer.go | 403 +++++++++++++++++++++++++++++++------------ trzsz/trz.go | 21 ++- trzsz/trzsz.go | 39 +++-- trzsz/tsz.go | 36 ++-- trzsz/version.go | 2 +- 9 files changed, 503 insertions(+), 167 deletions(-) diff --git a/trzsz/comm.go b/trzsz/comm.go index 3c69efc..40da6fd 100644 --- a/trzsz/comm.go +++ b/trzsz/comm.go @@ -71,7 +71,7 @@ type ProgressCallback interface { onName(name string) onSize(size int64) onStep(step int64) - onDone(name string) + onDone() } type BufferSize struct { @@ -83,8 +83,9 @@ type Args struct { Overwrite bool `arg:"-y" help:"yes, overwrite existing file(s)"` Binary bool `arg:"-b" help:"binary transfer mode, faster for binary files"` Escape bool `arg:"-e" help:"escape all known control characters"` + Directory bool `arg:"-d" help:"transfer directories and files"` Bufsize BufferSize `arg:"-B" placeholder:"N" default:"10M" help:"max buffer chunk size (1K<=N<=1G). (default: 10M)"` - Timeout int `arg:"-t" placeholder:"N" default:"100" help:"timeout ( N seconds ) for each buffer chunk.\nN <= 0 means never timeout. (default: 100)"` + Timeout int `arg:"-t" placeholder:"N" default:"10" help:"timeout ( N seconds ) for each buffer chunk.\nN <= 0 means never timeout. (default: 10)"` } var sizeRegexp = regexp.MustCompile("(?i)^(\\d+)(b|k|m|g|kb|mb|gb)?$") @@ -196,38 +197,107 @@ func (e *TrzszError) isRemoteFail() bool { } func checkPathWritable(path string) error { - fileInfo, err := os.Stat(path) + info, err := os.Stat(path) if errors.Is(err, os.ErrNotExist) { return newTrzszError(fmt.Sprintf("No such directory: %s", path)) + } else if err != nil { + return err } - if !fileInfo.IsDir() { + if !info.IsDir() { return newTrzszError(fmt.Sprintf("Not a directory: %s", path)) } - if !IsWindows() { - if fileInfo.Mode().Perm()&(1<<7) == 0 { - return newTrzszError(fmt.Sprintf("No permission to write: %s", path)) - } + if syscallAccessWok(path) != nil { + return newTrzszError(fmt.Sprintf("No permission to write: %s", path)) } return nil } -func checkFilesReadable(files []string) error { +type TrzszFile struct { + PathID int `json:"path_id"` + AbsPath string `json:"-"` + RelPath []string `json:"path_name"` + IsDir bool `json:"is_dir"` +} + +func resolveLink(path string) string { + for { + p, err := os.Readlink(path) + if err != nil { + return path + } + path = p + } +} + +func checkPathReadable(pathID int, path string, info os.FileInfo, list *[]*TrzszFile, relPath []string, visitedDir map[string]bool) error { + if !info.IsDir() { + if !info.Mode().IsRegular() { + return newTrzszError(fmt.Sprintf("Not a regular file: %s", path)) + } + if syscallAccessRok(path) != nil { + return newTrzszError(fmt.Sprintf("No permission to read: %s", path)) + } + *list = append(*list, &TrzszFile{pathID, path, relPath, false}) + return nil + } + realPath := resolveLink(path) + if _, ok := visitedDir[realPath]; ok { + return newTrzszError(fmt.Sprintf("Loop link: %s", path)) + } + visitedDir[realPath] = true + *list = append(*list, &TrzszFile{pathID, path, relPath, true}) + f, err := os.Open(path) + if err != nil { + return newTrzszError(fmt.Sprintf("Open [%s] error: %v", path, err)) + } + files, err := f.Readdir(-1) + if err != nil { + return newTrzszError(fmt.Sprintf("Readdir [%s] error: %v", path, err)) + } for _, file := range files { - fileInfo, err := os.Stat(file) + p := filepath.Join(path, file.Name()) + r := make([]string, len(relPath)) + copy(r, relPath) + r = append(r, file.Name()) + if err := checkPathReadable(pathID, p, file, list, r, visitedDir); err != nil { + return err + } + } + return nil +} + +func checkPathsReadable(paths []string, directory bool) ([]*TrzszFile, error) { + var list []*TrzszFile + visitedDir := make(map[string]bool) + for i, p := range paths { + path, err := filepath.Abs(p) + if err != nil { + return nil, err + } + info, err := os.Stat(path) if errors.Is(err, os.ErrNotExist) { - return newTrzszError(fmt.Sprintf("No such file: %s", file)) + return nil, newTrzszError(fmt.Sprintf("No such file: %s", path)) + } else if err != nil { + return nil, err } - if fileInfo.IsDir() { - return newTrzszError(fmt.Sprintf("Is a directory: %s", file)) + if !directory && info.IsDir() { + return nil, newTrzszError(fmt.Sprintf("Is a directory: %s", path)) } - if !fileInfo.Mode().IsRegular() { - return newTrzszError(fmt.Sprintf("Not a regular file: %s", file)) + if err := checkPathReadable(i, path, info, &list, []string{info.Name()}, visitedDir); err != nil { + return nil, err } - if !IsWindows() { - if fileInfo.Mode().Perm()&(1<<8) == 0 { - return newTrzszError(fmt.Sprintf("No permission to read: %s", file)) - } + } + return list, nil +} + +func checkDuplicateNames(list []*TrzszFile) error { + m := make(map[string]bool) + for _, f := range list { + p := filepath.Join(f.RelPath...) + if _, ok := m[p]; ok { + return newTrzszError(fmt.Sprintf("Duplicate name: %s", p)) } + m[p] = true } return nil } @@ -367,3 +437,12 @@ func trimVT100(buf []byte) []byte { } return b.Bytes() } + +func containsString(elems []string, v string) bool { + for _, s := range elems { + if v == s { + return true + } + } + return false +} diff --git a/trzsz/progress.go b/trzsz/progress.go index 9586f72..6fd37a8 100644 --- a/trzsz/progress.go +++ b/trzsz/progress.go @@ -192,7 +192,7 @@ func (p *TextProgressBar) onStep(step int64) { p.showProgress() } -func (p *TextProgressBar) onDone(name string) { +func (p *TextProgressBar) onDone() { } func (p *TextProgressBar) showProgress() { diff --git a/trzsz/pty_unix.go b/trzsz/pty_unix.go index 6bb59fc..eaf7506 100644 --- a/trzsz/pty_unix.go +++ b/trzsz/pty_unix.go @@ -33,6 +33,7 @@ import ( "syscall" "github.com/creack/pty" + "golang.org/x/sys/unix" ) type TrzszPty struct { @@ -109,3 +110,11 @@ func (t *TrzszPty) Terminate() { func (t *TrzszPty) ExitCode() int { return t.cmd.ProcessState.ExitCode() } + +func syscallAccessWok(path string) error { + return syscall.Access(path, unix.W_OK) +} + +func syscallAccessRok(path string) error { + return syscall.Access(path, unix.R_OK) +} diff --git a/trzsz/pty_windows.go b/trzsz/pty_windows.go index 3e19d8d..2f57e49 100644 --- a/trzsz/pty_windows.go +++ b/trzsz/pty_windows.go @@ -37,15 +37,16 @@ import ( ) type TrzszPty struct { - Stdin PtyIO - Stdout PtyIO - cpty *conpty.ConPty - inMode uint32 - outMode uint32 - width int - height int - closed bool - exitCode *uint32 + Stdin PtyIO + Stdout PtyIO + cpty *conpty.ConPty + inMode uint32 + outMode uint32 + width int + height int + closed bool + exitCode *uint32 + startTime time.Time } func getConsoleSize() (int, int, error) { @@ -134,7 +135,7 @@ func Spawn(name string, args ...string) (*TrzszPty, error) { return nil, err } - return &TrzszPty{cpty, cpty, cpty, inMode, outMode, width, height, false, nil}, nil + return &TrzszPty{cpty, cpty, cpty, inMode, outMode, width, height, false, nil, time.Now()}, nil } func (t *TrzszPty) OnResize(cb func(int)) { @@ -148,13 +149,15 @@ func (t *TrzszPty) Close() { if t.closed { return } + t.closed = true t.cpty.Close() resetVirtualTerminal(t.inMode, t.outMode) - time.Sleep(100 * time.Millisecond) - cmd := exec.Command("cmd", "/c", "cls") - cmd.Stdout = os.Stdout - cmd.Run() - t.closed = true + if time.Now().Sub(t.startTime) > 10*time.Second { + time.Sleep(100 * time.Millisecond) + cmd := exec.Command("cmd", "/c", "cls") + cmd.Stdout = os.Stdout + cmd.Run() + } } func (t *TrzszPty) Wait() { @@ -172,3 +175,11 @@ func (t *TrzszPty) ExitCode() int { } return int(*t.exitCode) } + +func syscallAccessWok(path string) error { + return nil +} + +func syscallAccessRok(path string) error { + return nil +} diff --git a/trzsz/transfer.go b/trzsz/transfer.go index 490b126..ff54470 100644 --- a/trzsz/transfer.go +++ b/trzsz/transfer.go @@ -28,6 +28,7 @@ import ( "bytes" "crypto/md5" "encoding/json" + "errors" "fmt" "io/fs" "os" @@ -51,6 +52,7 @@ type TrzszTransfer struct { transferConfig map[string]interface{} protocolNewline string stdinState *term.State + fileNameMap map[int]string } func maxDuration(a, b time.Duration) time.Duration { @@ -75,9 +77,11 @@ func NewTransfer(writer PtyIO, stdinState *term.State) *TrzszTransfer { false, nil, 100 * time.Millisecond, - 0, make(map[string]interface{}), + 0, + make(map[string]interface{}), "\n", stdinState, + make(map[int]string), } } @@ -295,9 +299,10 @@ func (t *TrzszTransfer) recvData(binary bool, escapeCodes [][]byte, timeout time func (t *TrzszTransfer) sendAction(confirm, remoteIsWindows bool) error { actMap := map[string]interface{}{ - "lang": "go", - "confirm": confirm, - "version": kTrzszVersion, + "lang": "go", + "confirm": confirm, + "version": kTrzszVersion, + "support_dir": true, } if IsWindows() { actMap["binary"] = false @@ -339,6 +344,9 @@ func (t *TrzszTransfer) sendConfig(args *Args, escapeChars [][]unicode, tmuxMode cfgMap["binary"] = true cfgMap["escape_chars"] = escapeChars } + if args.Directory { + cfgMap["directory"] = true + } cfgMap["bufsize"] = args.Bufsize.Size cfgMap["timeout"] = args.Timeout if args.Overwrite { @@ -373,7 +381,7 @@ func (t *TrzszTransfer) recvConfig() (map[string]interface{}, error) { } func (t *TrzszTransfer) clientExit(msg string) error { - t.cleanInput(200) + t.cleanInput(200 * time.Millisecond) return t.sendString("EXIT", msg) } @@ -382,7 +390,7 @@ func (t *TrzszTransfer) recvExit() (string, error) { } func (t *TrzszTransfer) serverExit(msg string) { - t.cleanInput(200) + t.cleanInput(500 * time.Millisecond) if t.stdinState != nil { term.Restore(int(os.Stdin.Fd()), t.stdinState) } @@ -433,11 +441,90 @@ func (t *TrzszTransfer) serverError(err error) { t.serverExit(err.Error()) } -func (t *TrzszTransfer) sendFiles(files []string, progress ProgressCallback) ([]string, error) { +func (t *TrzszTransfer) sendFileNum(num int64, progress ProgressCallback) error { + if err := t.sendInteger("NUM", num); err != nil { + return err + } + if err := t.checkInteger(num); err != nil { + return err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onNum(num) + } + return nil +} + +func (t *TrzszTransfer) sendFileName(f *TrzszFile, directory bool, progress ProgressCallback) (*os.File, string, error) { + var fileName string + if directory { + jsonName, err := json.Marshal(f) + if err != nil { + return nil, "", err + } + fileName = string(jsonName) + } else { + fileName = f.RelPath[0] + } + if err := t.sendString("NAME", fileName); err != nil { + return nil, "", err + } + remoteName, err := t.recvString("SUCC", false) + if err != nil { + return nil, "", err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onName(f.RelPath[len(f.RelPath)-1]) + } + if f.IsDir { + return nil, remoteName, nil + } + file, err := os.Open(f.AbsPath) + if err != nil { + return nil, "", err + } + return file, remoteName, nil +} + +func (t *TrzszTransfer) sendFileSize(file *os.File, progress ProgressCallback) (int64, error) { + stat, err := file.Stat() + if err != nil { + return 0, err + } + size := stat.Size() + if err := t.sendInteger("SIZE", size); err != nil { + return 0, err + } + if err := t.checkInteger(size); err != nil { + return 0, err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onSize(size) + } + return size, nil +} + +func (t *TrzszTransfer) sendFileMD5(digest []byte, progress ProgressCallback) error { + if err := t.sendBinary("MD5", digest); err != nil { + return err + } + if err := t.checkBinary(digest); err != nil { + return err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onDone() + } + return nil +} + +func (t *TrzszTransfer) sendFiles(files []*TrzszFile, progress ProgressCallback) ([]string, error) { binary := false if v, ok := t.transferConfig["binary"].(bool); ok { binary = v } + directory := false + if v, ok := t.transferConfig["directory"].(bool); ok { + directory = v + } maxBufSize := int64(10 * 1024 * 1024) if v, ok := t.transferConfig["bufsize"].(float64); ok { maxBufSize = int64(v) @@ -451,59 +538,39 @@ func (t *TrzszTransfer) sendFiles(files []string, progress ProgressCallback) ([] } } - num := int64(len(files)) - if err := t.sendInteger("NUM", num); err != nil { - return nil, err - } - if err := t.checkInteger(num); err != nil { + if err := t.sendFileNum(int64(len(files)), progress); err != nil { return nil, err } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onNum(num) - } bufSize := int64(1024) buffer := make([]byte, bufSize) - remoteNames := make([]string, len(files)) - for i, file := range files { - fileName := filepath.Base(file) - if err := t.sendString("NAME", fileName); err != nil { - return nil, err - } - remoteName, err := t.recvString("SUCC", false) + var remoteNames []string + for _, f := range files { + file, remoteName, err := t.sendFileName(f, directory, progress) if err != nil { return nil, err } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onName(fileName) - } - f, err := os.Open(file) - if err != nil { - return nil, err - } - defer f.Close() - stat, err := f.Stat() - if err != nil { - return nil, err + if !containsString(remoteNames, remoteName) { + remoteNames = append(remoteNames, remoteName) } - fileSize := stat.Size() - if err := t.sendInteger("SIZE", fileSize); err != nil { - return nil, err + if file == nil { + continue } - if err := t.checkInteger(fileSize); err != nil { + + defer file.Close() + + fileSize, err := t.sendFileSize(file, progress) + if err != nil { return nil, err } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onSize(fileSize) - } step := int64(0) hasher := md5.New() for step < fileSize { beginTime := time.Now() - n, err := f.Read(buffer) + n, err := file.Read(buffer) if err != nil { return nil, err } @@ -531,21 +598,184 @@ func (t *TrzszTransfer) sendFiles(files []string, progress ProgressCallback) ([] } } - digest := hasher.Sum(nil) - if err := t.sendBinary("MD5", digest); err != nil { + if err := t.sendFileMD5(hasher.Sum(nil), progress); err != nil { return nil, err } - if err := t.checkBinary(digest); err != nil { - return nil, err + } + + return remoteNames, nil +} + +func (t *TrzszTransfer) recvFileNum(progress ProgressCallback) (int64, error) { + num, err := t.recvInteger("NUM", false) + if err != nil { + return 0, err + } + if err := t.sendInteger("SUCC", num); err != nil { + return 0, err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onNum(num) + } + return num, nil +} + +func doCreateFile(path string) (*os.File, error) { + file, err := os.Create(path) + if err != nil { + if e, ok := err.(*fs.PathError); ok { + if errno, ok := e.Err.(syscall.Errno); ok { + if errno == 13 { + return nil, newTrzszError(fmt.Sprintf("No permission to write: %s", path)) + } else if errno == 21 { + return nil, newTrzszError(fmt.Sprintf("Is a directory: %s", path)) + } + } } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onDone(remoteName) + return nil, err + } + return file, nil +} + +func doCreateDirectory(path string) error { + stat, err := os.Stat(path) + if errors.Is(err, os.ErrNotExist) { + return os.MkdirAll(path, 0755) + } else if err != nil { + return err + } + if !stat.IsDir() { + return newTrzszError(fmt.Sprintf("Not a directory: %s", path)) + } + return nil +} + +func (t *TrzszTransfer) createFile(path, fileName string, overwrite bool) (*os.File, string, error) { + var localName string + if overwrite { + localName = fileName + } else { + var err error + localName, err = getNewName(path, fileName) + if err != nil { + return nil, "", err } + } + file, err := doCreateFile(filepath.Join(path, localName)) + if err != nil { + return nil, "", err + } + return file, localName, nil +} - remoteNames[i] = remoteName +func (t *TrzszTransfer) createDirOrFile(path, name string, overwrite bool) (*os.File, string, string, error) { + var f TrzszFile + if err := json.Unmarshal([]byte(name), &f); err != nil { + return nil, "", "", err + } + if len(f.RelPath) < 1 { + return nil, "", "", newTrzszError(fmt.Sprintf("Invalid name: %s", name)) } - return remoteNames, nil + fileName := f.RelPath[len(f.RelPath)-1] + + var localName string + if overwrite { + localName = f.RelPath[0] + } else { + if v, ok := t.fileNameMap[f.PathID]; ok { + localName = v + } else { + var err error + localName, err = getNewName(path, f.RelPath[0]) + if err != nil { + return nil, "", "", err + } + t.fileNameMap[f.PathID] = localName + } + } + + var fullPath string + if len(f.RelPath) > 1 { + p := filepath.Join(append([]string{path, localName}, f.RelPath[1:len(f.RelPath)-1]...)...) + if err := doCreateDirectory(p); err != nil { + return nil, "", "", err + } + fullPath = filepath.Join(p, fileName) + } else { + fullPath = filepath.Join(path, localName) + } + + if f.IsDir { + if err := doCreateDirectory(fullPath); err != nil { + return nil, "", "", err + } + return nil, localName, fileName, nil + } + + file, err := doCreateFile(fullPath) + if err != nil { + return nil, "", "", err + } + return file, localName, fileName, nil +} + +func (t *TrzszTransfer) recvFileName(path string, directory, overwrite bool, progress ProgressCallback) (*os.File, string, error) { + fileName, err := t.recvString("NAME", false) + if err != nil { + return nil, "", err + } + + var file *os.File + var localName string + if directory { + file, localName, fileName, err = t.createDirOrFile(path, fileName, overwrite) + } else { + file, localName, err = t.createFile(path, fileName, overwrite) + } + if err != nil { + return nil, "", err + } + + if err := t.sendString("SUCC", localName); err != nil { + return nil, "", err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onName(fileName) + } + + return file, localName, nil +} + +func (t *TrzszTransfer) recvFileSize(progress ProgressCallback) (int64, error) { + size, err := t.recvInteger("SIZE", false) + if err != nil { + return 0, err + } + if err := t.sendInteger("SUCC", size); err != nil { + return 0, err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onSize(size) + } + return size, nil +} + +func (t *TrzszTransfer) recvFileMD5(digest []byte, progress ProgressCallback) error { + expectDigest, err := t.recvBinary("MD5", false, nil) + if err != nil { + return err + } + if bytes.Compare(digest, expectDigest) != 0 { + return newTrzszError("Check MD5 failed") + } + if err := t.sendBinary("SUCC", digest); err != nil { + return err + } + if progress != nil && !reflect.ValueOf(progress).IsNil() { + progress.onDone() + } + return nil } func (t *TrzszTransfer) recvFiles(path string, progress ProgressCallback) ([]string, error) { @@ -553,6 +783,10 @@ func (t *TrzszTransfer) recvFiles(path string, progress ProgressCallback) ([]str if v, ok := t.transferConfig["binary"].(bool); ok { binary = v } + directory := false + if v, ok := t.transferConfig["directory"].(bool); ok { + directory = v + } overwrite := false if v, ok := t.transferConfig["overwrite"].(bool); ok { overwrite = v @@ -570,62 +804,32 @@ func (t *TrzszTransfer) recvFiles(path string, progress ProgressCallback) ([]str } } - num, err := t.recvInteger("NUM", false) + num, err := t.recvFileNum(progress) if err != nil { return nil, err } - if err := t.sendInteger("SUCC", num); err != nil { - return nil, err - } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onNum(num) - } - localNames := make([]string, num) + var localNames []string for i := int64(0); i < num; i++ { - fileName, err := t.recvString("NAME", false) + file, localName, err := t.recvFileName(path, directory, overwrite, progress) if err != nil { return nil, err } - localName := fileName - if !overwrite { - localName, err = getNewName(path, fileName) - if err != nil { - return nil, err - } - } - fullPath := filepath.Join(path, localName) - f, err := os.Create(fullPath) - if err != nil { - if e, ok := err.(*fs.PathError); ok { - if errno, ok := e.Err.(syscall.Errno); ok { - if errno == 13 { - return nil, newTrzszError(fmt.Sprintf("No permission to write: %s", fullPath)) - } else if errno == 21 { - return nil, newTrzszError(fmt.Sprintf("Is a directory: %s", fullPath)) - } - } - } - return nil, err - } - defer f.Close() - if err := t.sendString("SUCC", localName); err != nil { - return nil, err + + if !containsString(localNames, localName) { + localNames = append(localNames, localName) } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onName(fileName) + + if file == nil { + continue } - fileSize, err := t.recvInteger("SIZE", false) + defer file.Close() + + fileSize, err := t.recvFileSize(progress) if err != nil { return nil, err } - if err := t.sendInteger("SUCC", fileSize); err != nil { - return nil, err - } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onSize(fileSize) - } step := int64(0) hasher := md5.New() @@ -635,7 +839,7 @@ func (t *TrzszTransfer) recvFiles(path string, progress ProgressCallback) ([]str if err != nil { return nil, err } - if _, err := f.Write(data); err != nil { + if _, err := file.Write(data); err != nil { return nil, err } size := int64(len(data)) @@ -655,22 +859,9 @@ func (t *TrzszTransfer) recvFiles(path string, progress ProgressCallback) ([]str } } - actualDigest := hasher.Sum(nil) - expectDigest, err := t.recvBinary("MD5", false, nil) - if err != nil { + if err := t.recvFileMD5(hasher.Sum(nil), progress); err != nil { return nil, err } - if bytes.Compare(actualDigest, expectDigest) != 0 { - return nil, newTrzszError(fmt.Sprintf("Check MD5 of %s failed", fileName)) - } - if err := t.sendBinary("SUCC", actualDigest); err != nil { - return nil, err - } - if progress != nil && !reflect.ValueOf(progress).IsNil() { - progress.onDone(localName) - } - - localNames[i] = localName } return localNames, nil diff --git a/trzsz/trz.go b/trzsz/trz.go index ab76742..0503c07 100644 --- a/trzsz/trz.go +++ b/trzsz/trz.go @@ -73,6 +73,15 @@ func recvFiles(transfer *TrzszTransfer, args *TrzArgs, tmuxMode TmuxMode, tmuxPa args.Binary = false } + // check if the client doesn't support transfer directory + supportDir := false + if v, ok := action["support_dir"].(bool); ok { + supportDir = v + } + if args.Directory && !supportDir { + return newTrzszError("The client doesn't support transfer directory") + } + escapeChars := getEscapeChars(args.Escape) if err := transfer.sendConfig(&args.Args, escapeChars, tmuxMode, tmuxPaneWidth); err != nil { return err @@ -122,7 +131,7 @@ func TrzMain() int { args.Binary = false } - uniqueId := "0" + uniqueID := "0" if tmuxMode == TmuxNormalMode { columns := getTerminalColumns() if columns > 0 && columns < 40 { @@ -130,13 +139,17 @@ func TrzMain() int { } else { os.Stdout.WriteString("\n\x1b[1A\x1b[0J") } - uniqueId = reverseString(strconv.FormatInt(time.Now().UnixMilli(), 10)) + uniqueID = reverseString(strconv.FormatInt(time.Now().UnixMilli(), 10)) } if IsWindows() { - uniqueId = "1" + uniqueID = "1" } - os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:R:%s:%s\n", kTrzszVersion, uniqueId)) + mode := "R" + if args.Directory { + mode = "D" + } + os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:%s:%s:%s\n", mode, kTrzszVersion, uniqueID)) os.Stdout.Sync() state, err := term.MakeRaw(int(os.Stdin.Fd())) diff --git a/trzsz/trzsz.go b/trzsz/trzsz.go index 5dd2bab..7f59c05 100644 --- a/trzsz/trzsz.go +++ b/trzsz/trzsz.go @@ -57,7 +57,7 @@ var gDragFiles []string = nil var gDragHasDir bool = false var gInterrupting bool = false var gTransfer *TrzszTransfer = nil -var trzszRegexp = regexp.MustCompile("::TRZSZ:TRANSFER:([SR]):(\\d+\\.\\d+\\.\\d+)(:\\d+)?") +var trzszRegexp = regexp.MustCompile("::TRZSZ:TRANSFER:([SRD]):(\\d+\\.\\d+\\.\\d+)(:\\d+)?") func printVersion() { fmt.Printf("trzsz go %s\n", kTrzszVersion) @@ -160,7 +160,7 @@ func chooseDownloadPath() (string, error) { return path, nil } -func chooseUploadFiles() ([]string, error) { +func chooseUploadPaths(directory bool) ([]string, error) { if gDragFiles != nil { files := gDragFiles gDragFiles = nil @@ -175,6 +175,9 @@ func chooseUploadFiles() ([]string, error) { if defaultPath != nil { options = append(options, zenity.Filename(*defaultPath)) } + if directory { + options = append(options, zenity.Directory()) + } files, err := zenity.SelectFileMutiple(options...) if err != nil { return nil, err @@ -241,15 +244,16 @@ func downloadFiles(pty *TrzszPty, transfer *TrzszTransfer, remoteIsWindows bool) return transfer.clientExit(fmt.Sprintf("Saved %s to %s", strings.Join(localNames, ", "), path)) } -func uploadFiles(pty *TrzszPty, transfer *TrzszTransfer, remoteIsWindows bool) error { - files, err := chooseUploadFiles() +func uploadFiles(pty *TrzszPty, transfer *TrzszTransfer, directory, remoteIsWindows bool) error { + paths, err := chooseUploadPaths(directory) if err == zenity.ErrCanceled { return transfer.sendAction(false, remoteIsWindows) } if err != nil { return err } - if err := checkFilesReadable(files); err != nil { + files, err := checkPathsReadable(paths, directory) + if err != nil { return err } @@ -261,6 +265,16 @@ func uploadFiles(pty *TrzszPty, transfer *TrzszTransfer, remoteIsWindows bool) e return err } + overwrite := false + if v, ok := config["overwrite"].(bool); ok { + overwrite = v + } + if overwrite { + if err := checkDuplicateNames(files); err != nil { + return err + } + } + progress, err := newProgressBar(pty, config) if err != nil { return err @@ -293,10 +307,13 @@ func handleTrzsz(pty *TrzszPty, mode byte, remoteIsWindows bool) { }() var err error - if mode == 'S' { + switch mode { + case 'S': err = downloadFiles(pty, transfer, remoteIsWindows) - } else if mode == 'R' { - err = uploadFiles(pty, transfer, remoteIsWindows) + case 'R': + err = uploadFiles(pty, transfer, false, remoteIsWindows) + case 'D': + err = uploadFiles(pty, transfer, true, remoteIsWindows) } if err != nil { transfer.clientError(err) @@ -313,9 +330,9 @@ func uploadDragFiles(pty *TrzszPty) { time.Sleep(200 * time.Millisecond) gInterrupting = false if gDragHasDir { - pty.Stdin.Write([]byte("echo 'upload directory is not supported yet'\n")) + pty.Stdin.Write([]byte("trz -d\r")) } else { - pty.Stdin.Write([]byte("trz\n")) + pty.Stdin.Write([]byte("trz\r")) } time.Sleep(time.Second) if gDragFiles != nil { @@ -433,7 +450,7 @@ func wrapOutput(pty *TrzszPty) { } if gTrzszArgs.DragFile && gDragFiles != nil { output := strings.TrimRight(string(trimVT100(buf)), "\r\n") - if output == "trz" { + if output == "trz" || output == "trz -d" { os.Stdout.WriteString("\r\n") continue } diff --git a/trzsz/tsz.go b/trzsz/tsz.go index 3075c7d..fc001a7 100644 --- a/trzsz/tsz.go +++ b/trzsz/tsz.go @@ -47,7 +47,7 @@ func (TszArgs) Version() string { return fmt.Sprintf("tsz (trzsz) go %s", kTrzszVersion) } -func sendFiles(transfer *TrzszTransfer, args *TszArgs, tmuxMode TmuxMode, tmuxPaneWidth int) error { +func sendFiles(transfer *TrzszTransfer, files []*TrzszFile, args *TszArgs, tmuxMode TmuxMode, tmuxPaneWidth int) error { action, err := transfer.recvAction() if err != nil { return err @@ -71,12 +71,21 @@ func sendFiles(transfer *TrzszTransfer, args *TszArgs, tmuxMode TmuxMode, tmuxPa args.Binary = false } + // check if the client doesn't support transfer directory + supportDir := false + if v, ok := action["support_dir"].(bool); ok { + supportDir = v + } + if args.Directory && !supportDir { + return newTrzszError("The client doesn't support transfer directory") + } + var escapeChars [][]unicode if err := transfer.sendConfig(&args.Args, escapeChars, tmuxMode, tmuxPaneWidth); err != nil { return err } - if _, err := transfer.sendFiles(args.File, nil); err != nil { + if _, err := transfer.sendFiles(files, nil); err != nil { return err } @@ -94,15 +103,22 @@ func TszMain() int { var args TszArgs arg.MustParse(&args) - if err := checkFilesReadable(args.File); err != nil { + files, err := checkPathsReadable(args.File, args.Directory) + if err != nil { fmt.Fprintln(os.Stderr, err) return -1 } + if args.Overwrite { + if err := checkDuplicateNames(files); err != nil { + fmt.Fprintln(os.Stderr, err) + return -2 + } + } tmuxMode, realStdout, tmuxPaneWidth, err := checkTmux() if err != nil { fmt.Fprintln(os.Stderr, err) - return -2 + return -3 } if args.Binary && tmuxMode == TmuxControlMode { @@ -114,7 +130,7 @@ func TszMain() int { args.Binary = false } - uniqueId := "0" + uniqueID := "0" if tmuxMode == TmuxNormalMode { columns := getTerminalColumns() if columns > 0 && columns < 40 { @@ -122,19 +138,19 @@ func TszMain() int { } else { os.Stdout.WriteString("\n\x1b[1A\x1b[0J") } - uniqueId = reverseString(strconv.FormatInt(time.Now().UnixMilli(), 10)) + uniqueID = reverseString(strconv.FormatInt(time.Now().UnixMilli(), 10)) } if IsWindows() { - uniqueId = "1" + uniqueID = "1" } - os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:S:%s:%s\n", kTrzszVersion, uniqueId)) + os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:S:%s:%s\n", kTrzszVersion, uniqueID)) os.Stdout.Sync() state, err := term.MakeRaw(int(os.Stdin.Fd())) if err != nil { fmt.Fprintln(os.Stderr, err) - return -3 + return -4 } defer func() { _ = term.Restore(int(os.Stdin.Fd()), state) }() @@ -148,7 +164,7 @@ func TszMain() int { go wrapStdinInput(transfer) handleServerSignal(transfer) - if err := sendFiles(transfer, &args, tmuxMode, tmuxPaneWidth); err != nil { + if err := sendFiles(transfer, files, &args, tmuxMode, tmuxPaneWidth); err != nil { transfer.serverError(err) } diff --git a/trzsz/version.go b/trzsz/version.go index e6834d1..d78b48c 100644 --- a/trzsz/version.go +++ b/trzsz/version.go @@ -24,4 +24,4 @@ SOFTWARE. package trzsz -const kTrzszVersion = "0.1.6" +const kTrzszVersion = "0.1.7"