Skip to content

Commit

Permalink
Handle TLSOpts.GetCertificate in webhook
Browse files Browse the repository at this point in the history
This change rewrites some of the webhook server logic to better handle
the user setting a custom configuration for the TLS options by providing
a custom GetCertificate function. When that's present, we won't start
a certwatcher routine. The change also updates the documentation to
better clarify when each of the options are effective.

Signed-off-by: Vince Prignano <vincepri@redhat.com>
  • Loading branch information
vincepri committed May 2, 2023
1 parent 0ef0753 commit bd12701
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 24 deletions.
53 changes: 30 additions & 23 deletions pkg/webhook/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ type Server struct {
CertDir string

// CertName is the server certificate name. Defaults to tls.crt.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
CertName string

// KeyName is the server key name. Defaults to tls.key.
//
// Note: This option should only be set when TLSOpts does not override GetCertificate.
KeyName string

// ClientCAName is the CA certificate name which server used to verify remote(client)'s certificate.
Expand Down Expand Up @@ -169,32 +173,40 @@ func (s *Server) Start(ctx context.Context) error {
baseHookLog := log.WithName("webhooks")
baseHookLog.Info("Starting webhook server")

certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()

tlsMinVersion, err := tlsVersion(s.TLSMinVersion)
if err != nil {
return err
}

cfg := &tls.Config{ //nolint:gosec
NextProtos: []string{"h2"},
GetCertificate: certWatcher.GetCertificate,
MinVersion: tlsMinVersion,
NextProtos: []string{"h2"},
MinVersion: tlsMinVersion,
}
// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

if cfg.GetCertificate == nil {
certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

// Create the certificate watcher and
// set the config's GetCertificate on the TLSConfig
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return err
}
cfg.GetCertificate = certWatcher.GetCertificate

go func() {
if err := certWatcher.Start(ctx); err != nil {
log.Error(err, "certificate watcher error")
}
}()
}

// load CA to verify client certificate
// Load CA to verify client certificate, if configured.
if s.ClientCAName != "" {
certPool := x509.NewCertPool()
clientCABytes, err := os.ReadFile(filepath.Join(s.CertDir, s.ClientCAName))
Expand All @@ -211,11 +223,6 @@ func (s *Server) Start(ctx context.Context) error {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

// fallback TLS config ready, will now mutate if passer wants full control over it
for _, op := range s.TLSOpts {
op(cfg)
}

listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)), cfg)
if err != nil {
return err
Expand Down
51 changes: 50 additions & 1 deletion pkg/webhook/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"io"
"net"
"net/http"
"path"
"reflect"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -181,7 +183,7 @@ var _ = Describe("Webhook Server", func() {
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx))
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
Expand All @@ -199,6 +201,53 @@ var _ = Describe("Webhook Server", func() {
ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})

It("should prefer GetCertificate through TLSOpts", func() {
var finalCfg *tls.Config
finalCert, err := tls.LoadX509KeyPair(
path.Join(servingOpts.LocalServingCertDir, "tls.crt"),
path.Join(servingOpts.LocalServingCertDir, "tls.key"),
)
Expect(err).NotTo(HaveOccurred())
finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
return &finalCert, nil
}
server = &webhook.Server{
Host: servingOpts.LocalServingHost,
Port: servingOpts.LocalServingPort,
CertDir: servingOpts.LocalServingCertDir,
TLSMinVersion: "1.2",
TLSOpts: []func(*tls.Config){
func(cfg *tls.Config) {
cfg.GetCertificate = finalGetCertificate
// save cfg after changes to test against
finalCfg = cfg
},
},
}
server.Register("/somepath", &testHandler{})
doneCh := genericStartServer(func(ctx context.Context) {
Expect(server.Start(ctx)).To(Succeed())
})

Eventually(func() ([]byte, error) {
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}).Should(Equal([]byte("gadzooks!")))
Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
// We can't compare the functions directly, but we can compare their pointers
if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() {
Fail("GetCertificate was not set properly, or overwritten")
}
cert, err := finalCfg.GetCertificate(nil)
Expect(err).NotTo(HaveOccurred())
Expect(cert).To(BeEquivalentTo(&finalCert))

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
})
})

type testHandler struct {
Expand Down

0 comments on commit bd12701

Please sign in to comment.