From 02e323205d6d4f992954f4eaae9efb1d67ce1a25 Mon Sep 17 00:00:00 2001 From: Tayler Geiger Date: Wed, 8 May 2024 15:26:15 -0500 Subject: [PATCH] Add file watcher for TLS cert and key Also adds error for missing either tls-key or tls-cert arguments. Signed-off-by: Tayler Geiger --- cmd/manager/main.go | 35 ++++++---- go.mod | 2 +- internal/certwatcher/tls_cert_watcher.go | 87 ++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 14 deletions(-) create mode 100644 internal/certwatcher/tls_cert_watcher.go diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 512147f..157eaf3 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -40,6 +40,7 @@ import ( metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "github.com/operator-framework/catalogd/api/core/v1alpha1" + "github.com/operator-framework/catalogd/internal/certwatcher" "github.com/operator-framework/catalogd/internal/garbagecollection" "github.com/operator-framework/catalogd/internal/source" "github.com/operator-framework/catalogd/internal/third_party/server" @@ -92,8 +93,8 @@ func main() { flag.StringVar(&cacheDir, "cache-dir", "/var/cache/", "The directory in the filesystem that catalogd will use for file based caching") flag.BoolVar(&catalogdVersion, "version", false, "print the catalogd version and exit") flag.DurationVar(&gcInterval, "gc-interval", 12*time.Hour, "interval in which garbage collection should be run against the catalog content cache") - flag.StringVar(&certFile, "tls-cert", "", "The certificate file used for serving catalog contents over HTTPS") - flag.StringVar(&keyFile, "tls-key", "", "The key file used for serving catalog contents over HTTPS") + flag.StringVar(&certFile, "tls-cert", "", "The certificate file used for serving catalog contents over HTTPS. Requires tls-key.") + flag.StringVar(&keyFile, "tls-key", "", "The key file used for serving catalog contents over HTTPS. Requires tls-cert.") opts := zap.Options{ Development: true, } @@ -150,29 +151,37 @@ func main() { os.Exit(1) } - if certFile != "" && keyFile != "" { - cert, err := tls.LoadX509KeyPair(certFile, keyFile) + switch { + case certFile == "" && keyFile == "": + listener, err = net.Listen("tcp", catalogServerAddr) if err != nil { - setupLog.Error(err, "unable to load certificate key pair") + setupLog.Error(err, "unable to create HTTP server listener") os.Exit(1) } + externalAddr = "http://" + externalAddr + case certFile != "" && keyFile != "": + tlsFileWatcher := certwatcher.TLSCertificateWatcher{ + Logger: ctrl.Log.WithName("tls-file-watcher"), + CertFile: certFile, + KeyFile: keyFile, + } config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS13, + GetCertificate: tlsFileWatcher.GetCertificate, + MinVersion: tls.VersionTLS13, } listener, err = tls.Listen("tcp", catalogServerAddr, config) if err != nil { setupLog.Error(err, "unable to create HTTPS server listener") os.Exit(1) } - externalAddr = "https://" + externalAddr - } else { - listener, err = net.Listen("tcp", catalogServerAddr) - if err != nil { - setupLog.Error(err, "unable to create HTTP server listener") + if err := mgr.Add(&tlsFileWatcher); err != nil { + setupLog.Error(err, "unable to start TLS file watcher") os.Exit(1) } - externalAddr = "http://" + externalAddr + externalAddr = "https://" + externalAddr + default: + setupLog.Error(nil, "unable to configure TLS certificates, both tls-cert and tls-key arguments are required") + os.Exit(1) } baseStorageURL, err := url.Parse(fmt.Sprintf("%s/catalogs/", externalAddr)) diff --git a/go.mod b/go.mod index de672e5..55db4ab 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.21.8 require ( github.com/blang/semver/v4 v4.0.0 github.com/containerd/containerd v1.7.11 + github.com/fsnotify/fsnotify v1.6.0 github.com/go-logr/logr v1.4.1 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.15.2 @@ -75,7 +76,6 @@ require ( github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.6.0 // indirect - github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.5.0 // indirect github.com/go-git/go-git/v5 v5.11.0 // indirect diff --git a/internal/certwatcher/tls_cert_watcher.go b/internal/certwatcher/tls_cert_watcher.go new file mode 100644 index 0000000..b95d7ed --- /dev/null +++ b/internal/certwatcher/tls_cert_watcher.go @@ -0,0 +1,87 @@ +package certwatcher + +import ( + "context" + "crypto/tls" + "path" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/manager" +) + +var _ manager.Runnable = (*TLSCertificateWatcher)(nil) + +type TLSCertificateWatcher struct { + sync.Mutex + certificate *tls.Certificate + Logger logr.Logger + CertFile string + KeyFile string +} + +func (t *TLSCertificateWatcher) Start(ctx context.Context) error { + // Run once on startup + err := t.loadCertificate() + if err != nil { + return err + } + t.Logger.Info("tls-key and tls-cert initialized", "tls-cert", t.CertFile, "tls-key", t.KeyFile) + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + defer watcher.Close() + + if path.Dir(t.CertFile) == path.Dir(t.KeyFile) { + err = watcher.Add(path.Dir(t.CertFile)) + if err != nil { + return err + } + } else { + for _, i := range []string{t.CertFile, t.KeyFile} { + err = watcher.Add(path.Dir(i)) + if err != nil { + return err + } + } + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case event := <-watcher.Events: + if event.Has(fsnotify.Write) { + err := t.loadCertificate() + if err != nil { + return err + } + t.Logger.Info("write detected, reloaded tls-cert and tls-key", "tls-cert", t.CertFile, "tls-key", t.KeyFile) + } + case err := <-watcher.Errors: + t.Logger.Error(err, "issue watching TLS certificate files") + } + } +} + +func (t *TLSCertificateWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + t.Lock() + defer t.Unlock() + return t.certificate, nil +} + +func (t *TLSCertificateWatcher) loadCertificate() error { + t.Lock() + defer t.Unlock() + + cert, err := tls.LoadX509KeyPair(t.CertFile, t.KeyFile) + if err != nil { + return err + } + + t.certificate = &cert + return nil +}