Skip to content

Commit

Permalink
#77 Build host key callback for the SSH client from the known_hosts f…
Browse files Browse the repository at this point in the history
…ile(s)
  • Loading branch information
zshamrock committed Nov 29, 2018
1 parent 2e505d5 commit 1111250
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
30 changes: 28 additions & 2 deletions command/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package command

import (
"fmt"
"github.com/zshamrock/vmx/config"
"io/ioutil"
"log"
"os"
Expand All @@ -10,13 +11,15 @@ import (

"github.com/kevinburke/ssh_config"
cryptoSSH "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

const (
SshConfigUserKey = "User"
SshConfigHostnameKey = "Hostname"
SshConfigIdentityFileKey = "IdentityFile"
ignoredIdentitySshFile = "~/.ssh/identity"
knownHostsFileName = "known_hosts"
)

// ssh implements scp connection to the remote instance
Expand All @@ -27,17 +30,19 @@ func ssh(sshConfig *ssh_config.Config, host, command string, follow bool, ch cha
identityFile, _ := sshConfig.Get(host, SshConfigIdentityFileKey)
var identityFilePath string
if identityFile == "" || identityFile == ignoredIdentitySshFile {
identityFilePath = filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa")
identityFilePath = filepath.Join(config.DefaultConfig.SSHConfigDir, "id_rsa")
} else {
identityFilePath = os.ExpandEnv(strings.Replace(identityFile, "~", "${HOME}", -1))
identityFilePath = strings.Replace(identityFile, "~", config.DefaultConfig.SSHConfigDir, -1)
}
pk, _ := ioutil.ReadFile(identityFilePath)
signer, _ := cryptoSSH.ParsePrivateKey([]byte(pk))
hostKeyCallback := buildHostKeyCallback()
config := &cryptoSSH.ClientConfig{
User: user,
Auth: []cryptoSSH.AuthMethod{
cryptoSSH.PublicKeys(signer),
},
HostKeyCallback: hostKeyCallback,
}
client, err := cryptoSSH.Dial("tcp", fmt.Sprintf("%s:22", hostname), config)
if err != nil {
Expand Down Expand Up @@ -66,3 +71,24 @@ func ssh(sshConfig *ssh_config.Config, host, command string, follow bool, ch cha
output.String(),
}
}

func buildHostKeyCallback() cryptoSSH.HostKeyCallback {
configuredDefaultKnownHostsFile := filepath.Join(config.DefaultConfig.SSHConfigDir, knownHostsFileName)
_, err := os.Stat(configuredDefaultKnownHostsFile)
knownHostsFiles := make([]string, 0, 2)
if err == nil {
knownHostsFiles = append(knownHostsFiles, configuredDefaultKnownHostsFile)
}
if defaultKnownHostsFile := filepath.Join(os.ExpandEnv(config.DefaultSSHConfigHome), knownHostsFileName); configuredDefaultKnownHostsFile != defaultKnownHostsFile {
knownHostsFiles = append(knownHostsFiles, defaultKnownHostsFile)
}
if len(knownHostsFiles) == 0 {
fmt.Printf("No %s files are found\n", knownHostsFileName)
os.Exit(1)
}
hostKeyCallback, err := knownhosts.New(knownHostsFiles...)
if err != nil {
log.Panicf("Failed to to read %s file %v\n", knownHostsFileName, err.Error())
}
return hostKeyCallback
}
4 changes: 2 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const (
vmxHomeEnvVar = "VMX_HOME"
defaultVmxHome = "${HOME}/.vmx"
vmxSSHConfigHomeEnvVar = "VMX_SSH_CONFIG_HOME"
defaultSSHConfigHome = "${HOME}/.ssh"
DefaultSSHConfigHome = "${HOME}/.ssh"

CommandNameConfirmationSuffix = "!"
HostsGroupChildrenSuffix = ":children"
Expand All @@ -31,7 +31,7 @@ func (c VMXConfig) GetDir(profile string) string {
var DefaultConfig VMXConfig

func init() {
DefaultConfig = VMXConfig{os.ExpandEnv(defaultVmxHome), os.ExpandEnv(defaultSSHConfigHome)}
DefaultConfig = VMXConfig{os.ExpandEnv(defaultVmxHome), os.ExpandEnv(DefaultSSHConfigHome)}
vmxHome, ok := os.LookupEnv(vmxHomeEnvVar)
if ok {
DefaultConfig.Dir = vmxHome
Expand Down
2 changes: 1 addition & 1 deletion config/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func GetHostNames() []string {
}

// Reading hosts from ~/.ssh/config
f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
f, _ := os.Open(filepath.Join(DefaultConfig.SSHConfigDir, "config"))
cfg, _ := ssh_config.Decode(f)
for _, host := range cfg.Hosts {
for _, pattern := range host.Patterns {
Expand Down

0 comments on commit 1111250

Please sign in to comment.