Skip to content

Commit

Permalink
🐛 fix issue when webhook server refreshing cert
Browse files Browse the repository at this point in the history
  • Loading branch information
Mengqi Yu committed Jan 9, 2019
1 parent b3d5de2 commit 43a3f31
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 12 deletions.
40 changes: 28 additions & 12 deletions pkg/webhook/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"k8s.io/apimachinery/pkg/runtime"
apitypes "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/runtime/inject"
Expand All @@ -36,6 +37,9 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/types"
)

// default interval for checking cert is 90 days (~3 months)
var defaultCertRefreshInterval = 3 * 30 * 24 * time.Hour

// ServerOptions are options for configuring an admission webhook server.
type ServerOptions struct {
// Port is the port number that the server will serve.
Expand Down Expand Up @@ -128,6 +132,9 @@ type Server struct {
// manager is the manager that this webhook server will be registered.
manager manager.Manager

// httpServer is the actual server that serves the traffic.
httpServer *http.Server

once sync.Once
}

Expand Down Expand Up @@ -209,21 +216,21 @@ func (s *Server) Start(stop <-chan struct{}) error {
return s.run(stop)
}

func (s *Server) run(stop <-chan struct{}) error {
srv := &http.Server{
Addr: fmt.Sprintf(":%v", s.Port),
Handler: s.sMux,
}
func (s *Server) run(stop <-chan struct{}) error { // nolint: gocyclo
errCh := make(chan error)
serveFn := func() {
errCh <- srv.ListenAndServeTLS(path.Join(s.CertDir, writer.ServerCertName), path.Join(s.CertDir, writer.ServerKeyName))
s.httpServer = &http.Server{
Addr: fmt.Sprintf(":%v", s.Port),
Handler: s.sMux,
}
log.Info("starting the webhook server.")
errCh <- s.httpServer.ListenAndServeTLS(path.Join(s.CertDir, writer.ServerCertName), path.Join(s.CertDir, writer.ServerKeyName))
}

shutdownHappend := false
timer := time.Tick(wait.Jitter(defaultCertRefreshInterval, 0.1))
go serveFn()
for {
// TODO(mengqiy): add jitter to the timer
// Could use https://godoc.org/k8s.io/apimachinery/pkg/util/wait#Jitter
timer := time.Tick(6 * 30 * 24 * time.Hour)
select {
case <-timer:
changed, err := s.RefreshCert()
Expand All @@ -232,19 +239,28 @@ func (s *Server) run(stop <-chan struct{}) error {
return err
}
if !changed {
log.Info("no need to reload the certificates.")
continue
}
log.Info("server is shutting down to reload the certificates.")
err = srv.Shutdown(context.Background())
shutdownHappend = true
err = s.httpServer.Shutdown(context.Background())
if err != nil {
log.Error(err, "encountering error when shutting down")
return err
}
timer = time.Tick(wait.Jitter(defaultCertRefreshInterval, 0.1))
go serveFn()
case <-stop:
return nil
return s.httpServer.Shutdown(context.Background())
case e := <-errCh:
return e
// Don't exit when getting an http.ErrServerClosed error due to restarting the server.
if shutdownHappend && e == http.ErrServerClosed {
shutdownHappend = false
} else if e != nil {
log.Error(e, "server returns an unexpected error")
return e
}
}
}
}
Expand Down
160 changes: 160 additions & 0 deletions pkg/webhook/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package webhook

import (
"context"
"io/ioutil"
"net/http"
"time"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

"k8s.io/apimachinery/pkg/runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/internal/cert"
"sigs.k8s.io/controller-runtime/pkg/webhook/internal/cert/generator"
"sigs.k8s.io/controller-runtime/pkg/webhook/internal/cert/writer"
"sigs.k8s.io/testing_frameworks/integration/addr"
)

type fakeCertWriter struct {
changed bool
}

func (cw *fakeCertWriter) EnsureCert(dnsName string) (*generator.Artifacts, bool, error) {
return &generator.Artifacts{}, cw.changed, nil
}

func (cw *fakeCertWriter) Inject(objs ...runtime.Object) error {
return nil
}

var _ = Describe("webhook server", func() {
Describe("run", func() {
var stop chan struct{}
var s *Server
var cn = "example.com"

BeforeEach(func() {
port, _, err := addr.Suggest()
Expect(err).NotTo(HaveOccurred())
s = &Server{
sMux: http.NewServeMux(),
ServerOptions: ServerOptions{
Port: int32(port),
BootstrapOptions: &BootstrapOptions{
Host: &cn,
},
},
}

cg := &generator.SelfSignedCertGenerator{}
s.CertDir, err = ioutil.TempDir("/tmp", "controller-runtime-")
Expect(err).NotTo(HaveOccurred())
certWriter, err := writer.NewFSCertWriter(writer.FSCertWriterOptions{CertGenerator: cg, Path: s.CertDir})
Expect(err).NotTo(HaveOccurred())
_, _, err = certWriter.EnsureCert(cn)
Expect(err).NotTo(HaveOccurred())

stop = make(chan struct{})
})

It("should stop if the stop channel is closed", func() {
var e error
go func() {
defer GinkgoRecover()
e = s.run(stop)
}()

Eventually(func() *http.Server {
return s.httpServer
}).ShouldNot(BeNil())

close(stop)
Expect(e).NotTo(HaveOccurred())
})

It("should exit if the server encounter an unexpected error", func() {
var e error
go func() {
defer GinkgoRecover()
e = s.run(stop)
}()

Eventually(func() *http.Server {
return s.httpServer
}).ShouldNot(BeNil())

err := s.httpServer.Shutdown(context.Background())
Expect(err).NotTo(HaveOccurred())

Eventually(func() error {
return e
}).Should(Equal(http.ErrServerClosed))

close(stop)
})

It("should be able to keep existing valid cert when timer fires", func() {
var e error
defaultCertRefreshInterval = 500 * time.Millisecond

s.certProvisioner = &cert.Provisioner{
CertWriter: &fakeCertWriter{changed: false},
}

go func() {
defer GinkgoRecover()
e = s.run(stop)
}()

// Wait for multiple cycles of timer firing
time.Sleep(2 * time.Second)
Expect(e).NotTo(HaveOccurred())

close(stop)
})

It("should be able to rotate the cert when timer fires", func() {
var e error
defaultCertRefreshInterval = 500 * time.Millisecond
s.certProvisioner = &cert.Provisioner{
CertWriter: &fakeCertWriter{changed: true},
}

go func() {
defer GinkgoRecover()
e = s.run(stop)
}()

Eventually(func() *http.Server {
return s.httpServer
}).ShouldNot(BeNil())

// Wait for multiple cycles of timer firing
time.Sleep(2 * time.Second)
Expect(e).NotTo(HaveOccurred())

close(stop)
})

AfterEach(func() {
defaultCertRefreshInterval = 3 * 30 * 24 * time.Hour
}, 60)
})
})
57 changes: 57 additions & 0 deletions pkg/webhook/webhook_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package webhook

import (
"testing"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"sigs.k8s.io/controller-runtime/pkg/envtest"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"
)

func TestSource(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecsWithDefaultAndCustomReporters(t, "Webhook Integration Suite", []Reporter{envtest.NewlineReporter{}})
}

var testenv *envtest.Environment
var cfg *rest.Config
var clientset *kubernetes.Clientset

var _ = BeforeSuite(func(done Done) {
logf.SetLogger(logf.ZapLoggerTo(GinkgoWriter, true))

testenv = &envtest.Environment{}

var err error
cfg, err = testenv.Start()
Expect(err).NotTo(HaveOccurred())

clientset, err = kubernetes.NewForConfig(cfg)
Expect(err).NotTo(HaveOccurred())

close(done)
}, 60)

var _ = AfterSuite(func() {
testenv.Stop()
})

0 comments on commit 43a3f31

Please sign in to comment.