Skip to content

Commit

Permalink
feat: Certificates management (#1086)
Browse files Browse the repository at this point in the history
* basics of mTLS working
* allow mtls to be enabled or disabled
* use db for cert backend
  • Loading branch information
ssoroka authored Mar 1, 2022
1 parent fb7f3fe commit 09802a3
Show file tree
Hide file tree
Showing 18 changed files with 1,052 additions and 269 deletions.
2 changes: 1 addition & 1 deletion internal/certs/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func SelfSignedCert(hosts []string) ([]byte, []byte, error) {
Subject: pkix.Name{
Organization: []string{"Infra"},
},
NotBefore: time.Now(),
NotBefore: time.Now().Add(-5 * time.Minute),
NotAfter: time.Now().AddDate(0, 0, 365),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
Expand Down
10 changes: 8 additions & 2 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func newServerCmd() *cobra.Command {
strcase.ConfigureAcronym("enable-ui", "enableUI")
strcase.ConfigureAcronym("ui-proxy-url", "uiProxyURL")

var options server.Options
options := server.Options{}
if err := parseOptions(cmd, &options, "INFRA_SERVER"); err != nil {
return err
}
Expand Down Expand Up @@ -570,11 +570,14 @@ func newVersionCmd() *cobra.Command {
}
}

var nonInteractiveMode bool

func NewRootCmd() (*cobra.Command, error) {
cobra.EnableCommandSorting = false

type rootOptions struct {
LogLevel string `mapstructure:"logLevel"`
LogLevel string `mapstructure:"logLevel"`
NonInteractive bool `mapstructure:"nonInteractive"`
}

rootCmd := &cobra.Command{
Expand All @@ -588,6 +591,8 @@ func NewRootCmd() (*cobra.Command, error) {
return err
}

nonInteractiveMode = options.NonInteractive

return logging.SetLevel(options.LogLevel)
},
}
Expand All @@ -609,6 +614,7 @@ func NewRootCmd() (*cobra.Command, error) {
rootCmd.AddCommand(newVersionCmd())

rootCmd.PersistentFlags().String("log-level", "info", "Set the log level. One of error, warn, info, or debug")
rootCmd.PersistentFlags().Bool("non-interactive", false, "don't assume an interactive terminal, even if there is one")

return rootCmd, nil
}
Expand Down
19 changes: 16 additions & 3 deletions internal/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const cliLoginRedirectURL = "http://localhost:8301"

func relogin() error {
// TODO (https://github.com/infrahq/infra/issues/488): support non-interactive login
if !term.IsTerminal(int(os.Stdin.Fd())) {
if !isInteractiveMode() {
return errors.New("Non-interactive login is not supported")
}

Expand Down Expand Up @@ -69,9 +69,22 @@ func relogin() error {
return finishLogin(currentConfig.Host, uid.NewUserPolymorphicID(loginRes.ID), loginRes.Name, loginRes.AccessKey, currentConfig.SkipTLSVerify, 0)
}

func isInteractiveMode() bool {
if nonInteractiveMode {
// user explicitly asked for a non-interactive terminal
return false
}

if os.Stdin == nil {
return false
}

return term.IsTerminal(int(os.Stdin.Fd()))
}

func login(host string) error {
// TODO (https://github.com/infrahq/infra/issues/488): support non-interactive login
if !term.IsTerminal(int(os.Stdin.Fd())) {
if !isInteractiveMode() {
return errors.New("Non-interactive login is not supported")
}

Expand Down Expand Up @@ -162,7 +175,7 @@ func login(host string) error {
}

// access key
if option == len(options) - 1 {
if option == len(options)-1 {
if accessKey == "" {
err = survey.AskOne(&survey.Password{Message: "Access Key:"}, &accessKey, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions internal/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ var (
// ErrForbidden means you don't have permissions to the requested resource
ErrForbidden = fmt.Errorf("forbidden")

ErrDuplicate = fmt.Errorf("duplicate record")
ErrNotFound = fmt.Errorf("record not found")
ErrBadRequest = fmt.Errorf("bad request")
ErrDuplicate = fmt.Errorf("duplicate record")
ErrNotFound = fmt.Errorf("record not found")
ErrBadRequest = fmt.Errorf("bad request")
ErrNotImplemented = fmt.Errorf("not implemented")
)
104 changes: 104 additions & 0 deletions internal/server/certificates_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package server

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/pki"
"github.com/infrahq/infra/uid"
"github.com/stretchr/testify/require"
)

func TestCertificateSigningWorks(t *testing.T) {
db := setupDB(t)

cp, err := pki.NewNativeCertificateProvider(db, pki.NativeCertificateProviderConfig{
FullKeyRotationDurationInDays: 2,
})
require.NoError(t, err)

err = cp.CreateCA()
require.NoError(t, err)

err = cp.RotateCA()
require.NoError(t, err)

user := &models.User{
Model: models.Model{ID: uid.New()},
Email: "joe@example.com",
}

keyPair, err := pki.MakeUserCert("User "+user.ID.String(), 24*time.Hour)
require.NoError(t, err)

// happens on the server, needs to be a request for this.
signedCert, signedRaw, err := pki.SignUserCert(cp, keyPair.Cert, user)
require.NoError(t, err)
keyPair.SignedCert = signedCert
keyPair.SignedCertPEM = signedRaw

// create a test server and client to make sure the certs work.
requireMutualTLSWorks(t, keyPair, cp)
}

func requireMutualTLSWorks(t *testing.T, clientKeypair *pki.KeyPair, cp pki.CertificateProvider) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "success!")
}))

serverTLSCerts, err := cp.TLSCertificates()
require.NoError(t, err)

caPool := x509.NewCertPool()

for _, cert := range cp.ActiveCAs() {
cert := cert
caPool.AddCert(&cert)
}

server.TLS = &tls.Config{
Certificates: serverTLSCerts,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caPool,
MinVersion: tls.VersionTLS12,
}

server.StartTLS()
defer server.Close()

// This will response with HTTP 200 OK and a body containing success!. We can now set up the client to trust the CA, and send a request to the server:

clientTLSCert, err := clientKeypair.TLSCertificate()
require.NoError(t, err)

transport := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*clientTLSCert},
ClientCAs: caPool,
RootCAs: caPool,
MinVersion: tls.VersionTLS12,
},
}
http := http.Client{
Transport: transport,
}

resp, err := http.Get(server.URL)
require.NoError(t, err)

// If no errors occurred, we now have our success! response from the server, and can verify it:

respBodyBytes, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)

body := strings.TrimSpace(string(respBodyBytes[:]))
require.Equal(t, "success!", body)
}
2 changes: 2 additions & 0 deletions internal/server/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func NewDB(connection gorm.Dialector) (*gorm.DB, error) {
&models.AccessKey{},
&models.Settings{},
&models.EncryptionKey{},
&models.TrustedCertificate{},
&models.RootCertificate{},
}
for _, table := range tables {
if err := db.AutoMigrate(table); err != nil {
Expand Down
68 changes: 68 additions & 0 deletions internal/server/data/trustedcerts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package data

import (
"encoding/base64"
"errors"
"time"

"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/server/models"
"gorm.io/gorm"
)

// TrustPublicKey trusts a public key (in base64 format) from a user or service
// Callers must have received the key from a mTLS/e2ee (mutually encrypted), trusted source.
func TrustPublicKey(db *gorm.DB, tc *models.TrustedCertificate) error {
_, err := get[models.TrustedCertificate](db, ByPublicKey(tc.PublicKey))
if err != nil && !errors.Is(err, internal.ErrNotFound) {
return err
}

if err == nil {
// this one already exists
return nil
}

return add(db, tc)
}

func ListTrustedClientCertificates(db *gorm.DB) ([]models.TrustedCertificate, error) {
return list[models.TrustedCertificate](db)
}

func ListRootCertificates(db *gorm.DB) ([]models.RootCertificate, error) {
return list[models.RootCertificate](db, OrderBy("id desc"), ByNotExpired(), Limit(2))
}

func GetRootCertificate(db *gorm.DB, selectors ...SelectorFunc) (*models.RootCertificate, error) {
return get[models.RootCertificate](db, selectors...)
}

func AddRootCertificate(db *gorm.DB, cert *models.RootCertificate) error {
return add(db, cert)
}

func ByPublicKey(key []byte) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
k := base64.StdEncoding.EncodeToString(key)
return db.Where("public_key = ?", k)
}
}

