From 09802a3b7563f95ef2d52e0ff9b913b3a2d0a259 Mon Sep 17 00:00:00 2001 From: Steven Soroka Date: Tue, 1 Mar 2022 17:37:35 -0500 Subject: [PATCH] feat: Certificates management (#1086) * basics of mTLS working * allow mtls to be enabled or disabled * use db for cert backend --- internal/certs/certs.go | 2 +- internal/cmd/cmd.go | 10 +- internal/cmd/login.go | 19 +- internal/errors.go | 7 +- internal/server/certificates_test.go | 104 +++++ internal/server/data/data.go | 2 + internal/server/data/trustedcerts.go | 68 +++ internal/server/middleware_test.go | 13 + internal/server/models/trustedcertificate.go | 28 ++ internal/server/models/types.go | 30 ++ internal/server/routes.go | 8 +- internal/server/server.go | 257 +++++++++-- pki/certificates.go | 111 ++++- pki/certificates_test.go | 16 +- pki/keypair.go | 79 ++++ pki/native.go | 446 +++++++++++-------- pki/native_test.go | 52 ++- pki/pem.go | 69 +++ 18 files changed, 1052 insertions(+), 269 deletions(-) create mode 100644 internal/server/certificates_test.go create mode 100644 internal/server/data/trustedcerts.go create mode 100644 internal/server/models/trustedcertificate.go create mode 100644 internal/server/models/types.go create mode 100644 pki/keypair.go create mode 100644 pki/pem.go diff --git a/internal/certs/certs.go b/internal/certs/certs.go index 85cc07082c..c81fe4f073 100644 --- a/internal/certs/certs.go +++ b/internal/certs/certs.go @@ -34,7 +34,7 @@ func SelfSignedCert(hosts []string) ([]byte, []byte, error) { Subject: pkix.Name{ Organization: []string{"Infra"}, }, - NotBefore: time.Now(), + NotBefore: time.Now().Add(-5 * time.Minute), NotAfter: time.Now().AddDate(0, 0, 365), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 7fde5e9b84..814c9a4748 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -330,7 +330,7 @@ func newServerCmd() *cobra.Command { strcase.ConfigureAcronym("enable-ui", "enableUI") strcase.ConfigureAcronym("ui-proxy-url", "uiProxyURL") - var options server.Options + options := server.Options{} if err := parseOptions(cmd, &options, "INFRA_SERVER"); err != nil { return err } @@ -570,11 +570,14 @@ func newVersionCmd() *cobra.Command { } } +var nonInteractiveMode bool + func NewRootCmd() (*cobra.Command, error) { cobra.EnableCommandSorting = false type rootOptions struct { - LogLevel string `mapstructure:"logLevel"` + LogLevel string `mapstructure:"logLevel"` + NonInteractive bool `mapstructure:"nonInteractive"` } rootCmd := &cobra.Command{ @@ -588,6 +591,8 @@ func NewRootCmd() (*cobra.Command, error) { return err } + nonInteractiveMode = options.NonInteractive + return logging.SetLevel(options.LogLevel) }, } @@ -609,6 +614,7 @@ func NewRootCmd() (*cobra.Command, error) { rootCmd.AddCommand(newVersionCmd()) rootCmd.PersistentFlags().String("log-level", "info", "Set the log level. One of error, warn, info, or debug") + rootCmd.PersistentFlags().Bool("non-interactive", false, "don't assume an interactive terminal, even if there is one") return rootCmd, nil } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index caf10e1634..48326cf7d0 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -27,7 +27,7 @@ const cliLoginRedirectURL = "http://localhost:8301" func relogin() error { // TODO (https://github.com/infrahq/infra/issues/488): support non-interactive login - if !term.IsTerminal(int(os.Stdin.Fd())) { + if !isInteractiveMode() { return errors.New("Non-interactive login is not supported") } @@ -69,9 +69,22 @@ func relogin() error { return finishLogin(currentConfig.Host, uid.NewUserPolymorphicID(loginRes.ID), loginRes.Name, loginRes.AccessKey, currentConfig.SkipTLSVerify, 0) } +func isInteractiveMode() bool { + if nonInteractiveMode { + // user explicitly asked for a non-interactive terminal + return false + } + + if os.Stdin == nil { + return false + } + + return term.IsTerminal(int(os.Stdin.Fd())) +} + func login(host string) error { // TODO (https://github.com/infrahq/infra/issues/488): support non-interactive login - if !term.IsTerminal(int(os.Stdin.Fd())) { + if !isInteractiveMode() { return errors.New("Non-interactive login is not supported") } @@ -162,7 +175,7 @@ func login(host string) error { } // access key - if option == len(options) - 1 { + if option == len(options)-1 { if accessKey == "" { err = survey.AskOne(&survey.Password{Message: "Access Key:"}, &accessKey, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) if err != nil { diff --git a/internal/errors.go b/internal/errors.go index e978a270be..a8fd089554 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -12,7 +12,8 @@ var ( // ErrForbidden means you don't have permissions to the requested resource ErrForbidden = fmt.Errorf("forbidden") - ErrDuplicate = fmt.Errorf("duplicate record") - ErrNotFound = fmt.Errorf("record not found") - ErrBadRequest = fmt.Errorf("bad request") + ErrDuplicate = fmt.Errorf("duplicate record") + ErrNotFound = fmt.Errorf("record not found") + ErrBadRequest = fmt.Errorf("bad request") + ErrNotImplemented = fmt.Errorf("not implemented") ) diff --git a/internal/server/certificates_test.go b/internal/server/certificates_test.go new file mode 100644 index 0000000000..6888ce42d6 --- /dev/null +++ b/internal/server/certificates_test.go @@ -0,0 +1,104 @@ +package server + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/pki" + "github.com/infrahq/infra/uid" + "github.com/stretchr/testify/require" +) + +func TestCertificateSigningWorks(t *testing.T) { + db := setupDB(t) + + cp, err := pki.NewNativeCertificateProvider(db, pki.NativeCertificateProviderConfig{ + FullKeyRotationDurationInDays: 2, + }) + require.NoError(t, err) + + err = cp.CreateCA() + require.NoError(t, err) + + err = cp.RotateCA() + require.NoError(t, err) + + user := &models.User{ + Model: models.Model{ID: uid.New()}, + Email: "joe@example.com", + } + + keyPair, err := pki.MakeUserCert("User "+user.ID.String(), 24*time.Hour) + require.NoError(t, err) + + // happens on the server, needs to be a request for this. + signedCert, signedRaw, err := pki.SignUserCert(cp, keyPair.Cert, user) + require.NoError(t, err) + keyPair.SignedCert = signedCert + keyPair.SignedCertPEM = signedRaw + + // create a test server and client to make sure the certs work. + requireMutualTLSWorks(t, keyPair, cp) +} + +func requireMutualTLSWorks(t *testing.T, clientKeypair *pki.KeyPair, cp pki.CertificateProvider) { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "success!") + })) + + serverTLSCerts, err := cp.TLSCertificates() + require.NoError(t, err) + + caPool := x509.NewCertPool() + + for _, cert := range cp.ActiveCAs() { + cert := cert + caPool.AddCert(&cert) + } + + server.TLS = &tls.Config{ + Certificates: serverTLSCerts, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caPool, + MinVersion: tls.VersionTLS12, + } + + server.StartTLS() + defer server.Close() + + // This will response with HTTP 200 OK and a body containing success!. We can now set up the client to trust the CA, and send a request to the server: + + clientTLSCert, err := clientKeypair.TLSCertificate() + require.NoError(t, err) + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{*clientTLSCert}, + ClientCAs: caPool, + RootCAs: caPool, + MinVersion: tls.VersionTLS12, + }, + } + http := http.Client{ + Transport: transport, + } + + resp, err := http.Get(server.URL) + require.NoError(t, err) + + // If no errors occurred, we now have our success! response from the server, and can verify it: + + respBodyBytes, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + body := strings.TrimSpace(string(respBodyBytes[:])) + require.Equal(t, "success!", body) +} diff --git a/internal/server/data/data.go b/internal/server/data/data.go index a9c5712807..b2bb555523 100644 --- a/internal/server/data/data.go +++ b/internal/server/data/data.go @@ -49,6 +49,8 @@ func NewDB(connection gorm.Dialector) (*gorm.DB, error) { &models.AccessKey{}, &models.Settings{}, &models.EncryptionKey{}, + &models.TrustedCertificate{}, + &models.RootCertificate{}, } for _, table := range tables { if err := db.AutoMigrate(table); err != nil { diff --git a/internal/server/data/trustedcerts.go b/internal/server/data/trustedcerts.go new file mode 100644 index 0000000000..a5ef049e8d --- /dev/null +++ b/internal/server/data/trustedcerts.go @@ -0,0 +1,68 @@ +package data + +import ( + "encoding/base64" + "errors" + "time" + + "github.com/infrahq/infra/internal" + "github.com/infrahq/infra/internal/server/models" + "gorm.io/gorm" +) + +// TrustPublicKey trusts a public key (in base64 format) from a user or service +// Callers must have received the key from a mTLS/e2ee (mutually encrypted), trusted source. +func TrustPublicKey(db *gorm.DB, tc *models.TrustedCertificate) error { + _, err := get[models.TrustedCertificate](db, ByPublicKey(tc.PublicKey)) + if err != nil && !errors.Is(err, internal.ErrNotFound) { + return err + } + + if err == nil { + // this one already exists + return nil + } + + return add(db, tc) +} + +func ListTrustedClientCertificates(db *gorm.DB) ([]models.TrustedCertificate, error) { + return list[models.TrustedCertificate](db) +} + +func ListRootCertificates(db *gorm.DB) ([]models.RootCertificate, error) { + return list[models.RootCertificate](db, OrderBy("id desc"), ByNotExpired(), Limit(2)) +} + +func GetRootCertificate(db *gorm.DB, selectors ...SelectorFunc) (*models.RootCertificate, error) { + return get[models.RootCertificate](db, selectors...) +} + +func AddRootCertificate(db *gorm.DB, cert *models.RootCertificate) error { + return add(db, cert) +} + +func ByPublicKey(key []byte) SelectorFunc { + return func(db *gorm.DB) *gorm.DB { + k := base64.StdEncoding.EncodeToString(key) + return db.Where("public_key = ?", k) + } +} + +func OrderBy(order string) SelectorFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Order(order) + } +} + +func Limit(limit int) SelectorFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Limit(limit) + } +} + +func ByNotExpired() SelectorFunc { + return func(db *gorm.DB) *gorm.DB { + return db.Where("expires_at is null or expires_at > ?", time.Now().UTC()) + } +} diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index 80cf0738fd..f01bbb0b89 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/infrahq/infra/internal/generate" "github.com/infrahq/infra/internal/server/data" "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/secrets" "github.com/infrahq/infra/uid" ) @@ -25,6 +27,17 @@ func setupDB(t *testing.T) *gorm.DB { db, err := data.NewDB(driver) require.NoError(t, err) + fp := secrets.NewFileSecretProviderFromConfig(secrets.FileConfig{ + Path: os.TempDir(), + }) + + kp := secrets.NewNativeSecretProvider(fp) + + key, err := kp.GenerateDataKey("") + require.NoError(t, err) + + models.SymmetricKey = key + return db } diff --git a/internal/server/models/trustedcertificate.go b/internal/server/models/trustedcertificate.go new file mode 100644 index 0000000000..d077a8243e --- /dev/null +++ b/internal/server/models/trustedcertificate.go @@ -0,0 +1,28 @@ +package models + +import ( + "time" +) + +type TrustedCertificate struct { + Model + + KeyAlgorithm string `validate:"required"` + SigningAlgorithm string `validate:"required"` + PublicKey Base64 `validate:"required"` + CertPEM []byte `validate:"required"` // pem encoded + Identity string `validate:"required"` + ExpiresAt time.Time + OneTimeUse bool +} + +type RootCertificate struct { + Model + + KeyAlgorithm string `validate:"required"` + SigningAlgorithm string `validate:"required"` + PublicKey Base64 `validate:"required"` + PrivateKey EncryptedAtRest `validate:"required"` + SignedCert EncryptedAtRest `validate:"required"` // contains private key? probably not pem encoded + ExpiresAt time.Time `validate:"required"` +} diff --git a/internal/server/models/types.go b/internal/server/models/types.go new file mode 100644 index 0000000000..32e4a88bc1 --- /dev/null +++ b/internal/server/models/types.go @@ -0,0 +1,30 @@ +package models + +import ( + "database/sql/driver" + "encoding/base64" + "fmt" +) + +type Base64 []byte + +func (f Base64) Value() (driver.Value, error) { + r := base64.StdEncoding.EncodeToString([]byte(f)) + + return r, nil +} + +func (f *Base64) Scan(v interface{}) error { + b, err := base64.StdEncoding.DecodeString(string(v.(string))) + if err != nil { + return fmt.Errorf("base64 decoding field: %w", err) + } + + *f = Base64(b) + + return nil +} + +func (f Base64) GormDataType() string { + return "text" +} diff --git a/internal/server/routes.go b/internal/server/routes.go index c15aed50ef..70ad7bc8d3 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -11,11 +11,9 @@ import ( "github.com/infrahq/infra/metrics" ) -type ( - ReqHandlerFunc[Req any] func(c *gin.Context, req *Req) error - ResHandlerFunc[Res any] func(c *gin.Context) (Res, error) - ReqResHandlerFunc[Req, Res any] func(c *gin.Context, req *Req) (Res, error) -) +type ReqHandlerFunc[Req any] func(c *gin.Context, req *Req) error +type ResHandlerFunc[Res any] func(c *gin.Context) (Res, error) +type ReqResHandlerFunc[Req, Res any] func(c *gin.Context, req *Req) (Res, error) func (a *API) registerRoutes(router *gin.RouterGroup) { router.Use( diff --git a/internal/server/server.go b/internal/server/server.go index 4e68405c45..d26a68f6ba 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,6 +4,11 @@ package server import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" "errors" "fmt" "io" @@ -33,6 +38,7 @@ import ( "github.com/infrahq/infra/internal/server/data" "github.com/infrahq/infra/internal/server/models" timer "github.com/infrahq/infra/internal/timer" + "github.com/infrahq/infra/pki" "github.com/infrahq/infra/secrets" ) @@ -61,14 +67,21 @@ type Options struct { Secrets []SecretProvider `mapstructure:"secrets"` Import *config.Config `mapstructure:"import"` + + NetworkEncryption string `mapstructure:"networkEncryption"` // mtls (default), e2ee, none. + TrustInitialClientPublicKey string `mapstructure:"trustInitialClientPublicKey"` + InitialRootCACert string `mapstructure:"initialRootCACert"` + InitialRootCAPublicKey string `mapstructure:"initialRootCAPublicKey"` + FullKeyRotationInDays int `mapstructure:"fullKeyRotationInDays"` // 365 default } type Server struct { - options Options - db *gorm.DB - tel *Telemetry - secrets map[string]secrets.SecretStorage - keys map[string]secrets.SymmetricKeyProvider + options Options + db *gorm.DB + tel *Telemetry + secrets map[string]secrets.SecretStorage + keys map[string]secrets.SymmetricKeyProvider + certificateProvider pki.CertificateProvider } func Run(options Options) (err error) { @@ -108,6 +121,10 @@ func Run(options Options) (err error) { return fmt.Errorf("loading database key: %w", err) } + if err = server.loadCertificates(); err != nil { + return fmt.Errorf("loading certificate provider: %w", err) + } + if options.EnableTelemetry { if err := configureTelemetry(server.db); err != nil { return fmt.Errorf("configuring telemetry: %w", err) @@ -135,17 +152,14 @@ func Run(options Options) (err error) { scope.SetContext("serverId", settings.ID) }) - // TODO: this should instead happen after runserver and we should wait for the server to close - go func() { - if err := server.importConfig(); err != nil { - logging.S.Error(fmt.Errorf("import config: %w", err)) - } - }() - if err := server.runServer(); err != nil { return fmt.Errorf("running server: %w", err) } + if err := server.importConfig(); err != nil { + logging.S.Error(fmt.Errorf("import config: %w", err)) + } + return logging.L.Sync() } @@ -165,45 +179,123 @@ func configureTelemetry(db *gorm.DB) error { return nil } -func serve(server *http.Server) { - if err := server.ListenAndServe(); err != nil { - logging.S.Errorf("server: %w", err) +func (s *Server) loadCertificates() (err error) { + if s.options.FullKeyRotationInDays == 0 { + s.options.FullKeyRotationInDays = 365 } -} -func (s *Server) runServer() error { - gin.SetMode(gin.ReleaseMode) + fullRotationInDays := s.options.FullKeyRotationInDays - router := gin.New() + // TODO: check certificate provider from config + s.certificateProvider, err = pki.NewNativeCertificateProvider(s.db, pki.NativeCertificateProviderConfig{ + FullKeyRotationDurationInDays: fullRotationInDays, + InitialRootCAPublicKey: []byte(s.options.InitialRootCAPublicKey), + InitialRootCACert: []byte(s.options.InitialRootCACert), + }) + if err != nil { + return err + } - router.Use(gin.Recovery()) - router.GET("/.well-known/jwks.json", func(c *gin.Context) { - settings, err := data.GetSettings(s.db) + // if there's no active CAs, try loading them from options. + cert := s.options.InitialRootCACert + key := s.options.InitialRootCAPublicKey + + if len(s.certificateProvider.ActiveCAs()) == 0 && len(cert) > 0 && len(key) > 0 { + jsonBytes := fmt.Sprintf(`{"ServerKey":{"CertPEM":"%s", "PublicKey":"%s"}}`, cert, key) + kp := &pki.KeyPair{} + err := json.Unmarshal([]byte(jsonBytes), kp) if err != nil { - sendAPIError(c, fmt.Errorf("could not get JWKs")) - return + return fmt.Errorf("reading initialRootCACert and initialRootCAPublicKey: %w", err) } - var pubKey jose.JSONWebKey - if err := pubKey.UnmarshalJSON(settings.PublicJWK); err != nil { - sendAPIError(c, fmt.Errorf("could not get JWKs")) - return + err = s.certificateProvider.Preload(kp.CertPEM, kp.PublicKey) + if err != nil && err.Error() != internal.ErrNotImplemented.Error() { + return fmt.Errorf("preloading initialRootCACert and initialRootCAPublicKey: %w", err) } + } - c.JSON(http.StatusOK, struct { - Keys []jose.JSONWebKey `json:"keys"` - }{ - []jose.JSONWebKey{pubKey}, - }) - }) + // if still no active CAs, create them + if len(s.certificateProvider.ActiveCAs()) == 0 { + logging.S.Info("Creating Root CA certificate") + if err := s.certificateProvider.CreateCA(); err != nil { + return fmt.Errorf("creating CA certificates: %w", err) + } + } + + // automatically rotate CAs as the oldest one expires + if len(s.certificateProvider.ActiveCAs()) == 1 { + logging.S.Info("Rotating Root CA certificate") + if err := s.certificateProvider.RotateCA(); err != nil { + return fmt.Errorf("rotating CA: %w", err) + } + } + + // if the current cert is going to expire in less than FullKeyRotationDurationInDays/2 days, rotate. + rotationWindow := time.Now().AddDate(0, 0, fullRotationInDays/2) + activeCAs := s.certificateProvider.ActiveCAs() + if len(activeCAs) < 2 || activeCAs[1].NotAfter.Before(rotationWindow) { + logging.S.Info("Half-Rotating Root CA certificate") + if err := s.certificateProvider.RotateCA(); err != nil { + return fmt.Errorf("rotating CA: %w", err) + } + } + + if len(s.options.TrustInitialClientPublicKey) > 0 { + key := s.options.TrustInitialClientPublicKey + rawKey, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return fmt.Errorf("reading trustInitialClientPublicKey: %w", err) + } + + tc := &models.TrustedCertificate{ + KeyAlgorithm: x509.PureEd25519.String(), + SigningAlgorithm: x509.Ed25519.String(), + PublicKey: models.Base64(rawKey), + // CertPEM: raw, + // ExpiresAt: cert.NotAfter, + // Identity: ident, + } + + err = data.TrustPublicKey(s.db, tc) + if err != nil { + return fmt.Errorf("saving trusted public key: %w", err) + } + } + + return nil +} + +func serve(server *http.Server) { + if err := server.ListenAndServe(); err != nil { + logging.S.Errorf("server: %w", err) + } +} - router.GET("/healthz", func(c *gin.Context) { - c.Status(http.StatusOK) +func (s *Server) wellKnownJWKsHandler(c *gin.Context) { + settings, err := data.GetSettings(s.db) + if err != nil { + sendAPIError(c, fmt.Errorf("could not get JWKs")) + return + } + + var pubKey jose.JSONWebKey + if err := pubKey.UnmarshalJSON(settings.PublicJWK); err != nil { + sendAPIError(c, fmt.Errorf("could not get JWKs")) + return + } + + c.JSON(http.StatusOK, struct { + Keys []jose.JSONWebKey `json:"keys"` + }{ + []jose.JSONWebKey{pubKey}, }) +} - NewAPIMux(s, router.Group("/v1")) +func (s *Server) healthHandler(c *gin.Context) { + c.Status(http.StatusOK) +} - // UI +func (s *Server) ui(router *gin.Engine) error { if s.options.EnableUI { if s.options.UIProxyURL != "" { remote, err := urlx.Parse(s.options.UIProxyURL) @@ -250,6 +342,24 @@ func (s *Server) runServer() error { } } + return nil +} + +func (s *Server) runServer() error { + gin.SetMode(gin.ReleaseMode) + + router := gin.New() + + router.Use(gin.Recovery()) + router.GET("/.well-known/jwks.json", s.wellKnownJWKsHandler) + router.GET("/healthz", s.healthHandler) + + NewAPIMux(s, router.Group("/v1")) + + if err := s.ui(router); err != nil { + return err + } + sentryHandler := sentryhttp.New(sentryhttp.Options{}) metrics := gin.New() @@ -277,16 +387,11 @@ func (s *Server) runServer() error { return fmt.Errorf("create tls cache: %w", err) } - manager := &autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(s.options.TLSCache), + tlsConfig, err := s.serverTLSConfig() + if err != nil { + return fmt.Errorf("tls config: %w", err) } - tlsConfig := manager.TLSConfig() - tlsConfig.GetCertificate = certs.SelfSignedOrLetsEncryptCert(manager, func() string { - return "" - }) - tlsServer := &http.Server{ Addr: ":443", TLSConfig: tlsConfig, @@ -301,6 +406,66 @@ func (s *Server) runServer() error { return nil } +func (s *Server) serverTLSConfig() (*tls.Config, error) { + switch s.options.NetworkEncryption { + case "mtls": + serverTLSCerts, err := s.certificateProvider.TLSCertificates() + if err != nil { + return nil, fmt.Errorf("getting tls certs: %w", err) + } + + caPool := x509.NewCertPool() + + for _, cert := range s.certificateProvider.ActiveCAs() { + cert := cert + caPool.AddCert(&cert) + } + + tcerts, err := data.ListTrustedClientCertificates(s.db) + if err != nil { + return nil, err + } + + for _, tcert := range tcerts { + p, _ := pem.Decode(tcert.CertPEM) + + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + return nil, err + } + + if cert.NotAfter.After(time.Now()) { + logging.S.Debugf("Trusting user certificate %q\n", cert.Subject.CommonName) + caPool.AddCert(cert) + } + } + + return &tls.Config{ + Certificates: serverTLSCerts, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caPool, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + }, nil + default: // "none" or blank + manager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(s.options.TLSCache), + } + tlsConfig := manager.TLSConfig() + tlsConfig.GetCertificate = certs.SelfSignedOrLetsEncryptCert(manager, func() string { + return "" + }) + + return tlsConfig, nil + } +} + // configureSentry returns ok:true when sentry is configured and initialized, or false otherwise. It can be used to know if `defer recoverWithSentryHub(sentry.CurrentHub())` can be called func (s *Server) configureSentry() (err error, ok bool) { if s.options.EnableCrashReporting && internal.CrashReportingDSN != "" { diff --git a/pki/certificates.go b/pki/certificates.go index 8d3e7f5a96..9a185da709 100644 --- a/pki/certificates.go +++ b/pki/certificates.go @@ -1,28 +1,123 @@ package pki -import "crypto/x509" +import ( + "crypto/ed25519" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "math/rand" + "strings" + "time" + + "github.com/infrahq/infra/internal/server/models" +) // the pki package defines an interface and implementations of public key encryption, specifically around certificates. type CertificateProvider interface { // A setup step; create a root CA. this happens only once. CreateCA() error + // rotate the current CA. This does a half-rotation. the current cert becomes the previous cert, and there are always two active certificates RotateCA() error + // return the two active CA certificates. This always returns two, and the second one is always the most recent ActiveCAs() []x509.Certificate - // return the chain of certificates, active or not, since x - // CertificateChain() error - // IssueCert() // don't think I need this + + // return active CAs as tls certificates, this includes the private keys; it's used for the servers to listen for requests and be able to read the responses. + TLSCertificates() ([]tls.Certificate, error) // Sign a cert with the latest active CA. // Caller should have already validated that it's okay to sign this certificate by verifying the sender's authenticity, and that they own the resources they're asking to be certified for. // A Certificate Signing Request can be parsed with `x509.ParseCertificateRequest()` SignCertificate(csr x509.CertificateRequest) (pemBytes []byte, err error) + + // Preload attempts to preload the root certificate into the system. If this is not possible in this implementation of the certificate provider, it should return internal.ErrNotImplemented or a simple errors.New("not implemented") + Preload(rootCACertificate, publicKey []byte) error } -// type Signer interface { -// SignCert() -// } +func MakeUserCert(commonName string, lifetime time.Duration) (*KeyPair, error) { + pub, prv, err := ed25519.GenerateKey(randReader) + if err != nil { + return nil, fmt.Errorf("generating keys: %w", err) + } + + certTemplate := x509.Certificate{ + PublicKeyAlgorithm: x509.Ed25519, + PublicKey: pub, + SerialNumber: big.NewInt(rand.Int63()), //nolint:gosec + Subject: pkix.Name{CommonName: commonName}, + NotBefore: time.Now().Add(-5 * time.Minute), + NotAfter: time.Now().Add(lifetime), + KeyUsage: x509.KeyUsageDataEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + rawCert, err := x509.CreateCertificate(randReader, &certTemplate, &certTemplate, pub, prv) + if err != nil { + return nil, fmt.Errorf("creating certificate: %w", err) + } + + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return nil, fmt.Errorf("parsing self-created certificate: %w", err) + } + + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: rawCert, + }) + + keyPair := &KeyPair{ + Cert: cert, + CertPEM: pemBytes, + PublicKey: pub, + PrivateKey: prv, + } -func ValidateCertificate() {} + return keyPair, nil +} + +func SignUserCert(cp CertificateProvider, cert *x509.Certificate, user *models.User) (*x509.Certificate, []byte, error) { + rawCert := cert.Raw + if len(cert.Raw) == 0 { + panic("cert.Raw is missing") + } + + if !strings.HasPrefix(cert.Subject.CommonName, "User ") { + return nil, nil, fmt.Errorf("invalid certificate common name for user certificate") + } + + pem1, err := cp.SignCertificate(x509.CertificateRequest{ + Raw: rawCert, + PublicKeyAlgorithm: cert.PublicKeyAlgorithm, + PublicKey: cert.PublicKey, + Subject: cert.Subject, + EmailAddresses: []string{user.Email}, + Extensions: cert.Extensions, + ExtraExtensions: cert.ExtraExtensions, + SignatureAlgorithm: x509.PureEd25519, + }) + if err != nil { + return nil, nil, err + } + + p, rest := pem.Decode(pem1) + if p == nil { + return nil, nil, fmt.Errorf("decoding certificate: %w", err) + } + + if len(rest) > 0 { + panic("forgot part of cert chain") + } + + newCert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("parsing certificate: %w", err) + } + + return newCert, pem1, nil +} diff --git a/pki/certificates_test.go b/pki/certificates_test.go index 57f20a6fd5..5fead733bf 100644 --- a/pki/certificates_test.go +++ b/pki/certificates_test.go @@ -64,10 +64,9 @@ func eachProvider(t *testing.T, eachFunc func(t *testing.T, p CertificateProvide require.NoError(t, err) defer os.RemoveAll(tmpDir) + db := setupDB(t) - p, err := NewNativeCertificateProvider(NativeCertificateProviderConfig{ - StoragePath: tmpDir, - }) + p, err := NewNativeCertificateProvider(db, NativeCertificateProviderConfig{}) require.NoError(t, err) providers["native"] = p @@ -105,7 +104,7 @@ func TestCertificatesImplementations(t *testing.T) { } t.Run("signing Cert Signing Requests", func(t *testing.T) { - cert, err := generateClientCertificate("Engine") + cert, err := generateClientCertificate("Connector") require.NoError(t, err) csr := x509.CertificateRequest{ @@ -136,6 +135,7 @@ func TestCertificatesImplementations(t *testing.T) { } func init() { + // only used in tests randReader = rand.New(rand.NewSource(0)) //nolint:gosec } @@ -145,7 +145,7 @@ func generateClientCertificate(subject string) (*x509.Certificate, error) { return nil, fmt.Errorf("generating keys: %w", err) } - kp := keyPair{ + kp := KeyPair{ PublicKey: pub, PrivateKey: prv, } @@ -158,7 +158,7 @@ func generateClientCertificate(subject string) (*x509.Certificate, error) { return cert, nil } -func createClientCertSignedBy(signer, signee keyPair, subject string, lifetime time.Duration) (*x509.Certificate, []byte, error) { +func createClientCertSignedBy(signer, signee KeyPair, subject string, lifetime time.Duration) (*x509.Certificate, []byte, error) { sig := ed25519.Sign(signer.PrivateKey, signee.PublicKey) if !ed25519.Verify(signer.PublicKey, signee.PublicKey, sig) { return nil, nil, errors.New("self-signed certificate doesn't match signature") @@ -171,10 +171,10 @@ func createClientCertSignedBy(signer, signee keyPair, subject string, lifetime t PublicKey: signee.PublicKey, SerialNumber: big.NewInt(rand.Int63()), //nolint:gosec Subject: pkix.Name{CommonName: subject}, - NotBefore: time.Now(), + NotBefore: time.Now().Add(-5 * time.Minute), NotAfter: time.Now().Add(lifetime), KeyUsage: x509.KeyUsageDataEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, } // create client certificate from template and CA public key diff --git a/pki/keypair.go b/pki/keypair.go new file mode 100644 index 0000000000..87fe88474b --- /dev/null +++ b/pki/keypair.go @@ -0,0 +1,79 @@ +package pki + +import ( + "crypto" + "crypto/ed25519" + "crypto/tls" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" +) + +// PrivateKey is an interface compatible with the go stdlib ed25519, ecdsa, and rsa keys +type PrivateKey interface { + Public() crypto.PublicKey + Equal(x crypto.PrivateKey) bool +} + +type KeyPair struct { + KeyAlgorithm string + SigningAlgorithm string + PublicKey ed25519.PublicKey + PrivateKey ed25519.PrivateKey `json:",omitempty"` + CertPEM []byte `json:",omitempty"` // pem encoded + SignedCertPEM []byte `json:",omitempty"` // pem encoded + Cert *x509.Certificate `json:"-"` + SignedCert *x509.Certificate `json:"-"` +} + +func (k *KeyPair) TLSCertificate() (*tls.Certificate, error) { + bytes := k.SignedCertPEM + if len(bytes) == 0 { + bytes = k.CertPEM + } + + keyPEM, err := MarshalPrivateKey(k.PrivateKey) + + cert, err := tls.X509KeyPair(bytes, keyPEM) + if err != nil { + return nil, fmt.Errorf("reading keypair: %w", err) + } + + return &cert, nil +} + +func (k *KeyPair) UnmarshalJSON(data []byte) error { + type TmpKeyPair KeyPair + tmpKeyPair := &TmpKeyPair{} + + err := json.Unmarshal(data, &tmpKeyPair) + if err != nil { + return err + } + + k.PublicKey = tmpKeyPair.PublicKey + k.PrivateKey = tmpKeyPair.PrivateKey + k.CertPEM = tmpKeyPair.CertPEM + k.SignedCertPEM = tmpKeyPair.SignedCertPEM + + p, _ := pem.Decode(k.CertPEM) + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + return fmt.Errorf("parsing raw certificate: %w", err) + } + + k.Cert = cert + + if len(k.SignedCertPEM) > 0 { + p, _ = pem.Decode(k.SignedCertPEM) + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + return fmt.Errorf("parsing signed certificate: %w", err) + } + + k.SignedCert = cert + } + + return nil +} diff --git a/pki/native.go b/pki/native.go index 42b811a1e7..6a8c8d648a 100644 --- a/pki/native.go +++ b/pki/native.go @@ -1,21 +1,26 @@ package pki import ( - "bytes" "crypto/ed25519" "crypto/rand" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/pem" "errors" "fmt" - "io/ioutil" - "log" "math/big" + "net" "os" - "path" + "strings" "time" + + "github.com/infrahq/infra/internal" + "github.com/infrahq/infra/internal/server/data" + "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/secrets" + "gorm.io/gorm" ) const ( @@ -39,44 +44,129 @@ var ( type NativeCertificateProvider struct { NativeCertificateProviderConfig - activeKeypair keyPair - previousKeypair keyPair + db *gorm.DB + + activeKeypair KeyPair + previousKeypair KeyPair - // TODO: support arbitrary storage - // secretStorage secrets.SecretStorage + secretStorage secrets.SecretStorage // secretKeyProvider secrets.SymmetricKeyProvider } type NativeCertificateProviderConfig struct { - StoragePath string FullKeyRotationDurationInDays int - // Algorithm string // only ed25519 so far. -} - -type keyPair struct { - PublicKey ed25519.PublicKey - PrivateKey ed25519.PrivateKey - certRaw []byte - Cert *x509.Certificate + KeyAlgorithm string // only ed25519 so far. + SigningAlgorithm string + InitialRootCAPublicKey []byte + InitialRootCACert []byte + InitialRootCAPrivateKey []byte } -func NewNativeCertificateProvider(cfg NativeCertificateProviderConfig) (*NativeCertificateProvider, error) { +func NewNativeCertificateProvider(db *gorm.DB, cfg NativeCertificateProviderConfig) (*NativeCertificateProvider, error) { if cfg.FullKeyRotationDurationInDays == 0 { cfg.FullKeyRotationDurationInDays = 365 } p := &NativeCertificateProvider{ NativeCertificateProviderConfig: cfg, + db: db, } - if err := p.loadFromDisk(); err != nil { - if !errors.Is(err, os.ErrNotExist) { - return nil, err + + if err := p.loadFromDB(); err != nil { + return nil, err + } + + if p.activeKeypair.SignedCert == nil && + len(cfg.InitialRootCAPublicKey) > 0 && + len(cfg.InitialRootCACert) > 0 && + len(cfg.InitialRootCAPrivateKey) > 0 { + pubKey, err := base64.StdEncoding.DecodeString(string(cfg.InitialRootCAPublicKey)) + if err != nil { + return nil, fmt.Errorf("reading initialRootCAPublicKey: %w", err) + } + + cert, err := base64.StdEncoding.DecodeString(string(cfg.InitialRootCACert)) + if err != nil { + return nil, fmt.Errorf("reading initialRootCACert: %w", err) + } + + prvKey, err := base64.StdEncoding.DecodeString(string(cfg.InitialRootCAPrivateKey)) + if err != nil { + return nil, fmt.Errorf("reading initialRootCAPrivateKey: %w", err) + } + + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, fmt.Errorf("parsing initialRootCACert: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + }) + + p.activeKeypair = KeyPair{ + KeyAlgorithm: p.KeyAlgorithm, + SigningAlgorithm: p.SigningAlgorithm, + PublicKey: pubKey, + PrivateKey: prvKey, + SignedCertPEM: certPEM, + SignedCert: c, } } return p, nil } +func (n *NativeCertificateProvider) Preload(rootCACertificate, publicKey []byte) (err error) { + if n.activeKeypair.SignedCert != nil { + return fmt.Errorf("cannot preload a certificate when another one is already loaded.") + } + + partsFound := 0 + rest := rootCACertificate + var p *pem.Block + var cert *x509.Certificate + var privateKey ed25519.PrivateKey + for len(rest) > 0 { + partsFound++ + p, rest = pem.Decode(rootCACertificate) + + switch { + case strings.Contains(p.Type, "PRIVATE KEY"): + key, err := x509.ParsePKCS8PrivateKey(p.Bytes) + if err != nil { + return fmt.Errorf("parsing private key from certificate: %w", err) + } + privateKey = key.(ed25519.PrivateKey) + case strings.Contains(p.Type, "CERTIFICATE"): + cert, err = x509.ParseCertificate(p.Bytes) + if err != nil { + return fmt.Errorf("parsing root certificate: %w", err) + } + } + } + + if partsFound > 2 { + return fmt.Errorf("expected one certificate and one private key, but got certificate chain") + } + + if partsFound < 2 { + return fmt.Errorf("expected one certificate and one private key") + } + + n.activeKeypair = KeyPair{ + KeyAlgorithm: cert.PublicKeyAlgorithm.String(), + SigningAlgorithm: cert.SignatureAlgorithm.String(), + PublicKey: publicKey, + PrivateKey: privateKey, + SignedCertPEM: rootCACertificate, + SignedCert: cert, + } + + return n.RotateCA() +} + // CreateCA creates a new root CA and immediately does a half-rotation. // the new active key after rotation is the one that should be used. func (n *NativeCertificateProvider) CreateCA() error { @@ -95,21 +185,29 @@ func (n *NativeCertificateProvider) CreateCA() error { return err } - n.activeKeypair.certRaw = raw - n.activeKeypair.Cert = cert + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: raw, + }) + + n.activeKeypair.SignedCertPEM = pemBytes + n.activeKeypair.SignedCert = cert + n.activeKeypair.KeyAlgorithm = x509.Ed25519.String() + n.activeKeypair.SigningAlgorithm = x509.PureEd25519.String() return n.RotateCA() } +// ActiveCAs returns the currently in-use CAs, the newest cert is always the last in the list func (n *NativeCertificateProvider) ActiveCAs() []x509.Certificate { result := []x509.Certificate{} - if certActive(n.previousKeypair.Cert) { - result = append(result, *n.previousKeypair.Cert) + if n.previousKeypair.SignedCert != nil && certActive(n.previousKeypair.SignedCert) { + result = append(result, *n.previousKeypair.SignedCert) } - if certActive(n.activeKeypair.Cert) { - result = append(result, *n.activeKeypair.Cert) + if n.activeKeypair.SignedCert != nil && certActive(n.activeKeypair.SignedCert) { + result = append(result, *n.activeKeypair.SignedCert) } return result @@ -127,11 +225,14 @@ func certActive(cert *x509.Certificate) bool { return true } +// TODO: SignCertificate should be renamed to SignUserCertificate? func (n *NativeCertificateProvider) SignCertificate(csr x509.CertificateRequest) (pemBytes []byte, err error) { - switch csr.Subject.CommonName { - case rootCAName: + switch { + case csr.Subject.CommonName == rootCAName: return nil, fmt.Errorf("cannot sign cert pretending to be the root CA") - case "Engine", "Server", "Client": + case strings.HasPrefix(csr.Subject.CommonName, "Connector"): + case strings.HasPrefix(csr.Subject.CommonName, "Infra Server"): + case strings.HasPrefix(csr.Subject.CommonName, "User"): // these are ok. default: return nil, fmt.Errorf("invalid Subject name %q", csr.Subject.CommonName) @@ -148,27 +249,36 @@ func (n *NativeCertificateProvider) SignCertificate(csr x509.CertificateRequest) certTemplate := &x509.Certificate{ Signature: csr.Signature, SignatureAlgorithm: csr.SignatureAlgorithm, - PublicKeyAlgorithm: csr.PublicKeyAlgorithm, PublicKey: csr.PublicKey, + SerialNumber: big.NewInt(2), + Issuer: n.activeKeypair.SignedCert.Subject, + Subject: csr.Subject, + EmailAddresses: csr.EmailAddresses, + Extensions: csr.Extensions, // TODO: security issue? + ExtraExtensions: csr.ExtraExtensions, // TODO: security issue? + NotBefore: time.Now().Add(-5 * time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + } + + if !n.activeKeypair.SignedCert.IsCA { + panic("not ca") + } - SerialNumber: big.NewInt(2), - Issuer: n.activeKeypair.Cert.Subject, - Subject: csr.Subject, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + if n.activeKeypair.SignedCert.KeyUsage&x509.KeyUsageCertSign != x509.KeyUsageCertSign { + panic("can't sign keys with this cert") } - rawCert, err := x509.CreateCertificate(randReader, certTemplate, n.activeKeypair.Cert, csr.PublicKey, n.activeKeypair.PrivateKey) + signedCert, err := x509.CreateCertificate(randReader, certTemplate, n.activeKeypair.SignedCert, csr.PublicKey, n.activeKeypair.PrivateKey) if err != nil { return nil, fmt.Errorf("creating cert: %w", err) } pemBytes = pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", - Bytes: rawCert, + Bytes: signedCert, }) return pemBytes, nil @@ -177,6 +287,7 @@ func (n *NativeCertificateProvider) SignCertificate(csr x509.CertificateRequest) // RotateCA does a half-rotation. the current cert becomes the previous cert, and there are always two active certificates func (n *NativeCertificateProvider) RotateCA() error { n.previousKeypair = n.activeKeypair + n.activeKeypair = KeyPair{} pub, prv, err := ed25519.GenerateKey(randReader) if err != nil { @@ -193,19 +304,21 @@ func (n *NativeCertificateProvider) RotateCA() error { return err } - n.activeKeypair.certRaw = raw - n.activeKeypair.Cert = cert + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: raw, + }) + + n.activeKeypair.SignedCertPEM = pemBytes + n.activeKeypair.SignedCert = cert + n.activeKeypair.KeyAlgorithm = x509.Ed25519.String() + n.activeKeypair.SigningAlgorithm = x509.PureEd25519.String() - return n.saveToDisk() + return n.saveToDB() } // createCertSignedBy signs the signee public key using the signer private key, allowing anyone to verify the signature with the signer public key. Certificate expires after _lifetime_ -func createCertSignedBy(signer, signee keyPair, lifetime time.Duration) (*x509.Certificate, []byte, error) { - sig := ed25519.Sign(signer.PrivateKey, signee.PublicKey) - if !ed25519.Verify(signer.PublicKey, signee.PublicKey, sig) { - return nil, nil, errors.New("self-signed certificate doesn't match signature") - } - +func createCertSignedBy(signer, signee KeyPair, lifetime time.Duration) (*x509.Certificate, []byte, error) { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serial, err := rand.Int(randReader, serialNumberLimit) @@ -213,76 +326,57 @@ func createCertSignedBy(signer, signee keyPair, lifetime time.Duration) (*x509.C return nil, nil, fmt.Errorf("creating random serial: %w", err) } - certTemplate := x509.Certificate{ - Signature: sig, + certTemplate := &x509.Certificate{ SignatureAlgorithm: x509.PureEd25519, PublicKeyAlgorithm: x509.Ed25519, PublicKey: signee.PublicKey, SerialNumber: serial, Issuer: pkix.Name{CommonName: rootCAName}, Subject: pkix.Name{CommonName: rootCAName}, - NotBefore: time.Now(), + NotBefore: time.Now().Add(-5 * time.Minute), NotAfter: time.Now().Add(lifetime), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{}, - IsCA: true, - } + KeyUsage: x509.KeyUsageCertSign | + x509.KeyUsageDigitalSignature | + x509.KeyUsageCRLSign | + x509.KeyUsageKeyAgreement | + x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + IsCA: true, + BasicConstraintsValid: true, - // create client certificate from template and CA public key - rawCert, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, signee.PublicKey, signee.PrivateKey) - if err != nil { - return nil, nil, fmt.Errorf("creating certificate: %w", err) + // SubjectAltName values + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback}, + DNSNames: []string{"localhost"}, // TODO: Support domain names for services? } - cert, err := x509.ParseCertificate(rawCert) - if err != nil { - return nil, nil, fmt.Errorf("parsing self-created certificate: %w", err) + signeeCert := signee.Cert + if signeeCert == nil { + signeeCert = certTemplate } - cert.IsCA = true // this isn't persisted - - return cert, rawCert, nil -} - -func (n *NativeCertificateProvider) saveToDisk() error { - err := os.MkdirAll(n.StoragePath, 0o600) - if err != nil && !os.IsExist(err) { - log.Printf("creating directory %q", n.StoragePath) + if !signeeCert.IsCA { + return nil, nil, fmt.Errorf("signee cert is not a CA") } - err = writePEMToFile(path.Join(n.StoragePath, "root.crt"), &pem.Block{ - Type: "CERTIFICATE", - Bytes: n.activeKeypair.certRaw, - }) - if err != nil { - return fmt.Errorf("writing PEM: %w", err) - } - - err = writePEMToFile(path.Join(n.StoragePath, "root.key"), &pem.Block{ - Type: "PRIVATE KEY", - Bytes: []byte(base64.StdEncoding.EncodeToString(n.activeKeypair.PrivateKey)), - }) + // create client certificate from template and CA public key + rawCert, err := x509.CreateCertificate(rand.Reader, certTemplate, signeeCert, signee.PublicKey, signee.PrivateKey) if err != nil { - return fmt.Errorf("writing PEM: %w", err) + return nil, nil, fmt.Errorf("creating certificate: %w", err) } - err = writePEMToFile(path.Join(n.StoragePath, "root-previous.crt"), &pem.Block{ - Type: "CERTIFICATE", - Bytes: n.previousKeypair.certRaw, - }) + cert, err := x509.ParseCertificate(rawCert) if err != nil { - return fmt.Errorf("writing PEM: %w", err) + return nil, nil, fmt.Errorf("parsing self-created certificate: %w", err) } - err = writePEMToFile(path.Join(n.StoragePath, "root-previous.key"), &pem.Block{ - Type: "PRIVATE KEY", - Bytes: []byte(base64.StdEncoding.EncodeToString(n.activeKeypair.PrivateKey)), - }) - if err != nil { - return fmt.Errorf("writing PEM: %w", err) + if !cert.IsCA { + return nil, nil, fmt.Errorf("signed cert is not a CA?") } - return nil + return cert, rawCert, nil } func writePEMToFile(file string, p *pem.Block) error { @@ -299,126 +393,102 @@ func writePEMToFile(file string, p *pem.Block) error { return f.Close() } -func (n *NativeCertificateProvider) loadFromDisk() error { - var ok bool - - err := os.MkdirAll(n.StoragePath, 0o600) - if err != nil && !os.IsExist(err) { - log.Printf("creating directory %q", n.StoragePath) - } - - pems, err := readFromPEMFile(path.Join(n.StoragePath, "root.crt")) - if err != nil { - return fmt.Errorf("reading PEM: %w", err) - } - - n.activeKeypair.certRaw = pems[0].Bytes - - cert, err := x509.ParseCertificate(n.activeKeypair.certRaw) +func (n *NativeCertificateProvider) loadFromDB() error { + certs, err := data.ListRootCertificates(n.db) if err != nil { - return fmt.Errorf("parsing certificate: %w", err) + return err } - cert.IsCA = true - n.activeKeypair.Cert = cert + if len(certs) >= 1 { + n.activeKeypair, err = certificateToKeyPair(&certs[0]) + if err != nil { + return err + } - // nolint:exhaustive - switch cert.PublicKeyAlgorithm { - case x509.Ed25519: - n.activeKeypair.PublicKey, ok = cert.PublicKey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("unexpected key type %t, expected ed25519", cert.PublicKey) + if len(certs) >= 2 { + n.previousKeypair, err = certificateToKeyPair(&certs[1]) + if err != nil { + return err + } } - default: - panic("unexpected key algorithm " + cert.PublicKeyAlgorithm.String()) } - pems, err = readFromPEMFile(path.Join(n.StoragePath, "root.key")) - if err != nil { - return fmt.Errorf("reading PEM: %w", err) - } + return nil +} - b, err := base64.StdEncoding.DecodeString(string(pems[0].Bytes)) +func certificateToKeyPair(c *models.RootCertificate) (KeyPair, error) { + // the certificate doesn't have pem armoring on it. + cert, err := x509.ParseCertificate([]byte(c.SignedCert)) if err != nil { - return fmt.Errorf("decoding key: %w", err) + return KeyPair{}, fmt.Errorf("couldn't read certificate from db: %w", err) } - n.activeKeypair.PrivateKey = b - - pems, err = readFromPEMFile(path.Join(n.StoragePath, "root-previous.crt")) - if err != nil { - return fmt.Errorf("reading PEM: %w", err) - } + return KeyPair{ + KeyAlgorithm: c.KeyAlgorithm, + SigningAlgorithm: c.SigningAlgorithm, + PublicKey: ed25519.PublicKey(c.PublicKey), + PrivateKey: ed25519.PrivateKey(c.PrivateKey), + SignedCertPEM: []byte(c.SignedCert), + SignedCert: cert, + }, nil +} - n.previousKeypair.certRaw = pems[0].Bytes +func keyPairToCertificate(k KeyPair) *models.RootCertificate { + // don't store the certificate with pem encoding; it's padding that only assists a known-plaintext attack + b, _ := pem.Decode(k.SignedCertPEM) - cert, err = x509.ParseCertificate(n.previousKeypair.certRaw) - if err != nil { - return fmt.Errorf("parsing certificate: %w", err) + return &models.RootCertificate{ + KeyAlgorithm: k.KeyAlgorithm, + SigningAlgorithm: k.SigningAlgorithm, + PublicKey: models.Base64(k.PublicKey), + PrivateKey: models.EncryptedAtRest(k.PrivateKey), + SignedCert: models.EncryptedAtRest(b.Bytes), + ExpiresAt: k.SignedCert.NotAfter, } +} - cert.IsCA = true - n.previousKeypair.Cert = cert - - // nolint:exhaustive - switch cert.PublicKeyAlgorithm { - case x509.Ed25519: - n.previousKeypair.PublicKey, ok = cert.PublicKey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("unexpected key type %t, expected ed25519", cert.PublicKey) +// saveToDB stores new certs to the database. Used when rotating keys. +func (n *NativeCertificateProvider) saveToDB() error { + certs := []*models.RootCertificate{ + keyPairToCertificate(n.previousKeypair), + keyPairToCertificate(n.activeKeypair), + } + // only create the previous keypair if it doesn't already exist. + for _, cert := range certs { + c, err := data.GetRootCertificate(n.db, data.ByPublicKey(cert.PublicKey)) + if c != nil { + continue + } + if !errors.Is(err, internal.ErrNotFound) { + return fmt.Errorf("checking for existing cert: %w", err) } - default: - panic("unexpected key algorithm " + cert.PublicKeyAlgorithm.String()) - } - - pems, err = readFromPEMFile(path.Join(n.StoragePath, "root-previous.key")) - if err != nil { - return fmt.Errorf("reading PEM: %w", err) - } - b, err = base64.StdEncoding.DecodeString(string(pems[0].Bytes)) - if err != nil { - return fmt.Errorf("decoding key: %w", err) + if err := data.AddRootCertificate(n.db, cert); err != nil { + return fmt.Errorf("adding CA certificate: %w", err) + } } - n.previousKeypair.PrivateKey = b - return nil } -func readFromPEMFile(file string) (pems []*pem.Block, err error) { - // nicer errors from os.Stat. it'll be an errors.Is(err, os.ErrNotExist) if it doesn't exist. - if _, err = os.Stat(file); err != nil { - return nil, err - } +func (n *NativeCertificateProvider) TLSCertificates() ([]tls.Certificate, error) { + result := []tls.Certificate{} - f, err := os.Open(file) - if err != nil { - return nil, fmt.Errorf("opening %q: %w", file, err) + keyPairs := []KeyPair{ + n.previousKeypair, + n.activeKeypair, } - defer f.Close() - b, err := ioutil.ReadAll(f) - if err != nil { - return nil, fmt.Errorf("reading %q: %w", file, err) - } - - for { - block, rest := pem.Decode(b) - if block == nil && bytes.Equal(rest, b) { - return nil, fmt.Errorf("%q contains no pem data", file) - } - - if block != nil { - pems = append(pems, block) + for _, keyPair := range keyPairs { + cert, err := keyPair.TLSCertificate() + if err != nil { + return nil, err } - if len(rest) == 0 { - break - } + result = append(result, *cert) } - return pems, nil + return result, nil } func isAllowedSignatureAlgorithm(alg x509.SignatureAlgorithm) bool { diff --git a/pki/native_test.go b/pki/native_test.go index 2127931d8c..c18122b5a3 100644 --- a/pki/native_test.go +++ b/pki/native_test.go @@ -4,18 +4,42 @@ import ( "os" "testing" + "github.com/infrahq/infra/internal/server/data" + "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/secrets" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) -func TestCertificateDiskStorage(t *testing.T) { - s, err := os.MkdirTemp(os.TempDir(), "certs") +func setupDB(t *testing.T) *gorm.DB { + driver, err := data.NewSQLiteDriver("file::memory:") require.NoError(t, err) + db, err := data.NewDB(driver) + require.NoError(t, err) + + fp := secrets.NewFileSecretProviderFromConfig(secrets.FileConfig{ + Path: os.TempDir(), + }) + + kp := secrets.NewNativeSecretProvider(fp) + + key, err := kp.GenerateDataKey("") + require.NoError(t, err) + + models.SymmetricKey = key + + return db +} + +func TestCertificateStorage(t *testing.T) { cfg := NativeCertificateProviderConfig{ - StoragePath: s, FullKeyRotationDurationInDays: 2, } - p, err := NewNativeCertificateProvider(cfg) + + db := setupDB(t) + + p, err := NewNativeCertificateProvider(db, cfg) require.NoError(t, err) err = p.CreateCA() @@ -25,7 +49,7 @@ func TestCertificateDiskStorage(t *testing.T) { require.Len(t, activeCAs, 2) // reload - p, err = NewNativeCertificateProvider(cfg) + p, err = NewNativeCertificateProvider(db, cfg) require.NoError(t, err) reloadedActiveCAs := p.ActiveCAs() @@ -33,3 +57,21 @@ func TestCertificateDiskStorage(t *testing.T) { require.Equal(t, activeCAs, reloadedActiveCAs) } + +func TestTLSCertificates(t *testing.T) { + cfg := NativeCertificateProviderConfig{ + FullKeyRotationDurationInDays: 2, + } + p, err := NewNativeCertificateProvider(setupDB(t), cfg) + require.NoError(t, err) + + err = p.CreateCA() + require.NoError(t, err) + + activeCAs := p.ActiveCAs() + require.Len(t, activeCAs, 2) + + certs, err := p.TLSCertificates() + require.NoError(t, err) + require.Len(t, certs, 2) +} diff --git a/pki/pem.go b/pki/pem.go new file mode 100644 index 0000000000..a617b8d19d --- /dev/null +++ b/pki/pem.go @@ -0,0 +1,69 @@ +package pki + +import ( + "bytes" + "crypto/ed25519" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" +) + +func ReadFromPEMFile(file string) (pems []*pem.Block, pemBytes []byte, err error) { + b, err := readFile(file) + if err != nil { + return nil, nil, err + } + + for { + block, rest := pem.Decode(b) + if block == nil && bytes.Equal(rest, b) { + return nil, nil, fmt.Errorf("%q contains no pem data", file) + } + + if block != nil { + pems = append(pems, block) + } + + if len(rest) == 0 { + break + } + } + + return pems, b, nil +} + +func MarshalPrivateKey(key ed25519.PrivateKey) ([]byte, error) { + marshalledPrvKey, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return nil, fmt.Errorf("marshalling private key: %w", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: marshalledPrvKey, + }) + + return keyPEM, nil +} + +func readFile(file string) ([]byte, error) { + // nicer errors from os.Stat. it'll be an errors.Is(err, os.ErrNotExist) if it doesn't exist. + if _, err := os.Stat(file); err != nil { + return nil, err + } + + f, err := os.Open(file) + if err != nil { + return nil, fmt.Errorf("opening %q: %w", file, err) + } + defer f.Close() + + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, fmt.Errorf("reading %q: %w", file, err) + } + + return b, nil +}