Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hot-reload TLS certificates #2744

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/daemon/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
handler = cors.New(options).Handler(handler)
}

certs := c.GetTSLCertificatesForPublic(ctx)
certs := c.GetTLSCertificatesForPublic(ctx)

if tracer := r.Tracer(ctx); tracer.IsLoaded() {
handler = x.TraceHandler(handler)
Expand All @@ -130,7 +130,7 @@ func ServePublic(r driver.Registry, cmd *cobra.Command, args []string, slOpts *s
// #nosec G112 - the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{Certificates: certs, MinVersion: tls.VersionTLS12},
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
})
addr := c.PublicListenOn(ctx)

Expand Down Expand Up @@ -186,7 +186,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
r.PrometheusManager().RegisterRouter(router.Router)

n.UseHandler(router)
certs := c.GetTSLCertificatesForAdmin(ctx)
certs := c.GetTLSCertificatesForAdmin(ctx)

var handler http.Handler = n
if tracer := r.Tracer(ctx); tracer.IsLoaded() {
Expand All @@ -196,7 +196,7 @@ func ServeAdmin(r driver.Registry, cmd *cobra.Command, args []string, slOpts *se
// #nosec G112 - the correct settings are set by graceful.WithDefaults
server := graceful.WithDefaults(&http.Server{
Handler: handler,
TLSConfig: &tls.Config{Certificates: certs, MinVersion: tls.VersionTLS12},
TLSConfig: &tls.Config{GetCertificate: certs, MinVersion: tls.VersionTLS12},
})

addr := c.AdminListenOn(ctx)
Expand Down
26 changes: 2 additions & 24 deletions cmd/serve/root_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
package serve_test

import (
"encoding/base64"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"

"github.com/ory/kratos/x"

"github.com/ory/kratos/internal/testhelpers"
)

Expand All @@ -18,19 +11,7 @@ func TestServe(t *testing.T) {
}

func TestServeTLSBase64(t *testing.T) {
certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)

certRaw, err := os.ReadFile(certPath)
require.NoError(t, err)

keyRaw, err := os.ReadFile(keyPath)
require.NoError(t, err)

certBase64 := base64.StdEncoding.EncodeToString(certRaw)
keyBase64 := base64.StdEncoding.EncodeToString(keyRaw)
_, _, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)
publicPort, adminPort := testhelpers.StartE2EServerOnly(t,
"./stub/kratos.yml",
true,
Expand All @@ -45,10 +26,7 @@ func TestServeTLSBase64(t *testing.T) {
}

func TestServeTLSPaths(t *testing.T) {
certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)
certPath, keyPath, _, _ := testhelpers.GenerateTLSCertificateFilesForTests(t)

publicPort, adminPort := testhelpers.StartE2EServerOnly(t,
"./stub/kratos.yml",
Expand Down
43 changes: 31 additions & 12 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,11 @@ func (p *Config) CipherAlgorithm(ctx context.Context) string {
}
}

func (p *Config) GetTSLCertificatesForPublic(ctx context.Context) []tls.Certificate {
return p.getTSLCertificates(
type CertFunc = func(*tls.ClientHelloInfo) (*tls.Certificate, error)

func (p *Config) GetTLSCertificatesForPublic(ctx context.Context) CertFunc {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a few seconds to spot the typo 😂 nice find

return p.getTLSCertificates(
ctx,
"public",
p.GetProvider(ctx).String(ViperKeyPublicTLSCertBase64),
p.GetProvider(ctx).String(ViperKeyPublicTLSKeyBase64),
Expand All @@ -1294,8 +1297,9 @@ func (p *Config) GetTSLCertificatesForPublic(ctx context.Context) []tls.Certific
)
}

func (p *Config) GetTSLCertificatesForAdmin(ctx context.Context) []tls.Certificate {
return p.getTSLCertificates(
func (p *Config) GetTLSCertificatesForAdmin(ctx context.Context) CertFunc {
return p.getTLSCertificates(
ctx,
"admin",
p.GetProvider(ctx).String(ViperKeyAdminTLSCertBase64),
p.GetProvider(ctx).String(ViperKeyAdminTLSKeyBase64),
Expand All @@ -1304,16 +1308,31 @@ func (p *Config) GetTSLCertificatesForAdmin(ctx context.Context) []tls.Certifica
)
}

func (p *Config) getTSLCertificates(daemon, certBase64, keyBase64, certPath, keyPath string) []tls.Certificate {
cert, err := tlsx.Certificate(certBase64, keyBase64, certPath, keyPath)

if err == nil {
func (p *Config) getTLSCertificates(ctx context.Context, daemon, certBase64, keyBase64, certPath, keyPath string) CertFunc {
if certBase64 != "" && keyBase64 != "" {
cert, err := tlsx.CertificateFromBase64(certBase64, keyBase64)
if err != nil {
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return nil // reachable in unit tests when Fatalf is hooked
}
p.l.Infof("Setting up HTTPS for %s", daemon)
return cert
} else if !errors.Is(err, tlsx.ErrNoCertificatesConfigured) {
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return &cert, nil }
}
if certPath != "" && keyPath != "" {
errs := make(chan error, 1)
getCert, err := tlsx.GetCertificate(ctx, certPath, keyPath, errs)
if err != nil {
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
p.l.WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
return nil // reachable in unit tests when Fatalf is hooked
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
}
go func() {
for err := range errs {
p.l.WithError(err).Error("Failed to reload TLS certificates, using previous certificates")
}
}()
aeneasr marked this conversation as resolved.
Show resolved Hide resolved
p.l.Infof("Setting up HTTPS for %s (automatic certificate reloading active)", daemon)
return getCert
}

p.l.Infof("TLS has not been configured for %s, skipping", daemon)
return nil
}
Expand Down
128 changes: 53 additions & 75 deletions driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (

"github.com/ory/x/watcherx"

"github.com/ory/kratos/x"

"github.com/ory/kratos/internal/testhelpers"

"github.com/ory/x/configx"
Expand Down Expand Up @@ -775,126 +773,106 @@ func TestViperProvider_HaveIBeenPwned(t *testing.T) {
})
}

func newTestConfig(t *testing.T) (_ *config.Config, _ *test.Hook, exited *bool) {
l := logrusx.New("", "")
h := new(test.Hook)
exited = new(bool)
l.Logger.Hooks.Add(h)
l.Logger.ExitFunc = func(code int) { *exited = true }
config := config.MustNew(t, l, os.Stderr, configx.SkipValidation())
return config, h, exited
}

func TestLoadingTLSConfig(t *testing.T) {
ctx := context.Background()
t.Parallel()

certPath := filepath.Join(os.TempDir(), "e2e_test_cert_"+x.NewUUID().String()+".pem")
keyPath := filepath.Join(os.TempDir(), "e2e_test_key_"+x.NewUUID().String()+".pem")

testhelpers.GenerateTLSCertificateFilesForTests(t, certPath, keyPath)

certRaw, err := os.ReadFile(certPath)
assert.Nil(t, err)
certPath, keyPath, certBase64, keyBase64 := testhelpers.GenerateTLSCertificateFilesForTests(t)

keyRaw, err := os.ReadFile(keyPath)
assert.Nil(t, err)
t.Run("case=public: no TLS config", func(t *testing.T) {
p, hook, exited := newTestConfig(t)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.False(t, *exited)
})

certBase64 := base64.StdEncoding.EncodeToString(certRaw)
keyBase64 := base64.StdEncoding.EncodeToString(keyRaw)
t.Run("case=admin: no TLS config", func(t *testing.T) {
p, hook, exited := newTestConfig(t)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: loading inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyBase64, keyBase64)
p.MustSet(ctx, config.ViperKeyPublicTLSCertBase64, certBase64)
assert.NotNil(t, p.GetTSLCertificatesForPublic(ctx))
assert.NotNil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: loading certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyPath, keyPath)
p.MustSet(ctx, config.ViperKeyPublicTLSCertPath, certPath)
assert.NotNil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public", hook.LastEntry().Message)
assert.NotNil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Setting up HTTPS for public (automatic certificate reloading active)", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=public: failing to load inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyBase64, "empty")
p.MustSet(ctx, config.ViperKeyPublicTLSCertBase64, certBase64)
assert.Nil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=public: failing to load certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyPublicTLSKeyPath, "/dev/null")
p.MustSet(ctx, config.ViperKeyPublicTLSCertPath, certPath)
assert.Nil(t, p.GetTSLCertificatesForPublic(ctx))
assert.Equal(t, "TLS has not been configured for public, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForPublic(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=admin: loading inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyBase64, keyBase64)
p.MustSet(ctx, config.ViperKeyAdminTLSCertBase64, certBase64)
assert.NotNil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.NotNil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin", hook.LastEntry().Message)
assert.False(t, *exited)
})

t.Run("case=admin: loading certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { panic("") }
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyPath, keyPath)
p.MustSet(ctx, config.ViperKeyAdminTLSCertPath, certPath)
assert.NotNil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin", hook.LastEntry().Message)
assert.NotNil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Setting up HTTPS for admin (automatic certificate reloading active)", hook.LastEntry().Message)
alnr marked this conversation as resolved.
Show resolved Hide resolved
assert.False(t, *exited)
})

t.Run("case=admin: failing to load inline base64 certificate", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyBase64, "empty")
p.MustSet(ctx, config.ViperKeyAdminTLSCertBase64, certBase64)
assert.Nil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

t.Run("case=admin: failing to load certificate from a file", func(t *testing.T) {
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) {}
hook := new(test.Hook)
logger.Logger.Hooks.Add(hook)

p := config.MustNew(t, logger, os.Stderr, configx.SkipValidation())
p, hook, exited := newTestConfig(t)
p.MustSet(ctx, config.ViperKeyAdminTLSKeyPath, "/dev/null")
p.MustSet(ctx, config.ViperKeyAdminTLSCertPath, certPath)
assert.Nil(t, p.GetTSLCertificatesForAdmin(ctx))
assert.Equal(t, "TLS has not been configured for admin, skipping", hook.LastEntry().Message)
assert.Nil(t, p.GetTLSCertificatesForAdmin(ctx))
assert.Equal(t, "Unable to load HTTPS TLS Certificate", hook.LastEntry().Message)
assert.True(t, *exited)
})

}
Expand Down
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ require (
github.com/ory/kratos-client-go v0.6.3-alpha.1
github.com/ory/mail/v3 v3.0.0
github.com/ory/nosurf v1.2.7
github.com/ory/x v0.0.470
github.com/ory/x v0.0.474
github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.3.0
Expand Down Expand Up @@ -199,7 +199,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/serf v0.9.7 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.12.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
Expand Down Expand Up @@ -267,11 +267,11 @@ require (
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/viper v1.12.0 // indirect
github.com/subosito/gotenv v1.3.0 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/timtadh/data-structures v0.5.3 // indirect
Expand Down Expand Up @@ -315,7 +315,7 @@ require (
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.17.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 // indirect
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
Expand All @@ -327,7 +327,7 @@ require (
gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
Expand Down
Loading