func OrderBy(order string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Order(order)
}
}

func Limit(limit int) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Limit(limit)
}
}

func ByNotExpired() SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("expires_at is null or expires_at > ?", time.Now().UTC())
}
}
13 changes: 13 additions & 0 deletions internal/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
Expand All @@ -15,6 +16,7 @@ import (
"github.com/infrahq/infra/internal/generate"
"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/secrets"
"github.com/infrahq/infra/uid"
)

Expand All @@ -25,6 +27,17 @@ func setupDB(t *testing.T) *gorm.DB {
db, err := data.NewDB(driver)
require.NoError(t, err)

fp := secrets.NewFileSecretProviderFromConfig(secrets.FileConfig{
Path: os.TempDir(),
})

kp := secrets.NewNativeSecretProvider(fp)

key, err := kp.GenerateDataKey("")
require.NoError(t, err)

models.SymmetricKey = key

return db
}

Expand Down
28 changes: 28 additions & 0 deletions internal/server/models/trustedcertificate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package models

import (
"time"
)

type TrustedCertificate struct {
Model

KeyAlgorithm string `validate:"required"`
SigningAlgorithm string `validate:"required"`
PublicKey Base64 `validate:"required"`
CertPEM []byte `validate:"required"` // pem encoded
Identity string `validate:"required"`
ExpiresAt time.Time
OneTimeUse bool
}

