Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault PKI cert source #315

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NOTICES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ github.com/rogpeppe/fastuuid
https://github.com/rogpeppe/fastuuid.git
License: BSD 3-clause (https://github.com/google/uuid/LICENSE)
Copyright © 2014, Roger Peppe All rights reserved.

golang.org/x/sync/singleflight
https://golang.org/x/sync/singleflight
License: BSD 3-clause (https://golang.org/x/sync/LICENSE)
Copyright (c) 2009 The Go Authors. All rights reserved.
62 changes: 51 additions & 11 deletions cert/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"

"github.com/fabiolb/fabio/config"
"golang.org/x/sync/singleflight"
)

// Source provides the interface for dynamic certificate sources.
Expand All @@ -22,6 +22,14 @@ type Source interface {
LoadClientCAs() (*x509.CertPool, error)
}

// Issuer is the interface implemented by sources that can issue certificates
// on-demand.
type Issuer interface {
// Issue issues a new certificate for the given common name. Issue must
// return a certificate or an error, never (nil, nil).
Issue(commonName string) (*tls.Certificate, error)
}

// NewSource generates a cert source from the config options.
func NewSource(cfg config.CertSource) (Source, error) {
switch cfg.Type {
Expand Down Expand Up @@ -58,41 +66,73 @@ func NewSource(cfg config.CertSource) (Source, error) {

case "vault":
return &VaultSource{
Addr: os.Getenv("VAULT_ADDR"),
CertPath: cfg.CertPath,
ClientCAPath: cfg.ClientCAPath,
CAUpgradeCN: cfg.CAUpgradeCN,
Refresh: cfg.Refresh,
vaultToken: os.Getenv("VAULT_TOKEN"),
Client: DefaultVaultClient,
}, nil
case "vault-pki":
src := NewVaultPKISource()
src.CertPath = cfg.CertPath
src.ClientCAPath = cfg.ClientCAPath
src.CAUpgradeCN = cfg.CAUpgradeCN
src.Refresh = cfg.Refresh
src.Client = DefaultVaultClient
return src, nil

default:
return nil, fmt.Errorf("invalid certificate source %q", cfg.Type)
}
}

// TLSConfig creates a tls.Config which sets the
// GetCertificate field to a certificate store
// which uses the given source to update the
// the certificates on demand.
// TLSConfig creates a tls.Config which sets the GetCertificate field to a
// certificate store which uses the given source to update the the certificates
// on-demand.
//
// It also sets the ClientCAs field if
// src.LoadClientCAs returns a non-nil value
// and sets ClientAuth to RequireAndVerifyClientCert.
// It also sets the ClientCAs field if src.LoadClientCAs returns a non-nil
// value and sets ClientAuth to RequireAndVerifyClientCert.
func TLSConfig(src Source, strictMatch bool, minVersion, maxVersion uint16, cipherSuites []uint16) (*tls.Config, error) {
clientCAs, err := src.LoadClientCAs()
if err != nil {
return nil, err
}

sf := &singleflight.Group{}
store := NewStore()
x := &tls.Config{
MinVersion: minVersion,
MaxVersion: maxVersion,
CipherSuites: cipherSuites,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
return getCertificate(store.certstore(), clientHello, strictMatch)
cert, err = getCertificate(store.certstore(), clientHello, strictMatch)
if cert != nil {
return
}

switch err {
case nil, ErrNoCertsStored:
// Store doesn't contain a suitable cert. Perhaps the source can issue one?
default:
// an unrecoverable error
return
}

ca, ok := src.(Issuer)
if !ok {
return
}

serverName := clientHello.ServerName
x, err, _ := sf.Do(serverName, func() (interface{}, error) {
return ca.Issue(serverName)
})
if err != nil {
return cert, err
}

return x.(*tls.Certificate), nil
},
}

Expand Down
183 changes: 123 additions & 60 deletions cert/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ func TestNewSource(t *testing.T) {
desc: "vault",
cfg: certsource("vault"),
src: &VaultSource{
Addr: os.Getenv("VAULT_ADDR"),
vaultToken: os.Getenv("VAULT_TOKEN"),
Client: DefaultVaultClient,
CertPath: "cert",
ClientCAPath: "clientca",
CAUpgradeCN: "upgcn",
Expand Down Expand Up @@ -205,7 +204,7 @@ func TestPathSource(t *testing.T) {
defer os.RemoveAll(dir)
certPEM, keyPEM := makePEM("localhost", time.Minute)
saveCert(dir, "localhost", certPEM, keyPEM)
testSource(t, PathSource{CertPath: dir}, makeCertPool(certPEM), 0)
testSource(t, PathSource{CertPath: dir}, makeCertPool(certPEM), 10*time.Millisecond)
}

func TestHTTPSource(t *testing.T) {
Expand Down Expand Up @@ -339,6 +338,10 @@ func vaultServer(t *testing.T, addr, rootToken string) (*exec.Cmd, *vaultapi.Cli
path "secret/fabio/cert/*" {
capabilities = ["read"]
}

path "test-pki/issue/fabio" {
capabilities = ["update"]
}
`

if err := c.Sys().PutPolicy("fabio", policy); err != nil {
Expand Down Expand Up @@ -371,6 +374,43 @@ func makeToken(t *testing.T, c *vaultapi.Client, wrapTTL string, req *vaultapi.T
return resp.Auth.ClientToken
}

var vaultTestCases = []struct {
desc string
wrapTTL string
req *vaultapi.TokenCreateRequest
dropWarn bool
}{
{
desc: "renewable token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Policies: []string{"fabio"}},
},
{
desc: "non-renewable token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Renewable: new(bool), Policies: []string{"fabio"}},
dropWarn: true,
},
{
desc: "renewable orphan token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", NoParent: true, Policies: []string{"fabio"}},
},
{
desc: "non-renewable orphan token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", NoParent: true, Renewable: new(bool), Policies: []string{"fabio"}},
dropWarn: true,
},
{
desc: "renewable wrapped token",
wrapTTL: "10s",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Policies: []string{"fabio"}},
},
{
desc: "non-renewable wrapped token",
wrapTTL: "10s",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Renewable: new(bool), Policies: []string{"fabio"}},
dropWarn: true,
},
}

func TestVaultSource(t *testing.T) {
const (
addr = "127.0.0.1:58421"
Expand All @@ -389,55 +429,17 @@ func TestVaultSource(t *testing.T) {
t.Fatalf("logical.Write failed: %s", err)
}

newBool := func(b bool) *bool { return &b }

// run tests
tests := []struct {
desc string
wrapTTL string
req *vaultapi.TokenCreateRequest
dropWarn bool
}{
{
desc: "renewable token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Policies: []string{"fabio"}},
},
{
desc: "non-renewable token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Renewable: newBool(false), Policies: []string{"fabio"}},
dropWarn: true,
},
{
desc: "renewable orphan token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", NoParent: true, Policies: []string{"fabio"}},
},
{
desc: "non-renewable orphan token",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", NoParent: true, Renewable: newBool(false), Policies: []string{"fabio"}},
dropWarn: true,
},
{
desc: "renewable wrapped token",
wrapTTL: "10s",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Policies: []string{"fabio"}},
},
{
desc: "non-renewable wrapped token",
wrapTTL: "10s",
req: &vaultapi.TokenCreateRequest{Lease: "1m", TTL: "1m", Renewable: newBool(false), Policies: []string{"fabio"}},
dropWarn: true,
},
}

pool := makeCertPool(certPEM)
timeout := 500 * time.Millisecond
for _, tt := range tests {
for _, tt := range vaultTestCases {
tt := tt // capture loop var
t.Run(tt.desc, func(t *testing.T) {
src := &VaultSource{
Addr: "http://" + addr,
CertPath: certPath,
vaultToken: makeToken(t, client, tt.wrapTTL, tt.req),
Client: &vaultClient{
addr: "http://" + addr,
token: makeToken(t, client, tt.wrapTTL, tt.req),
},
CertPath: certPath,
}

// suppress the log warning about a non-renewable token
Expand All @@ -449,6 +451,70 @@ func TestVaultSource(t *testing.T) {
}
}

func TestVaultPKISource(t *testing.T) {
const (
addr = "127.0.0.1:58421"
rootToken = "token"
certPath = "test-pki/issue/fabio"
)

// start a vault server
vault, client := vaultServer(t, addr, rootToken)
defer vault.Process.Kill()

// mount the PKI backend
err := client.Sys().Mount("test-pki", &vaultapi.MountInput{
Type: "pki",
Config: vaultapi.MountConfigInput{
DefaultLeaseTTL: "1h", // default validity period of issued certificates
MaxLeaseTTL: "2h", // maximum validity period of issued certificates
},
})
if err != nil {
t.Fatalf("Mount pki backend failed: %s", err)
}

// generate root CA cert
resp, err := client.Logical().Write("test-pki/root/generate/internal", map[string]interface{}{
"common_name": "Fabio Test CA",
"ttl": "2h",
})
if err != nil {
t.Fatalf("Generate root failed: %s", err)
}
caPool := makeCertPool([]byte(resp.Data["certificate"].(string)))

// create role
role := filepath.Base(certPath)
_, err = client.Logical().Write("test-pki/roles/"+role, map[string]interface{}{
"allowed_domains": "",
"allow_localhost": true,
"allow_ip_sans": true,
"organization": "Fabio Test",
})
if err != nil {
t.Fatalf("Write role failed: %s", err)
}

for _, tt := range vaultTestCases {
tt := tt // capture loop var
t.Run(tt.desc, func(t *testing.T) {
src := NewVaultPKISource()
src.Client = &vaultClient{
addr: "http://" + addr,
token: makeToken(t, client, tt.wrapTTL, tt.req),
}
src.CertPath = certPath

// suppress the log warning about a non-renewable token
// since this is the expected behavior.
dropNotRenewableWarning = tt.dropWarn
testSource(t, src, caPool, 0)
dropNotRenewableWarning = false
})
}
}

// testSource runs an integration test by making an HTTPS request
// to https://localhost/ expecting that the source provides a valid
// certificate for "localhost". rootCAs is expected to contain a
Expand Down Expand Up @@ -505,19 +571,18 @@ func testSource(t *testing.T, source Source, rootCAs *x509.CertPool, sleep time.
}
}

// make a call for which certificate validation fails.
fail(http11)
fail(http20)

// now make the call that should succeed
// make a call for which certificate validation succeeds.
succeed(http11, "OK HTTP/1.1")
succeed(http20, "OK HTTP/2.0")

// now make the call that should fail.
fail(http11)
fail(http20)
}

// roundtrip starts a TLS server with the given server configuration and
// then calls "https://<host>/" with the given client. "host" must resolve
// to 127.0.0.1.
func roundtrip(host string, srvConfig *tls.Config, client *http.Client) (code int, body string, err error) {
// then sends an SNI request with the given serverName.
func roundtrip(serverName string, srvConfig *tls.Config, client *http.Client) (code int, body string, err error) {
// create an HTTPS server and start it. It will be listening on 127.0.0.1
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "OK ", r.Proto)
Expand All @@ -526,11 +591,9 @@ func roundtrip(host string, srvConfig *tls.Config, client *http.Client) (code in
srv.StartTLS()
defer srv.Close()

// for the certificate validation to work we need to use a hostname
// in the URL which resolves to 127.0.0.1. We can't fake the hostname
// via the Host header.
url := strings.Replace(srv.URL, "127.0.0.1", host, 1)
resp, err := client.Get(url)
// configure SNI
client.Transport.(*http.Transport).TLSClientConfig.ServerName = serverName
resp, err := client.Get(srv.URL)
if err != nil {
return 0, "", err
}
Expand Down
4 changes: 3 additions & 1 deletion cert/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ func (s *Store) certstore() certstore {
return s.cs.Load().(certstore)
}

var ErrNoCertsStored = errors.New("cert: no certificates stored")

func getCertificate(cs certstore, clientHello *tls.ClientHelloInfo, strictMatch bool) (cert *tls.Certificate, err error) {
if len(cs.Certificates) == 0 {
return nil, errors.New("cert: no certificates stored")
return nil, ErrNoCertsStored
}

// There's only one choice, so no point doing any work.
Expand Down
Loading