From 8c9e884e3c0e39c2d8d05fd36df9d8e9dc6be98e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 10 Mar 2023 16:55:04 +0100 Subject: [PATCH] Handle errors when resolving ssh aliases to hostnames --- pkg/ssh/ssh.go | 6 ++++- pkg/ssh/ssh_test.go | 65 ++++++++++++++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 0ef9137..4e5216e 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -95,7 +95,11 @@ func (t *Translator) resolve(hostname string) (string, error) { } } - _ = sshCmd.Wait() + err = sshCmd.Wait() + if err != nil || resolvedHost == "" { + // handle failures by returning the original hostname unchanged + resolvedHost = hostname + } if t.cacheMap == nil { t.cacheMap = map[string]string{} diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index 46ea3d4..d61e4e6 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -1,6 +1,7 @@ package ssh import ( + "errors" "fmt" "net/url" "os" @@ -85,7 +86,13 @@ func TestHelperProcess(t *testing.T) { return } if err := func(args []string) error { - fmt.Fprint(os.Stdout, "hostname github.com\n") + if len(args) < 3 || args[2] == "error" { + return errors.New("fatal") + } + if args[2] == "empty.io" { + return nil + } + fmt.Fprintf(os.Stdout, "hostname %s\n", args[2]) return nil }(os.Args[3:]); err != nil { fmt.Fprintln(os.Stderr, err) @@ -111,32 +118,46 @@ func TestTranslator_caching(t *testing.T) { }, } - u1, err := url.Parse("ssh://github1.com/owner/repo.git") - if err != nil { - t.Fatalf("error parsing URL: %v", err) - } - if res := tr.Translate(u1); res.Host != "github.com" { - t.Errorf("expected github.com, got: %q", res.Host) - } - if res := tr.Translate(u1); res.Host != "github.com" { - t.Errorf("expected github.com, got: %q", res.Host) - } - - u2, err := url.Parse("ssh://github2.com/owner/repo.git") - if err != nil { - t.Fatalf("error parsing URL: %v", err) - } - if res := tr.Translate(u2); res.Host != "github.com" { - t.Errorf("expected github.com, got: %q", res.Host) + tests := []struct { + input string + result string + }{ + { + input: "ssh://github1.com/owner/repo.git", + result: "github1.com", + }, + { + input: "ssh://github2.com/owner/repo.git", + result: "github2.com", + }, + { + input: "ssh://empty.io/owner/repo.git", + result: "empty.io", + }, + { + input: "ssh://error/owner/repo.git", + result: "error", + }, } - if res := tr.Translate(u2); res.Host != "github.com" { - t.Errorf("expected github.com, got: %q", res.Host) + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + u, err := url.Parse(tt.input) + if err != nil { + t.Fatalf("error parsing URL: %v", err) + } + if res := tr.Translate(u); res.Host != tt.result { + t.Errorf("expected github.com, got: %q", res.Host) + } + if res := tr.Translate(u); res.Host != tt.result { + t.Errorf("expected github.com, got: %q (second call)", res.Host) + } + }) } if countLookPath != 1 { t.Errorf("expected lookPath to happen 1 time; actual: %d", countLookPath) } - if countNewCommand != 2 { - t.Errorf("expected ssh command to shell out 2 times; actual: %d", countNewCommand) + if countNewCommand != len(tests) { + t.Errorf("expected ssh command to shell out %d times; actual: %d", len(tests), countNewCommand) } }