From 032f2507c0bfb24e528d8811f907f0ad3b7051b8 Mon Sep 17 00:00:00 2001 From: Chris Hein Date: Mon, 22 Mar 2021 01:58:02 -0700 Subject: [PATCH] adding certwatcher as external package Signed-off-by: Chris Hein --- .../internal => }/certwatcher/certwatcher.go | 8 +- pkg/certwatcher/certwatcher_suite_test.go | 52 +++++ pkg/certwatcher/certwatcher_test.go | 187 ++++++++++++++++++ pkg/certwatcher/doc.go | 23 +++ pkg/certwatcher/example_test.go | 77 ++++++++ pkg/certwatcher/testdata/.gitkeep | 0 pkg/webhook/server.go | 2 +- 7 files changed, 344 insertions(+), 5 deletions(-) rename pkg/{webhook/internal => }/certwatcher/certwatcher.go (97%) create mode 100644 pkg/certwatcher/certwatcher_suite_test.go create mode 100644 pkg/certwatcher/certwatcher_test.go create mode 100644 pkg/certwatcher/doc.go create mode 100644 pkg/certwatcher/example_test.go create mode 100644 pkg/certwatcher/testdata/.gitkeep diff --git a/pkg/webhook/internal/certwatcher/certwatcher.go b/pkg/certwatcher/certwatcher.go similarity index 97% rename from pkg/webhook/internal/certwatcher/certwatcher.go rename to pkg/certwatcher/certwatcher.go index d681ef2a6b..e8e0e17a2b 100644 --- a/pkg/webhook/internal/certwatcher/certwatcher.go +++ b/pkg/certwatcher/certwatcher.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2021 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ var log = logf.RuntimeLog.WithName("certwatcher") // changes, it reads and parses both and calls an optional callback with the new // certificate. type CertWatcher struct { - sync.Mutex + sync.RWMutex currentCert *tls.Certificate watcher *fsnotify.Watcher @@ -64,8 +64,8 @@ func New(certPath, keyPath string) (*CertWatcher, error) { // GetCertificate fetches the currently loaded certificate, which may be nil. func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { - cw.Lock() - defer cw.Unlock() + cw.RLock() + defer cw.RUnlock() return cw.currentCert, nil } diff --git a/pkg/certwatcher/certwatcher_suite_test.go b/pkg/certwatcher/certwatcher_suite_test.go new file mode 100644 index 0000000000..dfbd40a524 --- /dev/null +++ b/pkg/certwatcher/certwatcher_suite_test.go @@ -0,0 +1,52 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certwatcher_test + +import ( + "os" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "sigs.k8s.io/controller-runtime/pkg/envtest/printer" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" +) + +var ( + certPath = "testdata/tls.crt" + keyPath = "testdata/tls.key" +) + +func TestSource(t *testing.T) { + RegisterFailHandler(Fail) + suiteName := "CertWatcher Suite" + RunSpecsWithDefaultAndCustomReporters(t, suiteName, []Reporter{printer.NewlineReporter{}, printer.NewProwReporter(suiteName)}) +} + +var _ = BeforeSuite(func(done Done) { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + close(done) +}, 60) + +var _ = AfterSuite(func(done Done) { + for _, file := range []string{certPath, keyPath} { + _ = os.Remove(file) + } + close(done) +}, 60) diff --git a/pkg/certwatcher/certwatcher_test.go b/pkg/certwatcher/certwatcher_test.go new file mode 100644 index 0000000000..5761768381 --- /dev/null +++ b/pkg/certwatcher/certwatcher_test.go @@ -0,0 +1,187 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certwatcher_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" +) + +var _ = Describe("CertWatcher", func() { + var _ = Describe("certwatcher New", func() { + It("should errors without cert/key", func() { + _, err := certwatcher.New("", "") + Expect(err).ToNot(BeNil()) + }) + }) + + var _ = Describe("certwatcher Start", func() { + var ( + ctx context.Context + ctxCancel context.CancelFunc + watcher *certwatcher.CertWatcher + ) + + BeforeEach(func() { + ctx, ctxCancel = context.WithCancel(context.Background()) + + err := writeCerts(certPath, keyPath, "127.0.0.1") + Expect(err).To(BeNil()) + + Eventually(func() error { + for _, file := range []string{certPath, keyPath} { + _, err := os.ReadFile(file) + if err != nil { + return err + } + continue + } + + return nil + }).Should(Succeed()) + + watcher, err = certwatcher.New(certPath, keyPath) + Expect(err).To(BeNil()) + }) + + startWatcher := func() (done <-chan struct{}) { + doneCh := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(doneCh) + Expect(watcher.Start(ctx)).To(Succeed()) + }() + // wait till we read first cert + Eventually(func() error { + err := watcher.ReadCertificate() + return err + }).Should(Succeed()) + return doneCh + } + + It("should read the initial cert/key", func() { + doneCh := startWatcher() + + ctxCancel() + Eventually(doneCh, "4s").Should(BeClosed()) + }) + + It("should reload currentCert when changed", func() { + doneCh := startWatcher() + + firstcert, _ := watcher.GetCertificate(nil) + + err := writeCerts(certPath, keyPath, "192.168.0.1") + Expect(err).To(BeNil()) + + Eventually(func() bool { + secondcert, _ := watcher.GetCertificate(nil) + first := firstcert.PrivateKey.(*rsa.PrivateKey) + return first.Equal(secondcert.PrivateKey) + }).ShouldNot(BeTrue()) + + ctxCancel() + Eventually(doneCh, "4s").Should(BeClosed()) + }) + }) +}) + +func writeCerts(certPath, keyPath, ip string) error { + var priv interface{} + var err error + priv, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + keyUsage := x509.KeyUsageDigitalSignature + if _, isRSA := priv.(*rsa.PrivateKey); isRSA { + keyUsage |= x509.KeyUsageKeyEncipherment + } + + var notBefore time.Time + notBefore = time.Now() + + notAfter := notBefore.Add(1 * time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Kubernetes"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + template.IPAddresses = append(template.IPAddresses, net.ParseIP(ip)) + + privkey := priv.(*rsa.PrivateKey) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, priv) + if err != nil { + return err + } + + certOut, err := os.Create(certPath) + if err != nil { + return err + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return err + } + if err := certOut.Close(); err != nil { + return err + } + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return err + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + return err + } + if err := keyOut.Close(); err != nil { + return err + } + return nil +} diff --git a/pkg/certwatcher/doc.go b/pkg/certwatcher/doc.go new file mode 100644 index 0000000000..40c2fc0bfb --- /dev/null +++ b/pkg/certwatcher/doc.go @@ -0,0 +1,23 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Package certwatcher is a helper for reloading Certificates from disk to be used +with tls servers. It provides a helper func `GetCertificate` which can be +called from `tls.Config` and passed into your tls.Listener. For a detailed +example server view pkg/webhook/server.go. +*/ +package certwatcher diff --git a/pkg/certwatcher/example_test.go b/pkg/certwatcher/example_test.go new file mode 100644 index 0000000000..d4285e4af7 --- /dev/null +++ b/pkg/certwatcher/example_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package certwatcher_test + +import ( + "context" + "crypto/tls" + "net/http" + + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" +) + +type sampleServer struct { +} + +func Example() { + // Setup Context + ctx := ctrl.SetupSignalHandler() + + // Initialize a new cert watcher with cert/key pari + watcher, err := certwatcher.New("ssl/tls.crt", "ssl/tls.key") + if err != nil { + panic(err) + } + + // Start goroutine with certwatcher running fsnotify against supplied certdir + go func() { + if err := watcher.Start(ctx); err != nil { + panic(err) + } + }() + + // Setup TLS listener using GetCertficate for fetching the cert when changes + listener, err := tls.Listen("tcp", "localhost:9443", &tls.Config{ + GetCertificate: watcher.GetCertificate, + }) + if err != nil { + panic(err) + } + + // Initialize your tls server + srv := &http.Server{ + Handler: &sampleServer{}, + } + + // Start goroutine for handling server shutdown. + go func() { + <-ctx.Done() + if err := srv.Shutdown(context.Background()); err != nil { + panic(err) + } + }() + + // Serve t + if err := srv.Serve(listener); err != nil && err != http.ErrServerClosed { + panic(err) + } +} + +func (s *sampleServer) ServeHTTP(http.ResponseWriter, *http.Request) { + +} diff --git a/pkg/certwatcher/testdata/.gitkeep b/pkg/certwatcher/testdata/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pkg/webhook/server.go b/pkg/webhook/server.go index 721df490a0..9fefc9a697 100644 --- a/pkg/webhook/server.go +++ b/pkg/webhook/server.go @@ -31,8 +31,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" "sigs.k8s.io/controller-runtime/pkg/runtime/inject" - "sigs.k8s.io/controller-runtime/pkg/webhook/internal/certwatcher" "sigs.k8s.io/controller-runtime/pkg/webhook/internal/metrics" )