Skip to content
This repository has been archived by the owner on Jan 17, 2021. It is now read-only.

Commit

Permalink
Merge pull request #116 from cdr/reuse-ssh-connection
Browse files Browse the repository at this point in the history
Add SSH master connection feature
  • Loading branch information
deansheather authored Jun 28, 2019
2 parents c637d40 + bbd94c5 commit 2693c3f
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 19 deletions.
23 changes: 13 additions & 10 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ var _ interface {
} = new(rootCmd)

type rootCmd struct {
skipSync bool
syncBack bool
printVersion bool
bindAddr string
sshFlags string
skipSync bool
syncBack bool
printVersion bool
noReuseConnection bool
bindAddr string
sshFlags string
}

func (c *rootCmd) Spec() cli.CommandSpec {
Expand All @@ -53,6 +54,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) {
fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host")
fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination")
fl.BoolVar(&c.printVersion, "version", false, "print version information and exit")
fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket")
fl.StringVar(&c.bindAddr, "bind", "", "local bind address for SSH tunnel, in [HOST][:PORT] syntax (default: 127.0.0.1)")
fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags")
}
Expand All @@ -76,10 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) {
}

err := sshCode(host, dir, options{
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
reuseConnection: !c.noReuseConnection,
})

if err != nil {
Expand All @@ -101,7 +104,7 @@ Environment variables:
More info: https://github.com/cdr/sshcode
Arguments:
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
%vDIR is optional.`,
helpTab, vsCodeConfigDirEnv,
helpTab, vsCodeExtensionsDirEnv,
Expand Down
154 changes: 145 additions & 9 deletions sshcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path/filepath"
"strconv"
"strings"
"syscall"
"time"

"github.com/pkg/browser"
Expand All @@ -21,18 +22,23 @@ import (

const codeServerPath = "~/.cache/sshcode/sshcode-server"

const (
sshDirectory = "~/.ssh"
sshDirectoryUnsafeModeMask = 0022
sshControlPath = sshDirectory + "/control-%h-%p-%r"
)

type options struct {
skipSync bool
syncBack bool
noOpen bool
bindAddr string
remotePort string
sshFlags string
skipSync bool
syncBack bool
noOpen bool
reuseConnection bool
bindAddr string
remotePort string
sshFlags string
}

func sshCode(host, dir string, o options) error {
flog.Info("ensuring code-server is updated...")

host, extraSSHFlags, err := parseHost(host)
if err != nil {
return xerrors.Errorf("failed to parse host IP: %w", err)
Expand All @@ -53,6 +59,24 @@ func sshCode(host, dir string, o options) error {
return xerrors.Errorf("failed to find available remote port: %w", err)
}

// Check the SSH directory's permissions and warn the user if it is not safe.
o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection)

// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
// only happens on the initial connection.
if o.reuseConnection {
flog.Info("starting SSH master connection...")
newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host)
defer cancel()
if err != nil {
flog.Error("failed to start SSH master connection: %v", err)
o.reuseConnection = false
} else {
o.sshFlags = newSSHFlags
}
}

flog.Info("ensuring code-server is updated...")
dlScript := downloadScript(codeServerPath)

// Downloads the latest code-server and allows it to be executed.
Expand Down Expand Up @@ -147,8 +171,8 @@ func sshCode(host, dir string, o options) error {
case <-c:
}

flog.Info("shutting down")
if !o.syncBack || o.skipSync {
flog.Info("shutting down")
return nil
}

Expand All @@ -167,6 +191,24 @@ func sshCode(host, dir string, o options) error {
return nil
}

// expandPath returns an expanded version of path.
func expandPath(path string) string {
path = filepath.Clean(os.ExpandEnv(path))

// Replace tilde notation in path with the home directory. You can't replace the first instance of `~` in the
// string with the homedir as having a tilde in the middle of a filename is valid.
homedir := os.Getenv("HOME")
if homedir != "" {
if path == "~" {
path = homedir
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(homedir, path[2:])
}
}

return filepath.Clean(path)
}

func parseBindAddr(bindAddr string) (string, error) {
if !strings.Contains(bindAddr, ":") {
bindAddr += ":"
Expand Down Expand Up @@ -263,6 +305,100 @@ func randomPort() (string, error) {
return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

// checkSSHDirectory performs sanity and safety checks on sshDirectory, and
// returns a new value for o.reuseConnection depending on the checks.
func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool {
sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
if err != nil {
if reuseConnection {
flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
}
reuseConnection = false
} else {
if !sshDirectoryMode.IsDir() {
if reuseConnection {
flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
} else {
flog.Info("warning: %v is not a directory", sshDirectory)
}
reuseConnection = false
}
if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
"the owner (and files inside should be set to 0600)", sshDirectory)
}
}
return reuseConnection
}

// startSSHMaster starts an SSH master connection and waits for it to be ready.
// It returns a new set of SSH flags for child SSH processes to use.
func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) {
ctx, cancel := context.WithCancel(context.Background())

newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath)

// -MN means "start a master socket and don't open a session, just connect".
sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host)
sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr)
sshMasterCmd.Stdin = os.Stdin
sshMasterCmd.Stderr = os.Stderr

// Gracefully stop the SSH master.
stopSSHMaster := func() {
if sshMasterCmd.Process != nil {
if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() {
return
}
err := sshMasterCmd.Process.Signal(syscall.SIGTERM)
if err != nil {
flog.Error("failed to send SIGTERM to SSH master process: %v", err)
}
}
cancel()
}

// Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before
// sshcode does, and allows sshMasterCmd.ProcessState to be populated.
err := sshMasterCmd.Start()
go sshMasterCmd.Wait()
if err != nil {
return "", stopSSHMaster, err
}
err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
if err != nil {
stopSSHMaster()
return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err)
}
return newSSHFlags, stopSSHMaster, nil
}

// checkSSHMaster polls every second for 30 seconds to check if the SSH master
// is ready.
func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error {
var (
maxTries = 30
sleepDur = time.Second
err error
)
for i := 0; i < maxTries; i++ {
// Check if the master is running.
if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) {
return xerrors.Errorf("SSH master process is not running")
}

// Check if it's ready.
sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
sshCmd := exec.Command("sh", "-c", sshCmdStr)
err = sshCmd.Run()
if err == nil {
return nil
}
time.Sleep(sleepDur)
}
return xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

func syncUserSettings(sshFlags string, host string, back bool) error {
localConfDir, err := configDir()
if err != nil {
Expand Down

0 comments on commit 2693c3f

Please sign in to comment.