Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃尡 Handle TLSOpts.GetCertificate in webhook #2291

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions pkg/webhook/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ type Server struct {
CertDir string

// CertName is the server certificate name. Defaults to tls.crt.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
CertName string

// KeyName is the server key name. Defaults to tls.key.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
KeyName string

// ClientCAName is the CA certificate name which server used to verify remote(client)'s certificate.
Expand Down Expand Up @@ -169,32 +173,40 @@ func (s *Server) Start(ctx context.Context) error {
baseHookLog := log.WithName("webhooks")
baseHookLog.Info("Starting webhook server")

certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()

tlsMinVersion, err := tlsVersion(s.TLSMinVersion)
if err != nil {
return err
}

cfg := &tls.Config{ //nolint:gosec
NextProtos: []string{"h2"},
GetCertificate: certWatcher.GetCertificate,
MinVersion: tlsMinVersion,
NextProtos: []string{"h2"},
MinVersion: tlsMinVersion,
}
// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

if cfg.GetCertificate == nil {
certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

// Create the certificate watcher and
// set the config's GetCertificate on the TLSConfig
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}
cfg.GetCertificate = certWatcher.GetCertificate

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()
}

// load CA to verify client certificate
// Load CA to verify client certificate, if configured.
if s.ClientCAName != "" {
certPool := x509.NewCertPool()
clientCABytes, err := os.ReadFile(filepath.Join(s.CertDir, s.ClientCAName))
Expand All @@ -211,11 +223,6 @@ func (s *Server) Start(ctx context.Context) error {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)), cfg)
if err != nil {
return err
Expand Down
51 changes: 50 additions & 1 deletion pkg/webhook/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"io"
"net"
"net/http"
"path"
"reflect"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -181,7 +183,7 @@ var _ = Describe("Webhook Server", func() {
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx))
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
Expand All @@ -199,6 +201,53 @@ var _ = Describe("Webhook Server", func() {
ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})

It("should prefer GetCertificate through TLSOpts", func() {
var finalCfg *tls.Config
finalCert, err := tls.LoadX509KeyPair(
path.Join(servingOpts.LocalServingCertDir, "tls.crt"),
path.Join(servingOpts.LocalServingCertDir, "tls.key"),
)
Expect(err).NotTo(HaveOccurred())
finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
return &finalCert, nil
}
server = &webhook.Server{
Host: servingOpts.LocalServingHost,
Port: servingOpts.LocalServingPort,
CertDir: servingOpts.LocalServingCertDir,
TLSMinVersion: "1.2",
TLSOpts: []func(*tls.Config){
func(cfg *tls.Config) {
cfg.GetCertificate = finalGetCertificate
// save cfg after changes to test against
finalCfg = cfg
},
},
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}).Should(Equal([]byte("gadzooks!")))
Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
// We can't compare the functions directly, but we can compare their pointers
if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() {
Fail("GetCertificate was not set properly, or overwritten")
}
cert, err := finalCfg.GetCertificate(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cert).To(BeEquivalentTo(&finalCert))

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})
})

type testHandler struct {
Expand Down