diff --git a/probe/cri/registry.go b/probe/cri/registry.go index c346ec59b3..a069103e24 100644 --- a/probe/cri/registry.go +++ b/probe/cri/registry.go @@ -11,48 +11,59 @@ import ( client "github.com/weaveworks/scope/cri/runtime" ) -const unixProtocol = "unix" +const ( + unixProtocol = "unix" + tcpProtocol = "tcp" +) func dial(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout(unixProtocol, addr, timeout) } func getAddressAndDialer(endpoint string) (string, func(addr string, timeout time.Duration) (net.Conn, error), error) { - protocol, addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol) + addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol) if err != nil { return "", nil, err } - if protocol != unixProtocol { - return "", nil, fmt.Errorf("endpoint was not unix socket %v", protocol) - } return addr, dial, nil } -func parseEndpointWithFallbackProtocol(endpoint string, fallbackProtocol string) (protocol string, addr string, err error) { - if protocol, addr, err = parseEndpoint(endpoint); err != nil && protocol == "" { +func parseEndpointWithFallbackProtocol(endpoint string, fallbackProtocol string) (addr string, err error) { + var protocol string + + protocol, addr, err = parseEndpoint(endpoint) + + if err != nil { + return "", err + } + + if protocol == "" { fallbackEndpoint := fallbackProtocol + "://" + endpoint - protocol, addr, err = parseEndpoint(fallbackEndpoint) + _, addr, err = parseEndpoint(fallbackEndpoint) + if err != nil { - return "", "", err + return "", err } } - return + return addr, err } func parseEndpoint(endpoint string) (string, string, error) { u, err := url.Parse(endpoint) + if err != nil { return "", "", err } - if u.Scheme == "tcp" { - return "tcp", u.Host, nil - } else if u.Scheme == "unix" { - return "unix", u.Path, nil - } else if u.Scheme == "" { - return "", "", fmt.Errorf("Using %q as endpoint is deprecated, please consider using full url format", endpoint) - } else { + switch u.Scheme { + case tcpProtocol: + return tcpProtocol, u.Host, fmt.Errorf("endpoint was not unix socket %v", u.Scheme) + case unixProtocol: + return unixProtocol, u.Path, nil + case "": + return "", "", nil + default: return u.Scheme, "", fmt.Errorf("protocol %q not supported", u.Scheme) } } diff --git a/probe/cri/registry_test.go b/probe/cri/registry_test.go index 916f06e4fd..2c91a16842 100644 --- a/probe/cri/registry_test.go +++ b/probe/cri/registry_test.go @@ -7,15 +7,36 @@ import ( "github.com/weaveworks/scope/probe/cri" ) -func TestParseHttpEndpointUrl(t *testing.T) { - _, err := cri.NewCRIClient("http://xyz.com") +var nonUnixSocketsTest = []struct { + endpoint string + errorMessage string +}{ + {"http://xyz.com", "protocol \"http\" not supported"}, + {"tcp://var/unix.sock", "endpoint was not unix socket tcp"}, + {"http://[fe80::%31]/", "parse http://[fe80::%31]/: invalid URL escape \"%31\""}, +} + +func TestParseNonUnixEndpointUrl(t *testing.T) { + for _, tt := range nonUnixSocketsTest { + _, err := cri.NewCRIClient(tt.endpoint) - assert.Equal(t, "protocol \"http\" not supported", err.Error()) + assert.Equal(t, tt.errorMessage, err.Error()) + } } -func TestParseTcpEndpointUrl(t *testing.T) { - client, err := cri.NewCRIClient("127.0.0.1") +var unixSocketsTest = []string{ + "127.0.0.1", // tests the fallback endpoint + "unix://127.0.0.1", + "unix///var/run/dockershim.sock", + "var/run/dockershim.sock", +} + +func TestParseUnixEndpointUrl(t *testing.T) { + for _, tt := range unixSocketsTest { + client, err := cri.NewCRIClient(tt) + + assert.Equal(t, nil, err) + assert.NotEqual(t, nil, client) + } - assert.Equal(t, nil, err) - assert.NotEqual(t, nil, client) }