diff --git a/pkg/webhook/server.go b/pkg/webhook/server.go index 6b68ae3ed5..1e21da71d2 100644 --- a/pkg/webhook/server.go +++ b/pkg/webhook/server.go @@ -60,9 +60,13 @@ type Server struct { CertDir string // CertName is the server certificate name. Defaults to tls.crt. + // + // Note: This option should only be set when TLSOpts does not override GetCertificate. CertName string // KeyName is the server key name. Defaults to tls.key. + // + // Note: This option should only be set when TLSOpts does not override GetCertificate. KeyName string // ClientCAName is the CA certificate name which server used to verify remote(client)'s certificate. @@ -169,32 +173,40 @@ func (s *Server) Start(ctx context.Context) error { baseHookLog := log.WithName("webhooks") baseHookLog.Info("Starting webhook server") - certPath := filepath.Join(s.CertDir, s.CertName) - keyPath := filepath.Join(s.CertDir, s.KeyName) - - certWatcher, err := certwatcher.New(certPath, keyPath) - if err != nil { - return err - } - - go func() { - if err := certWatcher.Start(ctx); err != nil { - log.Error(err, "certificate watcher error") - } - }() - tlsMinVersion, err := tlsVersion(s.TLSMinVersion) if err != nil { return err } cfg := &tls.Config{ //nolint:gosec - NextProtos: []string{"h2"}, - GetCertificate: certWatcher.GetCertificate, - MinVersion: tlsMinVersion, + NextProtos: []string{"h2"}, + MinVersion: tlsMinVersion, + } + // fallback TLS config ready, will now mutate if passer wants full control over it + for _, op := range s.TLSOpts { + op(cfg) + } + + if cfg.GetCertificate == nil { + certPath := filepath.Join(s.CertDir, s.CertName) + keyPath := filepath.Join(s.CertDir, s.KeyName) + + // Create the certificate watcher and + // set the config's GetCertificate on the TLSConfig + certWatcher, err := certwatcher.New(certPath, keyPath) + if err != nil { + return err + } + cfg.GetCertificate = certWatcher.GetCertificate + + go func() { + if err := certWatcher.Start(ctx); err != nil { + log.Error(err, "certificate watcher error") + } + }() } - // load CA to verify client certificate + // Load CA to verify client certificate, if configured. if s.ClientCAName != "" { certPool := x509.NewCertPool() clientCABytes, err := os.ReadFile(filepath.Join(s.CertDir, s.ClientCAName)) @@ -211,11 +223,6 @@ func (s *Server) Start(ctx context.Context) error { cfg.ClientAuth = tls.RequireAndVerifyClientCert } - // fallback TLS config ready, will now mutate if passer wants full control over it - for _, op := range s.TLSOpts { - op(cfg) - } - listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)), cfg) if err != nil { return err diff --git a/pkg/webhook/server_test.go b/pkg/webhook/server_test.go index e9b40a1542..fea2a99f41 100644 --- a/pkg/webhook/server_test.go +++ b/pkg/webhook/server_test.go @@ -23,6 +23,8 @@ import ( "io" "net" "net/http" + "path" + "reflect" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -52,7 +54,7 @@ var _ = Describe("Webhook Server", func() { // bypass needing to set up the x509 cert pool, etc ourselves clientTransport, err := rest.TransportFor(&rest.Config{ - TLSClientConfig: rest.TLSClientConfig{CAData: servingOpts.LocalServingCAData}, + TLSClientConfig: rest.TLSClientConfig{Insecure: true}, }) Expect(err).NotTo(HaveOccurred()) client = &http.Client{ @@ -181,7 +183,7 @@ var _ = Describe("Webhook Server", func() { } server.Register("/somepath", &testHandler{}) doneCh := genericStartServer(func(ctx context.Context) { - Expect(server.Start(ctx)) + Expect(server.Start(ctx)).To(Succeed()) }) Eventually(func() ([]byte, error) { @@ -199,6 +201,53 @@ var _ = Describe("Webhook Server", func() { ctxCancel() Eventually(doneCh, "4s").Should(BeClosed()) }) + + It("should prefer GetCertificate through TLSOpts", func() { + var finalCfg *tls.Config + finalCert, err := tls.LoadX509KeyPair( + path.Join(servingOpts.LocalServingCertDir, "tls.crt"), + path.Join(servingOpts.LocalServingCertDir, "tls.key"), + ) + Expect(err).NotTo(HaveOccurred()) + finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam + return &finalCert, nil + } + server = &webhook.Server{ + Host: servingOpts.LocalServingHost, + Port: servingOpts.LocalServingPort, + CertDir: servingOpts.LocalServingCertDir, + TLSMinVersion: "1.2", + TLSOpts: []func(*tls.Config){ + func(cfg *tls.Config) { + cfg.GetCertificate = finalGetCertificate + // save cfg after changes to test against + finalCfg = cfg + }, + }, + } + server.Register("/somepath", &testHandler{}) + doneCh := genericStartServer(func(ctx context.Context) { + Expect(server.Start(ctx)).To(Succeed()) + }) + + Eventually(func() ([]byte, error) { + resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort)) + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + return io.ReadAll(resp.Body) + }).Should(Equal([]byte("gadzooks!"))) + Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12))) + // We can't compare the functions directly, but we can compare their pointers + if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() { + Fail("GetCertificate was not set properly, or overwritten") + } + cert, err := finalCfg.GetCertificate(nil) + Expect(err).NotTo(HaveOccurred()) + Expect(cert).To(BeEquivalentTo(&finalCert)) + + ctxCancel() + Eventually(doneCh, "4s").Should(BeClosed()) + }) }) type testHandler struct {