type RootCertificate struct {
Model

KeyAlgorithm string `validate:"required"`
SigningAlgorithm string `validate:"required"`
PublicKey Base64 `validate:"required"`
PrivateKey EncryptedAtRest `validate:"required"`
SignedCert EncryptedAtRest `validate:"required"` // contains private key? probably not pem encoded
ExpiresAt time.Time `validate:"required"`
}
30 changes: 30 additions & 0 deletions internal/server/models/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package models

import (
"database/sql/driver"
"encoding/base64"
"fmt"
)

type Base64 []byte

func (f Base64) Value() (driver.Value, error) {
r := base64.StdEncoding.EncodeToString([]byte(f))

return r, nil
}

func (f *Base64) Scan(v interface{}) error {
b, err := base64.StdEncoding.DecodeString(string(v.(string)))
if err != nil {
return fmt.Errorf("base64 decoding field: %w", err)
}

*f = Base64(b)

return nil
}

func (f Base64) GormDataType() string {
return "text"
}
8 changes: 3 additions & 5 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import (
"github.com/infrahq/infra/metrics"
)

type (
ReqHandlerFunc[Req any] func(c *gin.Context, req *Req) error
ResHandlerFunc[Res any] func(c *gin.Context) (Res, error)
ReqResHandlerFunc[Req, Res any] func(c *gin.Context, req *Req) (Res, error)
)
type ReqHandlerFunc[Req any] func(c *gin.Context, req *Req) error
type ResHandlerFunc[Res any] func(c *gin.Context) (Res, error)
type ReqResHandlerFunc[Req, Res any] func(c *gin.Context, req *Req) (Res, error)

func (a *API) registerRoutes(router *gin.RouterGroup) {
router.Use(
Expand Down
Loading

0 comments on commit 09802a3

Please sign in to comment.