diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index 5d1a9cd44cb..86511860335 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "fmt" "net" + "strings" "sync" ) @@ -151,20 +152,62 @@ func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) e } } if len(cert.DNSNames) > 0 { - for _, dns := range cert.DNSNames { - addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) - if lerr != nil { - continue + ok, err := isHostInDNS(ctx, h, cert.DNSNames) + if ok { + return nil + } + errStr := "" + if err != nil { + errStr = " (" + err.Error() + ")" + } + return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames) + } + return nil +} + +func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) { + // reverse lookup + wildcards, names := []string{}, []string{} + for _, dns := range dnsNames { + if strings.HasPrefix(dns, "*.") { + wildcards = append(wildcards, dns[1:]) + } else { + names = append(names, dns) + } + } + lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host) + for _, name := range lnames { + // strip trailing '.' from PTR record + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + for _, wc := range wildcards { + if strings.HasSuffix(name, wc) { + return true, nil } - for _, addr := range addrs { - if addr == h { - return nil - } + } + for _, n := range names { + if n == name { + return true, nil } } - return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames) } - return nil + err = lerr + + // forward lookup + for _, dns := range names { + addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) + if lerr != nil { + err = lerr + continue + } + for _, addr := range addrs { + if addr == host { + return true, nil + } + } + } + return false, err } func (l *tlsListener) Close() error {