From 2d952ac877fcb86923cb38a7925ed872cade374c Mon Sep 17 00:00:00 2001 From: Brad Ison Date: Mon, 20 May 2019 11:54:58 -0400 Subject: [PATCH] webhook: Handle TLS certificate rotation This watches the configured TLS certificate and key files for admission webhooks. On any change, it attempts to reload them from disk. --- Gopkg.lock | 1 + .../internal/certwatcher/certwatcher.go | 162 ++++++++++++++++++ pkg/webhook/server.go | 18 +- 3 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 pkg/webhook/internal/certwatcher/certwatcher.go diff --git a/Gopkg.lock b/Gopkg.lock index 6b46885c86..dbd73bbbc9 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -930,6 +930,7 @@ "go.uber.org/zap", "go.uber.org/zap/buffer", "go.uber.org/zap/zapcore", + "gopkg.in/fsnotify.v1", "k8s.io/api/admission/v1beta1", "k8s.io/api/apps/v1", "k8s.io/api/apps/v1beta1", diff --git a/pkg/webhook/internal/certwatcher/certwatcher.go b/pkg/webhook/internal/certwatcher/certwatcher.go new file mode 100644 index 0000000000..4a20319b50 --- /dev/null +++ b/pkg/webhook/internal/certwatcher/certwatcher.go @@ -0,0 +1,162 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certwatcher + +import ( + "crypto/tls" + "sync" + + "gopkg.in/fsnotify.v1" + logf "sigs.k8s.io/controller-runtime/pkg/internal/log" +) + +var log = logf.RuntimeLog.WithName("certwatcher") + +// CertWatcher watches certificate and key files for changes. When either file +// changes, it reads and parses both and calls an optional callback with the new +// certificate. +type CertWatcher struct { + sync.Mutex + + currentCert *tls.Certificate + watcher *fsnotify.Watcher + + certPath string + keyPath string +} + +// New returns a new CertWatcher watching the given certificate and key. +func New(certPath, keyPath string) (*CertWatcher, error) { + var err error + + cw := &CertWatcher{ + certPath: certPath, + keyPath: keyPath, + } + + // Initial read of certificate and key. + if err := cw.ReadCertificate(); err != nil { + return nil, err + } + + cw.watcher, err = fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + return cw, nil +} + +// GetCertificate fetches the currently loaded certificate, which may be nil. +func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cw.Lock() + defer cw.Unlock() + return cw.currentCert, nil +} + +// Start starts the watch on the certificate and key files. +func (cw *CertWatcher) Start(stopCh <-chan struct{}) error { + files := []string{cw.certPath, cw.keyPath} + + for _, f := range files { + if err := cw.watcher.Add(f); err != nil { + return err + } + } + + go cw.Watch() + + log.Info("Starting certificate watcher") + + // Block until the stop channel is closed. + <-stopCh + + return cw.watcher.Close() +} + +// Watch reads events from the watcher's channel and reacts to changes. +func (cw *CertWatcher) Watch() { + for { + select { + case event, ok := <-cw.watcher.Events: + // Channel is closed. + if !ok { + return + } + + cw.handleEvent(event) + + case err, ok := <-cw.watcher.Errors: + // Channel is closed. + if !ok { + return + } + + log.Error(err, "certificate watch error") + } + } +} + +// ReadCertificate reads the certificate and key files from disk, parses them, +// and updates the current certificate on the watcher. If a callback is set, it +// is invoked with the new certificate. +func (cw *CertWatcher) ReadCertificate() error { + cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath) + if err != nil { + return err + } + + cw.Lock() + cw.currentCert = &cert + cw.Unlock() + + log.Info("Updated current TLS certiface") + + return nil +} + +func (cw *CertWatcher) handleEvent(event fsnotify.Event) { + // Only care about events which may modify the contents of the file. + if !(isWrite(event) || isRemove(event) || isCreate(event)) { + return + } + + log.V(1).Info("certificate event", "event", event) + + // If the file was removed, re-add the watch. + if isRemove(event) { + if err := cw.watcher.Add(event.Name); err != nil { + log.Error(err, "error re-watching file") + } + } + + if err := cw.ReadCertificate(); err != nil { + log.Error(err, "error re-reading certificate") + } +} + +func isWrite(event fsnotify.Event) bool { + return event.Op&fsnotify.Write == fsnotify.Write +} + +func isCreate(event fsnotify.Event) bool { + return event.Op&fsnotify.Create == fsnotify.Create +} + +func isRemove(event fsnotify.Event) bool { + return event.Op&fsnotify.Remove == fsnotify.Remove +} diff --git a/pkg/webhook/server.go b/pkg/webhook/server.go index 7e8e032d55..241eb3ca29 100644 --- a/pkg/webhook/server.go +++ b/pkg/webhook/server.go @@ -23,11 +23,13 @@ import ( "net" "net/http" "path" + "path/filepath" "strconv" "sync" "time" "sigs.k8s.io/controller-runtime/pkg/runtime/inject" + "sigs.k8s.io/controller-runtime/pkg/webhook/internal/certwatcher" "sigs.k8s.io/controller-runtime/pkg/webhook/internal/metrics" ) @@ -132,15 +134,23 @@ func (s *Server) Start(stop <-chan struct{}) error { } } - // TODO: watch the cert dir. Reload the cert if it changes - cert, err := tls.LoadX509KeyPair(path.Join(s.CertDir, certName), path.Join(s.CertDir, keyName)) + certPath := filepath.Join(s.CertDir, certName) + keyPath := filepath.Join(s.CertDir, keyName) + + certWatcher, err := certwatcher.New(certPath, keyPath) if err != nil { return err } + go func() { + if err := certWatcher.Start(stop); err != nil { + log.Error(err, "certificate watcher error") + } + }() + cfg := &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{"h2"}, + NextProtos: []string{"h2"}, + GetCertificate: certWatcher.GetCertificate, } listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(int(s.Port))), cfg)