diff --git a/command/ssh.go b/command/ssh.go index eeb4d13..1515822 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -2,6 +2,7 @@ package command import ( "fmt" + "github.com/zshamrock/vmx/config" "io/ioutil" "log" "os" @@ -10,6 +11,7 @@ import ( "github.com/kevinburke/ssh_config" cryptoSSH "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) const ( @@ -17,6 +19,7 @@ const ( SshConfigHostnameKey = "Hostname" SshConfigIdentityFileKey = "IdentityFile" ignoredIdentitySshFile = "~/.ssh/identity" + knownHostsFileName = "known_hosts" ) // ssh implements scp connection to the remote instance @@ -27,17 +30,19 @@ func ssh(sshConfig *ssh_config.Config, host, command string, follow bool, ch cha identityFile, _ := sshConfig.Get(host, SshConfigIdentityFileKey) var identityFilePath string if identityFile == "" || identityFile == ignoredIdentitySshFile { - identityFilePath = filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa") + identityFilePath = filepath.Join(config.DefaultConfig.SSHConfigDir, "id_rsa") } else { - identityFilePath = os.ExpandEnv(strings.Replace(identityFile, "~", "${HOME}", -1)) + identityFilePath = strings.Replace(identityFile, "~", config.DefaultConfig.SSHConfigDir, -1) } pk, _ := ioutil.ReadFile(identityFilePath) signer, _ := cryptoSSH.ParsePrivateKey([]byte(pk)) + hostKeyCallback := buildHostKeyCallback() config := &cryptoSSH.ClientConfig{ User: user, Auth: []cryptoSSH.AuthMethod{ cryptoSSH.PublicKeys(signer), }, + HostKeyCallback: hostKeyCallback, } client, err := cryptoSSH.Dial("tcp", fmt.Sprintf("%s:22", hostname), config) if err != nil { @@ -66,3 +71,24 @@ func ssh(sshConfig *ssh_config.Config, host, command string, follow bool, ch cha output.String(), } } + +func buildHostKeyCallback() cryptoSSH.HostKeyCallback { + configuredDefaultKnownHostsFile := filepath.Join(config.DefaultConfig.SSHConfigDir, knownHostsFileName) + _, err := os.Stat(configuredDefaultKnownHostsFile) + knownHostsFiles := make([]string, 0, 2) + if err == nil { + knownHostsFiles = append(knownHostsFiles, configuredDefaultKnownHostsFile) + } + if defaultKnownHostsFile := filepath.Join(os.ExpandEnv(config.DefaultSSHConfigHome), knownHostsFileName); configuredDefaultKnownHostsFile != defaultKnownHostsFile { + knownHostsFiles = append(knownHostsFiles, defaultKnownHostsFile) + } + if len(knownHostsFiles) == 0 { + fmt.Printf("No %s files are found\n", knownHostsFileName) + os.Exit(1) + } + hostKeyCallback, err := knownhosts.New(knownHostsFiles...) + if err != nil { + log.Panicf("Failed to to read %s file %v\n", knownHostsFileName, err.Error()) + } + return hostKeyCallback +} diff --git a/config/config.go b/config/config.go index d102d9d..c0c1f14 100644 --- a/config/config.go +++ b/config/config.go @@ -9,7 +9,7 @@ const ( vmxHomeEnvVar = "VMX_HOME" defaultVmxHome = "${HOME}/.vmx" vmxSSHConfigHomeEnvVar = "VMX_SSH_CONFIG_HOME" - defaultSSHConfigHome = "${HOME}/.ssh" + DefaultSSHConfigHome = "${HOME}/.ssh" CommandNameConfirmationSuffix = "!" HostsGroupChildrenSuffix = ":children" @@ -31,7 +31,7 @@ func (c VMXConfig) GetDir(profile string) string { var DefaultConfig VMXConfig func init() { - DefaultConfig = VMXConfig{os.ExpandEnv(defaultVmxHome), os.ExpandEnv(defaultSSHConfigHome)} + DefaultConfig = VMXConfig{os.ExpandEnv(defaultVmxHome), os.ExpandEnv(DefaultSSHConfigHome)} vmxHome, ok := os.LookupEnv(vmxHomeEnvVar) if ok { DefaultConfig.Dir = vmxHome diff --git a/config/init.go b/config/init.go index 7b189e1..62e4227 100644 --- a/config/init.go +++ b/config/init.go @@ -153,7 +153,7 @@ func GetHostNames() []string { } // Reading hosts from ~/.ssh/config - f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) + f, _ := os.Open(filepath.Join(DefaultConfig.SSHConfigDir, "config")) cfg, _ := ssh_config.Decode(f) for _, host := range cfg.Hosts { for _, pattern := range host.Patterns {