diff --git a/source/gateway.go b/source/gateway.go index 4d325d627a..b3aec2af3b 100644 --- a/source/gateway.go +++ b/source/gateway.go @@ -19,6 +19,7 @@ package source import ( "context" "fmt" + "net/netip" "sort" "strings" "text/template" @@ -486,39 +487,89 @@ func gwProtocolMatches(a, b v1.ProtocolType) bool { } // gwMatchingHost returns the most-specific overlapping host and a bool indicating if one was found. -// For example, if one host is "*.foo.com" and the other is "bar.foo.com", "bar.foo.com" will be returned. -// An empty string matches anything. -func gwMatchingHost(gwHost, rtHost string) (string, bool) { - gwHost = toLowerCaseASCII(gwHost) // TODO: trim "." suffix? - rtHost = toLowerCaseASCII(rtHost) // TODO: trim "." suffix? +// Hostnames that are prefixed with a wildcard label (`*.`) are interpreted as a suffix match. +// That means that "*.example.com" would match both "test.example.com" and "foo.test.example.com", +// but not "example.com". An empty string matches anything. +func gwMatchingHost(a, b string) (string, bool) { + var ok bool + if a, ok = gwHost(a); !ok { + return "", false + } + if b, ok = gwHost(b); !ok { + return "", false + } - if gwHost == "" { - return rtHost, true + if a == "" { + return b, true } - if rtHost == "" { - return gwHost, true + if b == "" || a == b { + return a, true } + if na, nb := len(a), len(b); nb < na || (na == nb && strings.HasPrefix(b, "*.")) { + a, b = b, a + } + if strings.HasPrefix(a, "*.") && strings.HasSuffix(b, a[1:]) { + return b, true + } + return "", false +} - gwParts := strings.Split(gwHost, ".") - rtParts := strings.Split(rtHost, ".") - if len(gwParts) != len(rtParts) { +// gwHost returns the canonical host and a value indicating if it's valid. +func gwHost(host string) (string, bool) { + if host == "" { + return "", true + } + if isIPAddr(host) || !isDNS1123Domain(strings.TrimPrefix(host, "*.")) { return "", false } + return toLowerCaseASCII(host), true // TODO: trim "." suffix? +} + +// isIPAddr returns whether s in an IP address. +func isIPAddr(s string) bool { + _, err := netip.ParseAddr(s) + return err == nil +} - host := rtHost - for i, gwPart := range gwParts { - switch rtPart := rtParts[i]; { - case rtPart == gwPart: - // continue - case i == 0 && gwPart == "*": - // continue - case i == 0 && rtPart == "*": - host = gwHost // gwHost is more specific - default: - return "", false +// isDNS1123Domain returns whether s is a valid domain name according to RFC 1123. +func isDNS1123Domain(s string) bool { + if n := len(s); n == 0 || n > 255 { + return false + } + for lbl, rest := "", s; rest != ""; { + if lbl, rest, _ = strings.Cut(rest, "."); !isDNS1123Label(lbl) { + return false } } - return host, true + return true +} + +// isDNS1123Label returns whether s is a valid domain label according to RFC 1123. +func isDNS1123Label(s string) bool { + n := len(s) + if n == 0 || n > 63 { + return false + } + if !isAlphaNum(s[0]) || !isAlphaNum(s[n-1]) { + return false + } + for i, k := 1, n-1; i < k; i++ { + if b := s[i]; b != '-' && !isAlphaNum(b) { + return false + } + } + return true +} + +func isAlphaNum(b byte) bool { + switch { + case 'a' <= b && b <= 'z', + 'A' <= b && b <= 'Z', + '0' <= b && b <= '9': + return true + default: + return false + } } func strVal(ptr *string, def string) string { diff --git a/source/gateway_test.go b/source/gateway_test.go new file mode 100644 index 0000000000..96291b1b92 --- /dev/null +++ b/source/gateway_test.go @@ -0,0 +1,175 @@ +package source + +import ( + "strings" + "testing" +) + +func TestGatewayMatchingHost(t *testing.T) { + tests := []struct { + desc string + a, b string + host string + ok bool + }{ + { + desc: "ipv4-rejected", + a: "1.2.3.4", + ok: false, + }, + { + desc: "ipv6-rejected", + a: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + ok: false, + }, + { + desc: "empty-matches-empty", + ok: true, + }, + { + desc: "empty-matches-nonempty", + a: "example.net", + host: "example.net", + ok: true, + }, + { + desc: "simple-match", + a: "example.net", + b: "example.net", + host: "example.net", + ok: true, + }, + { + desc: "wildcard-matches-longer", + a: "*.example.net", + b: "test.example.net", + host: "test.example.net", + ok: true, + }, + { + desc: "wildcard-matches-equal-length", + a: "*.example.net", + b: "a.example.net", + host: "a.example.net", + ok: true, + }, + { + desc: "wildcard-matches-multiple-subdomains", + a: "*.example.net", + b: "foo.bar.test.example.net", + host: "foo.bar.test.example.net", + ok: true, + }, + { + desc: "wildcard-doesnt-match-parent", + a: "*.example.net", + b: "example.net", + ok: false, + }, + { + desc: "wildcard-must-be-complete-label", + a: "*example.net", + b: "test.example.net", + ok: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + for i := 0; i < 2; i++ { + if host, ok := gwMatchingHost(tt.a, tt.b); host != tt.host || ok != tt.ok { + t.Errorf( + "gwMatchingHost(%q, %q); got: %q, %v; want: %q, %v", + tt.a, tt.b, host, ok, tt.host, tt.ok, + ) + } + tt.a, tt.b = tt.b, tt.a + } + }) + + } +} + +func TestIsDNS1123Domain(t *testing.T) { + tests := []struct { + desc string + in string + ok bool + }{ + { + desc: "empty", + ok: false, + }, + { + desc: "label-too-long", + in: strings.Repeat("x", 64) + ".example.net", + ok: false, + }, + { + desc: "domain-too-long", + in: strings.Repeat("testing.", 256/(len("testing."))) + "example.net", + ok: false, + }, + { + desc: "hostname", + in: "example", + ok: true, + }, + { + desc: "domain", + in: "example.net", + ok: true, + }, + { + desc: "subdomain", + in: "test.example.net", + ok: true, + }, + { + desc: "dashes", + in: "test-with-dash.example.net", + ok: true, + }, + { + desc: "dash-prefix", + in: "-dash-prefix.example.net", + ok: false, + }, + { + desc: "dash-suffix", + in: "dash-suffix-.example.net", + ok: false, + }, + { + desc: "underscore", + in: "under_score.example.net", + ok: false, + }, + { + desc: "plus", + in: "pl+us.example.net", + ok: false, + }, + { + desc: "brackets", + in: "bra[k]ets.example.net", + ok: false, + }, + { + desc: "parens", + in: "pa[re]ns.example.net", + ok: false, + }, + { + desc: "wild", + in: "*.example.net", + ok: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if ok := isDNS1123Domain(tt.in); ok != tt.ok { + t.Errorf("isDNS1123Domain(%q); got: %v; want: %v", tt.in, ok, tt.ok) + } + }) + } +}