diff --git a/examples/scratch-env/go.mod b/examples/scratch-env/go.mod index dceb4d12aa..8cf3ce725b 100644 --- a/examples/scratch-env/go.mod +++ b/examples/scratch-env/go.mod @@ -14,7 +14,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect diff --git a/examples/scratch-env/go.sum b/examples/scratch-env/go.sum index 89d30c15c1..a920487db5 100644 --- a/examples/scratch-env/go.sum +++ b/examples/scratch-env/go.sum @@ -13,8 +13,6 @@ github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8 github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= diff --git a/go.mod b/go.mod index 3fd1aa9562..a0a01baa74 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.22.0 require ( github.com/evanphx/json-patch/v5 v5.9.0 - github.com/fsnotify/fsnotify v1.7.0 github.com/go-logr/logr v1.4.2 github.com/go-logr/zapr v1.3.0 github.com/google/go-cmp v0.6.0 @@ -40,6 +39,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect diff --git a/pkg/certwatcher/certwatcher.go b/pkg/certwatcher/certwatcher.go index fe15fc0dd7..d295b29864 100644 --- a/pkg/certwatcher/certwatcher.go +++ b/pkg/certwatcher/certwatcher.go @@ -17,58 +17,55 @@ limitations under the License. package certwatcher import ( + "bytes" "context" "crypto/tls" - "fmt" + "os" "sync" "time" - "github.com/fsnotify/fsnotify" - kerrors "k8s.io/apimachinery/pkg/util/errors" - "k8s.io/apimachinery/pkg/util/sets" - "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics" logf "sigs.k8s.io/controller-runtime/pkg/internal/log" ) var log = logf.RuntimeLog.WithName("certwatcher") -// CertWatcher watches certificate and key files for changes. When either file -// changes, it reads and parses both and calls an optional callback with the new -// certificate. +const defaultWatchInterval = 10 * time.Second + +// CertWatcher watches certificate and key files for changes. +// It always returns the cached version, +// but periodically reads and parses certificate and key for changes +// and calls an optional callback with the new certificate. type CertWatcher struct { sync.RWMutex currentCert *tls.Certificate - watcher *fsnotify.Watcher + interval time.Duration certPath string keyPath string + cachedKeyPEMBlock []byte + // callback is a function to be invoked when the certificate changes. callback func(tls.Certificate) } // New returns a new CertWatcher watching the given certificate and key. func New(certPath, keyPath string) (*CertWatcher, error) { - var err error - cw := &CertWatcher{ certPath: certPath, keyPath: keyPath, + interval: defaultWatchInterval, } - // Initial read of certificate and key. - if err := cw.ReadCertificate(); err != nil { - return nil, err - } - - cw.watcher, err = fsnotify.NewWatcher() - if err != nil { - return nil, err - } + return cw, cw.ReadCertificate() +} - return cw, nil +// WithWatchInterval sets the watch interval and returns the CertWatcher pointer +func (cw *CertWatcher) WithWatchInterval(interval time.Duration) *CertWatcher { + cw.interval = interval + return cw } // RegisterCallback registers a callback to be invoked when the certificate changes. @@ -91,72 +88,64 @@ func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, // Start starts the watch on the certificate and key files. func (cw *CertWatcher) Start(ctx context.Context) error { - files := sets.New(cw.certPath, cw.keyPath) - - { - var watchErr error - if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) { - for _, f := range files.UnsortedList() { - if err := cw.watcher.Add(f); err != nil { - watchErr = err - return false, nil //nolint:nilerr // We want to keep trying. - } - // We've added the watch, remove it from the set. - files.Delete(f) - } - return true, nil - }); err != nil { - return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr})) - } - } - - go cw.Watch() + ticker := time.NewTicker(cw.interval) + defer ticker.Stop() log.Info("Starting certificate watcher") - - // Block until the context is done. - <-ctx.Done() - - return cw.watcher.Close() -} - -// Watch reads events from the watcher's channel and reacts to changes. -func (cw *CertWatcher) Watch() { for { select { - case event, ok := <-cw.watcher.Events: - // Channel is closed. - if !ok { - return + case <-ctx.Done(): + return nil + case <-ticker.C: + if err := cw.ReadCertificate(); err != nil { + log.Error(err, "failed read certificate") } + } + } +} - cw.handleEvent(event) - - case err, ok := <-cw.watcher.Errors: - // Channel is closed. - if !ok { - return - } +// updateCachedCertificate checks if the new certificate differs from the cache, +// updates it and returns the result if it was updated or not +func (cw *CertWatcher) updateCachedCertificate(cert *tls.Certificate, keyPEMBlock []byte) bool { + cw.Lock() + defer cw.Unlock() - log.Error(err, "certificate watch error") - } + if cw.currentCert != nil && + bytes.Equal(cw.currentCert.Certificate[0], cert.Certificate[0]) && + bytes.Equal(cw.cachedKeyPEMBlock, keyPEMBlock) { + log.V(7).Info("certificate already cached") + return false } + cw.currentCert = cert + cw.cachedKeyPEMBlock = keyPEMBlock + return true } // ReadCertificate reads the certificate and key files from disk, parses them, -// and updates the current certificate on the watcher. If a callback is set, it +// and updates the current certificate on the watcher if updated. If a callback is set, it // is invoked with the new certificate. func (cw *CertWatcher) ReadCertificate() error { metrics.ReadCertificateTotal.Inc() - cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath) + certPEMBlock, err := os.ReadFile(cw.certPath) + if err != nil { + metrics.ReadCertificateErrors.Inc() + return err + } + keyPEMBlock, err := os.ReadFile(cw.keyPath) if err != nil { metrics.ReadCertificateErrors.Inc() return err } - cw.Lock() - cw.currentCert = &cert - cw.Unlock() + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + metrics.ReadCertificateErrors.Inc() + return err + } + + if !cw.updateCachedCertificate(&cert, keyPEMBlock) { + return nil + } log.Info("Updated current TLS certificate") @@ -170,39 +159,3 @@ func (cw *CertWatcher) ReadCertificate() error { } return nil } - -func (cw *CertWatcher) handleEvent(event fsnotify.Event) { - // Only care about events which may modify the contents of the file. - if !(isWrite(event) || isRemove(event) || isCreate(event) || isChmod(event)) { - return - } - - log.V(1).Info("certificate event", "event", event) - - // If the file was removed or renamed, re-add the watch to the previous name - if isRemove(event) || isChmod(event) { - if err := cw.watcher.Add(event.Name); err != nil { - log.Error(err, "error re-watching file") - } - } - - if err := cw.ReadCertificate(); err != nil { - log.Error(err, "error re-reading certificate") - } -} - -func isWrite(event fsnotify.Event) bool { - return event.Op.Has(fsnotify.Write) -} - -func isCreate(event fsnotify.Event) bool { - return event.Op.Has(fsnotify.Create) -} - -func isRemove(event fsnotify.Event) bool { - return event.Op.Has(fsnotify.Remove) -} - -func isChmod(event fsnotify.Event) bool { - return event.Op.Has(fsnotify.Chmod) -} diff --git a/pkg/certwatcher/certwatcher_suite_test.go b/pkg/certwatcher/certwatcher_suite_test.go index 2d0f677685..d0d9a72a62 100644 --- a/pkg/certwatcher/certwatcher_suite_test.go +++ b/pkg/certwatcher/certwatcher_suite_test.go @@ -22,6 +22,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" ) diff --git a/pkg/certwatcher/certwatcher_test.go b/pkg/certwatcher/certwatcher_test.go index 1fb247581f..fa7e09adae 100644 --- a/pkg/certwatcher/certwatcher_test.go +++ b/pkg/certwatcher/certwatcher_test.go @@ -17,6 +17,7 @@ limitations under the License. package certwatcher_test import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -34,6 +35,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/prometheus/client_golang/prometheus/testutil" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics" ) @@ -80,7 +82,7 @@ var _ = Describe("CertWatcher", func() { go func() { defer GinkgoRecover() defer close(doneCh) - Expect(watcher.Start(ctx)).To(Succeed()) + Expect(watcher.WithWatchInterval(time.Second).Start(ctx)).To(Succeed()) }() // wait till we read first cert Eventually(func() error { @@ -113,7 +115,7 @@ var _ = Describe("CertWatcher", func() { Eventually(func() bool { secondcert, _ := watcher.GetCertificate(nil) first := firstcert.PrivateKey.(*rsa.PrivateKey) - return first.Equal(secondcert.PrivateKey) + return first.Equal(secondcert.PrivateKey) || bytes.Equal(firstcert.Certificate[0], secondcert.Certificate[0]) }).ShouldNot(BeTrue()) ctxCancel() @@ -143,7 +145,7 @@ var _ = Describe("CertWatcher", func() { Eventually(func() bool { secondcert, _ := watcher.GetCertificate(nil) first := firstcert.PrivateKey.(*rsa.PrivateKey) - return first.Equal(secondcert.PrivateKey) + return first.Equal(secondcert.PrivateKey) || bytes.Equal(firstcert.Certificate[0], secondcert.Certificate[0]) }).ShouldNot(BeTrue()) ctxCancel() @@ -151,6 +153,33 @@ var _ = Describe("CertWatcher", func() { Expect(called.Load()).To(BeNumerically(">=", 1)) }) + It("should reload currentCert after move out", func() { + doneCh := startWatcher() + called := atomic.Int64{} + watcher.RegisterCallback(func(crt tls.Certificate) { + called.Add(1) + Expect(crt.Certificate).ToNot(BeEmpty()) + }) + + firstcert, _ := watcher.GetCertificate(nil) + + Expect(os.Rename(certPath, certPath+".old")).To(Succeed()) + Expect(os.Rename(keyPath, keyPath+".old")).To(Succeed()) + + err := writeCerts(certPath, keyPath, "192.168.0.3") + Expect(err).ToNot(HaveOccurred()) + + Eventually(func() bool { + secondcert, _ := watcher.GetCertificate(nil) + first := firstcert.PrivateKey.(*rsa.PrivateKey) + return first.Equal(secondcert.PrivateKey) || bytes.Equal(firstcert.Certificate[0], secondcert.Certificate[0]) + }, "10s", "1s").ShouldNot(BeTrue()) + + ctxCancel() + Eventually(doneCh, "4s").Should(BeClosed()) + Expect(called.Load()).To(BeNumerically(">=", 1)) + }) + Context("prometheus metric read_certificate_total", func() { var readCertificateTotalBefore float64 var readCertificateErrorsBefore float64 @@ -165,8 +194,8 @@ var _ = Describe("CertWatcher", func() { Eventually(func() error { readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal) - if readCertificateTotalAfter != readCertificateTotalBefore+1.0 { - return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter) + if readCertificateTotalAfter < readCertificateTotalBefore+1.0 { + return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter) } return nil }, "4s").Should(Succeed()) @@ -180,8 +209,8 @@ var _ = Describe("CertWatcher", func() { Eventually(func() error { readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal) - if readCertificateTotalAfter != readCertificateTotalBefore+1.0 { - return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter) + if readCertificateTotalAfter < readCertificateTotalBefore+1.0 { + return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter) } readCertificateTotalBefore = readCertificateTotalAfter return nil @@ -192,15 +221,15 @@ var _ = Describe("CertWatcher", func() { // Note, we are checking two errors here, because os.Remove generates two fsnotify events: Chmod + Remove Eventually(func() error { readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal) - if readCertificateTotalAfter != readCertificateTotalBefore+2.0 { - return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter) + if readCertificateTotalAfter < readCertificateTotalBefore+2.0 { + return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter) } return nil }, "4s").Should(Succeed()) Eventually(func() error { readCertificateErrorsAfter := testutil.ToFloat64(metrics.ReadCertificateErrors) - if readCertificateErrorsAfter != readCertificateErrorsBefore+2.0 { - return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter) + if readCertificateErrorsAfter < readCertificateErrorsBefore+2.0 { + return fmt.Errorf("metric read certificate errors expected at least: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter) } return nil }, "4s").Should(Succeed()) diff --git a/pkg/certwatcher/example_test.go b/pkg/certwatcher/example_test.go index e322aeebfc..e85b2403cb 100644 --- a/pkg/certwatcher/example_test.go +++ b/pkg/certwatcher/example_test.go @@ -39,7 +39,7 @@ func Example() { panic(err) } - // Start goroutine with certwatcher running fsnotify against supplied certdir + // Start goroutine with certwatcher running against supplied cert go func() { if err := watcher.Start(ctx); err != nil { panic(err) diff --git a/pkg/certwatcher/metrics/metrics.go b/pkg/certwatcher/metrics/metrics.go index 05869eff03..f128abbcf0 100644 --- a/pkg/certwatcher/metrics/metrics.go +++ b/pkg/certwatcher/metrics/metrics.go @@ -18,6 +18,7 @@ package metrics import ( "github.com/prometheus/client_golang/prometheus" + "sigs.k8s.io/controller-runtime/pkg/metrics" )