diff --git a/.travis.yml b/.travis.yml index 2637dcb16..c6c3279bc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,3 +2,11 @@ language: go go: - 1.6.2 + +before_script: + - wget https://releases.hashicorp.com/consul/0.6.4/consul_0.6.4_linux_amd64.zip + - unzip consul_0.6.4_linux_amd64.zip + - ./consul --version + +script: + - ../consul agent -server -bootstrap-expect 1 -data-dir /tmp/consul & \ No newline at end of file diff --git a/cert/consul_source.go b/cert/consul_source.go new file mode 100644 index 000000000..601a19515 --- /dev/null +++ b/cert/consul_source.go @@ -0,0 +1,132 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "log" + "net/url" + "path" + "reflect" + "time" + + "github.com/hashicorp/consul/api" +) + +type ConsulSource struct { + CertURL string + ClientCAURL string +} + +const kvURLPrefix = "/v1/kv/" + +func parseConsulURL(consulURL, stripPrefix string) (client *api.Client, key string, err error) { + u, err := url.Parse(consulURL) + if err != nil { + return nil, "", err + } + var token string + if len(u.Query()["token"]) > 0 { + token = u.Query()["token"][0] + } + client, err = api.NewClient(&api.Config{Address: u.Host, Scheme: u.Scheme, Token: token}) + if err != nil { + return nil, "", err + } + key = u.RequestURI()[len(stripPrefix):] + return client, key, nil +} + +func (s ConsulSource) LoadClientCAs() (*x509.CertPool, error) { + if s.ClientCAURL == "" { + return nil, nil + } + + client, key, err := parseConsulURL(s.ClientCAURL, kvURLPrefix) + if err != nil { + return nil, err + } + + pemBlocks, _, err := getCerts(client, key, 0) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + x := x509.NewCertPool() + for name, pemBlock := range pemBlocks { + if !x.AppendCertsFromPEM(pemBlock) { + log.Printf("[WARN] cert: Failed to add client CA certificate from %s", name) + continue + } + } + + log.Printf("[INFO] cert: Load client CA certs from %s", s.ClientCAURL) + return x, nil +} + +func (s ConsulSource) Certificates() chan []tls.Certificate { + if s.CertURL == "" { + return nil + } + + client, key, err := parseConsulURL(s.CertURL, kvURLPrefix) + if err != nil { + log.Printf("[ERROR] cert: Failed to create consul client. %s", err) + } + + pemBlocksCh := make(chan map[string][]byte, 1) + go watchKV(client, key, pemBlocksCh) + + ch := make(chan []tls.Certificate, 1) + go func() { + for pemBlocks := range pemBlocksCh { + certs, err := loadCertificates(pemBlocks) + if err != nil { + log.Printf("[ERROR] cert: Failed to load certificates. %s", err) + continue + } + ch <- certs + } + }() + return ch +} + +// watchKV monitors a key in the KV store for changes. +func watchKV(client *api.Client, key string, pemBlocks chan map[string][]byte) { + var lastIndex uint64 + var lastValue map[string][]byte + + for { + value, index, err := getCerts(client, key, lastIndex) + if err != nil { + log.Printf("[WARN] cert: Error fetching certificates from %s. %v", key, err) + time.Sleep(time.Second) + continue + } + + if !reflect.DeepEqual(value, lastValue) || index != lastIndex { + log.Printf("[INFO] cert: Certificate index changed to #%d", index) + pemBlocks <- value + lastValue, lastIndex = value, index + } + } +} + +func getCerts(client *api.Client, key string, waitIndex uint64) (pemBlocks map[string][]byte, lastIndex uint64, err error) { + q := &api.QueryOptions{RequireConsistent: true, WaitIndex: waitIndex} + kvpairs, meta, err := client.KV().List(key, q) + if err != nil { + return nil, 0, err + } + if len(kvpairs) == 0 { + return nil, meta.LastIndex, nil + } + pemBlocks = map[string][]byte{} + for _, kvpair := range kvpairs { + pemBlocks[path.Base(kvpair.Key)] = kvpair.Value + } + return pemBlocks, meta.LastIndex, nil +} diff --git a/cert/file_source.go b/cert/file_source.go new file mode 100644 index 000000000..fd9623478 --- /dev/null +++ b/cert/file_source.go @@ -0,0 +1,62 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" +) + +// FileSource implements a certificate source for one +// TLS and one client authentication certificate. +// The certificates are loaded during startup and are cached +// in memory until the program exits. +// It exists to support the legacy configuration only. The +// PathSource should be used instead. +type FileSource struct { + CertFile string + KeyFile string + ClientAuthFile string +} + +func (s FileSource) LoadClientCAs() (*x509.CertPool, error) { + if s.ClientAuthFile == "" { + return nil, nil + } + + pemBlock, err := ioutil.ReadFile(s.ClientAuthFile) + if err != nil { + return nil, fmt.Errorf("cert: cannot load client CAs. %s", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pemBlock) { + return nil, fmt.Errorf("cert: failed to add client auth certs from %s", s.ClientAuthFile) + } + + return pool, nil +} + +func (s FileSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + ch <- []tls.Certificate{loadX509KeyPair(s.CertFile, s.KeyFile)} + close(ch) + return ch +} + +func loadX509KeyPair(certFile, keyFile string) tls.Certificate { + if certFile == "" { + log.Fatalf("[FATAL] cert: CertFile is required") + } + + if keyFile == "" { + keyFile = certFile + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("[FATAL] cert: Error loading certificate. %s", err) + } + return cert +} diff --git a/cert/http_source.go b/cert/http_source.go new file mode 100644 index 000000000..3b1ed5cd1 --- /dev/null +++ b/cert/http_source.go @@ -0,0 +1,68 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "log" + "time" +) + +// HTTPSource implements a certificate source which loads +// TLS and client authentication certificates from an HTTP/HTTPS server. +// The CertURL/ClientCAURL must point to a text file in the directory +// of the certificates. The text file contains all files that should +// be loaded from this directory - one filename per line. +// The TLS certificates are updated automatically when Refresh +// is not zero. Refresh cannot be less than one second to prevent +// busy loops. +type HTTPSource struct { + CertURL string + ClientCAURL string + Refresh time.Duration +} + +func (s HTTPSource) LoadClientCAs() (*x509.CertPool, error) { + pemBlocks, err := loadURL(s.ClientCAURL) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + x := x509.NewCertPool() + for name, pemBlock := range pemBlocks { + if !x.AppendCertsFromPEM(pemBlock) { + log.Printf("[WARN] cert: Could not add client CA certificate from %s", name) + continue + } + } + + log.Printf("[INFO] cert: Load client CA certs from %s", s.ClientCAURL) + return x, nil +} + +func (s HTTPSource) Certificates() chan []tls.Certificate { + loadCerts := func() ([]tls.Certificate, error) { + pemBlocks, err := loadURL(s.CertURL) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + certs, err := loadCertificates(pemBlocks) + if err != nil { + return nil, err + } + + return certs, nil + } + + ch := make(chan []tls.Certificate, 1) + go watch(ch, s.Refresh, s.CertURL, loadCerts) + return ch +} diff --git a/cert/load.go b/cert/load.go new file mode 100644 index 000000000..77e325754 --- /dev/null +++ b/cert/load.go @@ -0,0 +1,167 @@ +package cert + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strings" +) + +const MaxSize = 1 << 20 // 1MB + +func loadURL(listURL string) (pemBlocks map[string][]byte, err error) { + if listURL == "" { + return nil, nil + } + + baseURL, err := base(listURL) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + fetch := func(url string) (buf []byte, err error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return ioutil.ReadAll(resp.Body) + } + + // fetch the file with the list of filenames + list, err := fetch(listURL) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + // fetch the individual files + pemBlocks = map[string][]byte{} + for _, p := range strings.Split(string(list), "\n") { + if p == "" { + continue + } + + path := baseURL + p + + buf, err := fetch(path) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + pemBlocks[path] = buf + } + + return pemBlocks, nil +} + +func loadPath(root string) (pemBlocks map[string][]byte, err error) { + if root == "" { + return nil, nil + } + + pemBlocks = map[string][]byte{} + err = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() || filepath.Ext(info.Name()) != ".pem" || strings.HasPrefix(info.Name(), ".") { + return nil + } + + if info.Size() > MaxSize { + log.Print("[WARN] cert: File too large %s", info.Name) + return nil + } + + buf, err := ioutil.ReadFile(path) + if err != nil { + return fmt.Errorf("cert: %s", err) + } + + pemBlocks[path] = buf + return nil + }) + + if err != nil { + return nil, err + } + + return pemBlocks, nil +} + +func loadCertificates(pemBlocks map[string][]byte) ([]tls.Certificate, error) { + var n []string + x := map[string]tls.Certificate{} + + for name := range pemBlocks { + var certFile, keyFile string + switch { + case strings.HasSuffix(name, "-cert.pem"): + certFile, keyFile = name, replaceSuffix(name, "-cert.pem", "-key.pem") + case strings.HasSuffix(name, "-key.pem"): + certFile, keyFile = replaceSuffix(name, "-key.pem", "-cert.pem"), name + case strings.HasSuffix(name, ".pem"): + certFile, keyFile = name, name + default: + continue + } + + if _, exists := x[certFile]; exists { + continue + } + + cert, key := pemBlocks[certFile], pemBlocks[keyFile] + if cert == nil || key == nil { + return nil, fmt.Errorf("cert: cannot load certificate %s", name) + } + + c, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("cert: invalid certificate %s. %s", name, err) + } + + x[certFile] = c + n = append(n, certFile) + } + + // append certificates in alphabetical order of the + // cert filenames. This determines which certificate + // becomes the default certificate (the first one) + sort.Strings(n) + var certs []tls.Certificate + for _, certFile := range n { + certs = append(certs, x[certFile]) + } + + return certs, nil +} + +// base returns the rawurl with the last element of the path +// removed. http://foo.com/x/y becomes http://foo.com/x +func base(rawurl string) (string, error) { + if rawurl == "" { + return "", nil + } + u, err := url.Parse(rawurl) + if err != nil { + return "", err + } + if u.Path != "/" { + u.Path = path.Dir(u.Path) + } + return u.String(), nil +} + +// replaceSuffix replaces oldSuffix with newSuffix in s. +// It is only valid when s has oldSuffix and oldSuffix is not empty. +func replaceSuffix(s string, oldSuffix, newSuffix string) string { + return s[:len(s)-len(oldSuffix)] + newSuffix +} diff --git a/cert/load_test.go b/cert/load_test.go new file mode 100644 index 000000000..05147f0bd --- /dev/null +++ b/cert/load_test.go @@ -0,0 +1,36 @@ +package cert + +import "testing" + +func TestBase(t *testing.T) { + tests := []struct { + in, out, err string + }{ + {"", "", ""}, + {"http://foo.com/x/y", "http://foo.com/x", ""}, + {"http://foo.com/x/y?p=q", "http://foo.com/x?p=q", ""}, + } + + for i, tt := range tests { + u, err := base(tt.in) + if err != nil { + if got, want := err.Error(), tt.err; got != want { + t.Errorf("%d: got %v want %v", i, got, want) + continue + } + } + if tt.err != "" { + t.Errorf("%d: got nil want %v", i, tt.err) + continue + } + if got, want := u, tt.out; got != want { + t.Errorf("%d: got %v want %v", i, got, want) + } + } +} + +func TestReplaceSuffix(t *testing.T) { + if got, want := replaceSuffix("ab", "b", "c"), "ac"; got != want { + t.Errorf("got %q want %q", got, want) + } +} diff --git a/cert/path_source.go b/cert/path_source.go new file mode 100644 index 000000000..c704d97b5 --- /dev/null +++ b/cert/path_source.go @@ -0,0 +1,63 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "log" + "time" +) + +// PathSource implements a certificate source which loads +// TLS and client authentication certificates from a directory. +// The TLS certificates are updated automatically when Refresh +// is not zero. Refresh cannot be less than one second to prevent +// busy loops. +type PathSource struct { + CertPath string + ClientCAPath string + Refresh time.Duration +} + +func (s PathSource) LoadClientCAs() (*x509.CertPool, error) { + pemBlocks, err := loadPath(s.ClientCAPath) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + x := x509.NewCertPool() + for name, pemBlock := range pemBlocks { + if !x.AppendCertsFromPEM(pemBlock) { + log.Printf("[WARN] cert: Could not add client CA certificate from %s", name) + continue + } + } + return x, nil +} + +func (s PathSource) Certificates() chan []tls.Certificate { + loadCerts := func() ([]tls.Certificate, error) { + pemBlocks, err := loadPath(s.CertPath) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + certs, err := loadCertificates(pemBlocks) + if err != nil { + return nil, err + } + + return certs, nil + } + + ch := make(chan []tls.Certificate, 1) + go watch(ch, s.Refresh, s.CertPath, loadCerts) + return ch +} diff --git a/cert/source.go b/cert/source.go new file mode 100644 index 000000000..f3e8b8cb6 --- /dev/null +++ b/cert/source.go @@ -0,0 +1,51 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" +) + +// Source provides the interface for dynamic certificate sources. +// +// Certificates() loads certificates for TLS connections. +// The first certificate is used as the default certificate +// if the client does not support SNI or no matching certificate +// could be found. TLS certificates can be updated at runtime. +// +// LoadClientCAs() provides certificates for client certificate +// authentication. +type Source interface { + Certificates() chan []tls.Certificate + LoadClientCAs() (*x509.CertPool, error) +} + +// TLSConfig creates a tls.Config which sets the +// GetCertificate field to a certificate store +// which uses the given source to update the +// the certificates on demand. +// +// It also sets the ClientCAs field if +// src.LoadClientCAs returns a non-nil value +// and sets ClientAuth to RequireAndVerifyClientCert. +func TLSConfig(src Source) (*tls.Config, error) { + clientCAs, err := src.LoadClientCAs() + if err != nil { + return nil, err + } + + store := NewStore() + x := &tls.Config{GetCertificate: store.GetCertificate} + + if clientCAs != nil { + x.ClientCAs = clientCAs + x.ClientAuth = tls.RequireAndVerifyClientCert + } + + go func() { + for certs := range src.Certificates() { + store.SetCertificates(certs) + } + }() + + return x, nil +} diff --git a/cert/source_test.go b/cert/source_test.go new file mode 100644 index 000000000..d51fe562d --- /dev/null +++ b/cert/source_test.go @@ -0,0 +1,239 @@ +package cert + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "log" + "math/big" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/hashicorp/consul/api" +) + +type StaticSource struct { + cert tls.Certificate +} + +func (s StaticSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + ch <- []tls.Certificate{s.cert} + close(ch) + return ch +} + +func (s StaticSource) LoadClientCAs() (*x509.CertPool, error) { + return nil, nil +} + +func TestStaticSource(t *testing.T) { + certPEM, keyPEM := makeCert("localhost", time.Minute) + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair: got %s want nil", err) + } + testSource(t, StaticSource{cert}, newCertPool(certPEM), 0) +} + +func TestFileSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + certFile, keyFile := saveCert(dir, "localhost", certPEM, keyPEM) + testSource(t, FileSource{CertFile: certFile, KeyFile: keyFile}, newCertPool(certPEM), 0) +} + +func TestPathSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + saveCert(dir, "localhost", certPEM, keyPEM) + testSource(t, PathSource{CertPath: dir}, newCertPool(certPEM), 0) +} + +func TestHTTPSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + certFile, keyFile := saveCert(dir, "localhost", certPEM, keyPEM) + listFile := filepath.Base(certFile) + "\n" + filepath.Base(keyFile) + "\n" + writeFile(filepath.Join(dir, "list"), []byte(listFile)) + + srv := httptest.NewServer(http.FileServer(http.Dir(dir))) + defer srv.Close() + + testSource(t, HTTPSource{CertURL: srv.URL + "/list"}, newCertPool(certPEM), 50*time.Millisecond) +} + +func TestConsulSource(t *testing.T) { + const certURL = "http://localhost:8500/v1/kv/fabio/test/consul-server" + client, key, err := parseConsulURL(certURL, kvURLPrefix) + if err != nil { + t.Fatalf("Failed to create consul client: %s", err) + } + defer func() { client.KV().DeleteTree(key, &api.WriteOptions{}) }() + + write := func(name string, value []byte) { + p := &api.KVPair{Key: key + "/" + name, Value: value} + _, err := client.KV().Put(p, &api.WriteOptions{}) + if err != nil { + t.Fatalf("Failed to write %q to consul: %s", p.Key, err) + } + } + + certPEM, keyPEM := makeCert("localhost", time.Minute) + write("localhost-cert.pem", certPEM) + write("localhost-key.pem", keyPEM) + + testSource(t, ConsulSource{CertURL: certURL}, newCertPool(certPEM), 50*time.Millisecond) +} + +// testSource runs an integration test by making an HTTPS request +// to https://localhost/ expecting that the source provides a valid +// certificate for "localhost". rootCAs is expected to contain a +// valid root certificate or the server certificate itself so that +// the HTTPS client can validate the certificate presented by the +// server. +func testSource(t *testing.T, source Source, rootCAs *x509.CertPool, sleep time.Duration) { + srvConfig, err := TLSConfig(source) + if err != nil { + t.Fatalf("TLSConfig: got %q want nil", err) + } + + // give the source some time to initialize if necessary + time.Sleep(sleep) + + // create the https server and start it + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "OK") + })) + srv.TLS = srvConfig + srv.StartTLS() + defer srv.Close() + + // create an http client that will accept the root CAs + // otherwise the HTTPS client will not verify the + // certificate presented by the server. + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + }, + }, + } + + call := func(host string) (statusCode int, body string, err error) { + // we need to call https://localhost:xxxxx/ + // Setting the Host header to provide the hostname does not work. + resp, err := client.Get(strings.Replace(srv.URL, "127.0.0.1", host, 1)) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + + return resp.StatusCode, string(data), nil + } + + // disable log output for the next call to prevent + // confusing log messages since they are expected + // http: TLS handshake error from 127.0.0.1:55044: remote error: bad certificate + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // calling https://foo.com/ should fail since there is only a cert for localhost + _, _, err = call("foo.com") + if got, want := err, "x509: certificate is valid for localhost, not foo.com"; got == nil || !strings.Contains(got.Error(), want) { + t.Fatalf("got %v want %v", got, want) + } + + statusCode, body, err := call("localhost") + if err != nil { + t.Fatalf("got %v want nil", err) + } + if got, want := statusCode, 200; got != want { + t.Fatalf("got %v want %v", got, want) + } + if got, want := body, "OK"; got != want { + t.Fatalf("got %v want %v", got, want) + } +} + +func tempDir() string { + dir, err := ioutil.TempDir("", "fabio") + if err != nil { + log.Fatal(err) + } + return dir +} + +func writeFile(filename string, data []byte) { + if err := ioutil.WriteFile(filename, data, 0644); err != nil { + log.Fatal(err) + } +} + +func newCertPool(x ...[]byte) *x509.CertPool { + p := x509.NewCertPool() + for _, b := range x { + p.AppendCertsFromPEM(b) + } + return p +} + +func saveCert(dir, host string, certPEM, keyPEM []byte) (certFile, keyFile string) { + certFile, keyFile = filepath.Join(dir, host+"-cert.pem"), filepath.Join(dir, host+"-key.pem") + writeFile(certFile, certPEM) + writeFile(keyFile, keyPEM) + return certFile, keyFile +} + +// makeCert creates a self-signed RSA certificate. +// taken from crypto/tls/generate_cert.go +func makeCert(host string, validFor time.Duration) (certPEM, keyPEM []byte) { + const bits = 1024 + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + log.Fatalf("Failed to generate private key: %s", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Fabio Co"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(validFor), + IsCA: true, + DNSNames: []string{host}, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %s", err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + return cert.Bytes(), key.Bytes() +} diff --git a/cert/store.go b/cert/store.go new file mode 100644 index 000000000..106e67d23 --- /dev/null +++ b/cert/store.go @@ -0,0 +1,100 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "log" + "strings" + "sync/atomic" +) + +// Store provides a dynamic certificate store which can be updated at +// runtime and is safe for concurrent use. +type Store struct { + cfg atomic.Value +} + +// NewStore creates an empty certificate store. +func NewStore() *Store { + s := new(Store) + s.cfg.Store(config{}) + return s +} + +// SetCertificates replaces the certificates of the store. +func (s *Store) SetCertificates(certs []tls.Certificate) { + cfg := config{Certificates: certs} + cfg.BuildNameToCertificate() + s.cfg.Store(cfg) + var names []string + for name := range cfg.NameToCertificate { + names = append(names, name) + } + log.Printf("[INFO] cert: Store has certificates for [%q]", strings.Join(names, ",")) +} + +// GetCertificate returns a matching certificate for the given clientHello if possible +// or the first certificate from the store. +func (s *Store) GetCertificate(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) { + return getCertificate(s.cfg.Load().(config), clientHello) +} + +func getCertificate(cfg config, clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) { + if len(cfg.Certificates) == 0 { + return nil, errors.New("cert: no certificates configured") + } + + if len(cfg.Certificates) == 1 || cfg.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &cfg.Certificates[0], nil + } + + name := strings.ToLower(clientHello.ServerName) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + if cert, ok := cfg.NameToCertificate[name]; ok { + return cert, nil + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := cfg.NameToCertificate[candidate]; ok { + return cert, nil + } + } + + // If nothing matches, return the first certificate. + return &cfg.Certificates[0], nil +} + +type config struct { + Certificates []tls.Certificate + NameToCertificate map[string]*tls.Certificate +} + +// BuildNameToCertificate parses Certificates and builds NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*tls.Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} diff --git a/cert/watch.go b/cert/watch.go new file mode 100644 index 000000000..578ac0214 --- /dev/null +++ b/cert/watch.go @@ -0,0 +1,40 @@ +package cert + +import ( + "crypto/tls" + "log" + "reflect" + "time" +) + +// watch monitors the result of the loadFn function for changes. +func watch(ch chan []tls.Certificate, refresh time.Duration, path string, loadFn func() ([]tls.Certificate, error)) { + once := refresh <= 0 + + // do not refresh more often than once a second to prevent busy loops + if refresh < time.Second { + refresh = time.Second + } + + var last []tls.Certificate + for { + next, err := loadFn() + if err != nil { + log.Printf("[ERROR] cert: Cannot load certificates from %s. %s", path, err) + time.Sleep(refresh) + continue + } + + if reflect.DeepEqual(next, last) { + time.Sleep(refresh) + continue + } + + ch <- next + last = next + + if once { + return + } + } +} diff --git a/config/config.go b/config/config.go index 73177ad74..19af32743 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,9 @@ package config -import "time" +import ( + "net/http" + "time" +) type Config struct { Proxy Proxy @@ -11,14 +14,25 @@ type Config struct { Runtime Runtime } -type Listen struct { - Addr string - KeyFile string +type TLSConfig struct { + Type string + CertPath string + ClientCAPath string + Refresh time.Duration + Header http.Header + + // legacy config CertFile string + KeyFile string ClientAuthFile string - TLS bool - ReadTimeout time.Duration - WriteTimeout time.Duration +} + +type Listen struct { + Addr string + Scheme string + ReadTimeout time.Duration + WriteTimeout time.Duration + TLSConfig TLSConfig } type UI struct { diff --git a/config/load.go b/config/load.go index 55d53c9e3..b837727ee 100644 --- a/config/load.go +++ b/config/load.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net/http" "os" "runtime" "strings" @@ -107,7 +108,7 @@ func load(p *properties.Properties) (cfg *Config, err error) { cfg.Registry.Consul.Scheme, cfg.Registry.Consul.Addr = parseScheme(cfg.Registry.Consul.Addr) - cfg.Listen, err = parseListen(cfg.Proxy.ListenerAddr, cfg.Proxy.ReadTimeout, cfg.Proxy.WriteTimeout) + cfg.Listen, err = parseListeners(cfg.Proxy.ListenerAddr, cfg.Proxy.ReadTimeout, cfg.Proxy.WriteTimeout) if err != nil { return nil, err } @@ -151,33 +152,119 @@ func parseScheme(s string) (scheme, addr string) { return "http", s } -func parseListen(addrs string, readTimeout, writeTimeout time.Duration) ([]Listen, error) { - listen := []Listen{} - for _, addr := range strings.Split(addrs, ",") { - addr = strings.TrimSpace(addr) - if addr == "" { +func parseListeners(cfgs string, readTimeout, writeTimeout time.Duration) (listen []Listen, err error) { + for _, cfg := range strings.Split(cfgs, ",") { + cfg = strings.TrimSpace(cfg) + if cfg == "" { continue } - var l Listen - p := strings.Split(addr, ";") - switch len(p) { - case 1: - l.Addr = p[0] - case 2: - l.Addr, l.CertFile, l.KeyFile, l.TLS = p[0], p[1], p[1], true - case 3: - l.Addr, l.CertFile, l.KeyFile, l.TLS = p[0], p[1], p[2], true - case 4: - l.Addr, l.CertFile, l.KeyFile, l.ClientAuthFile, l.TLS = p[0], p[1], p[2], p[3], true - default: - return nil, fmt.Errorf("invalid address %s", addr) + l, err := parseListen(cfg, readTimeout, writeTimeout) + if err != nil { + return nil, err } - l.ReadTimeout = readTimeout - l.WriteTimeout = writeTimeout + listen = append(listen, l) } - return listen, nil + return +} + +func parseListen(cfg string, readTimeout, writeTimeout time.Duration) (l Listen, err error) { + if cfg == "" { + return Listen{}, nil + } + + kv := func(s string) (k, v string) { + p := strings.SplitN(s, "=", 2) + if len(p) == 1 { + return p[0], "" + } + return p[0], p[1] + } + + opts := strings.Split(cfg, ";") + + if len(opts) > 1 && !strings.Contains(opts[1], "=") { + return parseLegacyListen(cfg, readTimeout, writeTimeout) + } + + l = Listen{ + Addr: opts[0], + Scheme: "http", + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + } + + for _, opt := range opts[1:] { + k, v := kv(opt) + switch k { + case "rt": // read timeout + d, err := time.ParseDuration(v) + if err != nil { + return Listen{}, err + } + l.ReadTimeout = d + case "wt": // write timeout + d, err := time.ParseDuration(v) + if err != nil { + return Listen{}, err + } + l.WriteTimeout = d + case "cs": // cert store + l.TLSConfig.Type = v + l.TLSConfig.Refresh = 3 * time.Second + l.Scheme = "https" + case "cert": + l.TLSConfig.CertPath = v + case "clientca": + l.TLSConfig.ClientCAPath = v + case "refresh": + d, err := time.ParseDuration(v) + if err != nil { + return Listen{}, err + } + l.TLSConfig.Refresh = d + case "hdr": + p := strings.SplitN(v, ": ", 2) + if len(p) != 2 { + return Listen{}, fmt.Errorf("invalid header %s", v) + } + if l.TLSConfig.Header == nil { + l.TLSConfig.Header = http.Header{} + } + l.TLSConfig.Header.Set(p[0], p[1]) + } + } + return +} + +func parseLegacyListen(cfg string, readTimeout, writeTimeout time.Duration) (l Listen, err error) { + opts := strings.Split(cfg, ";") + + l = Listen{ + Addr: opts[0], + Scheme: "http", + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + } + + if len(opts) > 1 { + l.Scheme = "https" + l.TLSConfig.Type = "file" + l.TLSConfig.CertFile = opts[1] + } + if len(opts) > 2 { + l.TLSConfig.KeyFile = opts[2] + } + if len(opts) > 3 { + l.TLSConfig.ClientAuthFile = opts[3] + } + if len(opts) > 4 { + return Listen{}, fmt.Errorf("invalid listener configuration") + } + + log.Printf("[WARN] proxy.addr legacy configuration for certificates is deprecated. Use cs=path configuration") + return l, nil } type tags []string diff --git a/config/load_test.go b/config/load_test.go index e6685149f..56f61e677 100644 --- a/config/load_test.go +++ b/config/load_test.go @@ -1,6 +1,7 @@ package config import ( + "net/http" "reflect" "testing" "time" @@ -90,6 +91,7 @@ ui.title = fabfab Listen: []Listen{ { Addr: ":1234", + Scheme: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, }, @@ -147,59 +149,65 @@ func TestParseScheme(t *testing.T) { } } -func TestParseAddr(t *testing.T) { +func TestParseListen(t *testing.T) { tests := []struct { in string - out []Listen + out Listen err string }{ { "", - []Listen{}, + Listen{}, "", }, { ":123", - []Listen{ - {Addr: ":123"}, - }, + Listen{Addr: ":123", Scheme: "http"}, "", }, { - ":123;cert.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "cert.pem", TLS: true}, - }, + ":123;rt=5s;wt=5s", + Listen{Addr: ":123", Scheme: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second}, "", }, { - ":123;cert.pem;key.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "key.pem", TLS: true}, + ":123;pathA;pathB;pathC", + Listen{ + Addr: ":123", + Scheme: "https", + TLSConfig: TLSConfig{ + Type: "file", + CertFile: "pathA", + KeyFile: "pathB", + ClientAuthFile: "pathC", + }, }, "", }, { - ":123;cert.pem;key.pem;client.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "key.pem", ClientAuthFile: "client.pem", TLS: true}, + ":123;cs=foo;cert=pathtocert;clientca=pathtoclientca;watch=true;hdr=auth: X;hdr=close: Y", + Listen{ + Addr: ":123", + Scheme: "https", + TLSConfig: TLSConfig{ + Type: "foo", + CertPath: "pathtocert", + ClientCAPath: "pathtoclientca", + Header: http.Header{"Auth": []string{"X"}, "Close": []string{"Y"}}, + Refresh: 3 * time.Second, + }, }, "", }, - { - ":123;cert.pem;key.pem;client.pem;", - nil, - "invalid address :123;cert.pem;key.pem;client.pem;", - }, } for i, tt := range tests { l, err := parseListen(tt.in, time.Duration(0), time.Duration(0)) if got, want := err, tt.err; (got != nil || want != "") && got.Error() != want { - t.Errorf("%d: got %v want %v", i, got, want) + t.Errorf("%d: got %+v want %+v", i, got, want) } if got, want := l, tt.out; !reflect.DeepEqual(got, want) { - t.Errorf("%d: got %v want %v", i, got, want) + t.Errorf("%d: got %+v want %+v", i, got, want) } } } diff --git a/fabio.properties b/fabio.properties index 08f9772ca..28a72d788 100644 --- a/fabio.properties +++ b/fabio.properties @@ -1,37 +1,145 @@ -# proxy.addr configures the HTTP and HTTPS listeners as a comma separated list. +# proxy.addr configures the HTTP and HTTPS listeners. # -# To configure an HTTP listener provide [host]:port. -# To configure an HTTPS listener provide [host]:port;certFile;keyFile;clientAuthFile. -# certFile and keyFile contain the public/private key pair for that listener -# in PEM format. If certFile contains both the public and private key then -# keyFile can be omittted. -# clientAuthFile contains the root CAs for client certificate validation. -# When clientAuthFile is provided the TLS configuration is set to -# RequireAndVerifyClientCert. +# Each listener is configured with and address and a +# list of optional arguments in the form of # -# Configure a single HTTP listener on port 9999: +# [host]:port;opt=arg;opt[=arg];... # +# General options: +# +# read timeout: rt= +# write timeout: wt= +# +# HTTPS listeners require a certificate source which is +# configured as follows: +# +# File +# +# The file certificate source supports one certificate +# which is loaded at startup and is cached until the service exits. +# It also supports the deprecated legacy format of configuring +# certificates. +# +# # legacy configuration (deprecated) +# path/to/cert.pem;path/to/key.pem;path/to/clientAuth.pem +# +# # new configuration +# cs=file;cert=p/a-cert.pem;cert=p/b-cert.pem;clientAuth=p/clientAuth.pem +# +# Path +# +# The path certificate source loads certificates from a directory in +# alphabetical order. Certificates need to be in PEM format. +# +# The cert option provides the path to the TLS certificates +# and the clientca option provides the path to the certificates for +# client authentication. +# +# TLS certificates are stored either in one or two files: +# +# www.example.com.pem or www.example.com-{cert,key}.pem +# +# The filename is not relevant only the {-cert,-key}.pem suffix matters +# but certificate and key file must have the same prefix. +# +# TLS certificates are loaded in alphabetical order +# and the first certificate is the default for clients which +# do not support SNI. +# +# The refresh option can be set to specify the refresh interval for the +# TLS certificates. Client authentication certificates cannot be refreshed since +# Go does not provide a mechanism for that yet. +# +# The default refresh interval is 3 seconds and cannot be lower than 1 second +# to prevent busy loops. To load the certificates only once +# and disable automatic refreshing set refresh to zero. +# +# cs=path;path=;cert=path/to/certs;clientca=path/to/clientcas;refresh=3s +# +# HTTP +# +# The http certificate source loads certificates from an HTTP/HTTPS server. +# +# The cert option provides a URL to a text file which contains all files that +# should be loaded from this directory. The filenames follow the same rules as +# for the Path source. The text file can be generated with: +# +# ls -1 *.pem > list +# +# The clientca option provides a URL for the client authentication certificates +# analogous to the cert option. +# +# +# Authentication credentials can be provided in the URL as request parameter, +# as basic authentication parameters or through a header. +# +# The refresh option can be set to specify the refresh interval for the +# TLS certificates. Client authentication certificates cannot be refreshed since +# Go does not provide a mechanism for that yet. +# +# The default refresh interval is 3 seconds and cannot be lower than 1 second +# to prevent busy loops. To load the certificates only once +# and disable automatic refreshing set refresh to zero. +# +# cs=http;cert=https://host.com/path/to/cert/list&token=123 +# cs=http;cert=https://user:pass@host.com/path/to/cert/list +# cs=http;cert=https://host.com/path/to/cert/list;hdr=Authorization: Bearer 1234 +# +# Consul +# +# The consul certificate source loads certificates from consul. +# +# The cert option provides a URL to a path in the KV store where the +# the TLS certificates are stored. +# +# The clientca option provides a URL to a path in the KV store where the +# the client authentication certificates are stored. +# +# The filenames follow the same rules as for the Path source. +# +# The TLS certificates are updated automatically whenever the KV store changes. +# The client authentication certificates cannot be updated automatically since +# Go does not provide a mechanism for that yet. +# +# cs=consul;cert=http://localhost:8500/v1/kv/path/to/cert&token=123 +# +# Vault +# +# The Vault certificate store uses HashiCorp Vault as the certificate +# store. +# +# cs=vault;url=http://localhost:1234/some/path +# +# Examples: +# +# # HTTP listener on port 9999 # proxy.addr = :9999 # -# Configure both an HTTP and HTTPS listener: +# # HTTP listener on IPv4 with read timeout +# proxy.addr = 1.2.3.4:9999;rt=3s +# +# # HTTP listener on IPv6 with write timeout +# proxy.addr = [2001:DB8::A/32]:9999;wt=5s +# +# # Multiple listeners +# proxy.addr = 1.2.3.4:9999;rt=3s,[2001:DB8::A/32]:9999;wt=5s +# +# # HTTPS listener on port 443 with file certificate store +# proxy.addr = :443;cs=file;cert=p/a-cert.pem;key=p/a-key.pem # -# proxy.addr = :9999,:443;path/to/cert.pem;path/to/key.pem;path/to/clientauth.pem +# # HTTPS listener with path based certificate store +# proxy.addr = :443;cs=path;path=path/to/certs # -# Configure multiple HTTP and HTTPS listeners on IPv4 and IPv6: +# # HTTPS listener with server based certificate store +# proxy.addr = :443;cs=http;url=https://host:port/path/to/certs?token=abc123 # -# proxy.addr = \ -# 1.2.3.4:9999, \ -# 5.6.7.8:9999, \ -# [2001:DB8::A/32]:9999, \ -# [2001:DB8::B/32]:9999, \ -# 1.2.3.4:443;path/to/certA.pem;path/to/keyA.pem, \ -# 5.6.7.8:443;path/to/certB.pem;path/to/keyB.pem, \ -# [2001:DB8::A/32]:443;path/to/certA.pem;path/to/keyA.pem, \ -# [2001:DB8::B/32]:443;path/to/certB.pem;path/to/keyB.pem +# # HTTPS listener with Vault certificate store +# proxy.addr = :443;cs=vault;url=https://host:port/path/to/certs # # The default is # # proxy.addr = :9999 +proxy.addr = :9999;cs=path;path=demo/cert # proxy.localip configures the ip address of the proxy which is added diff --git a/listen.go b/listen.go index 527e70f6b..b5d7581e7 100644 --- a/listen.go +++ b/listen.go @@ -2,9 +2,7 @@ package main import ( "crypto/tls" - "crypto/x509" - "errors" - "io/ioutil" + "fmt" "log" "net" "net/http" @@ -13,6 +11,7 @@ import ( "time" "github.com/armon/go-proxyproto" + "github.com/eBay/fabio/cert" "github.com/eBay/fabio/config" "github.com/eBay/fabio/exit" "github.com/eBay/fabio/proxy" @@ -45,15 +44,29 @@ func startListeners(listen []config.Listen, wait time.Duration, h http.Handler) } func listenAndServe(l config.Listen, h http.Handler) { - srv, err := newServer(l, h) - if err != nil { - log.Fatal("[FATAL] ", err) + srv := &http.Server{ + Handler: h, + Addr: l.Addr, + ReadTimeout: l.ReadTimeout, + WriteTimeout: l.WriteTimeout, + } + + if l.Scheme == "https" { + src, err := makeCertSource(l.TLSConfig) + if err != nil { + log.Fatal("[FATAL] ", err) + } + + srv.TLSConfig, err = cert.TLSConfig(src) + if err != nil { + log.Fatal("[FATAL] ", err) + } } if srv.TLSConfig != nil { - log.Printf("[INFO] HTTPS proxy listening on %s with certificate %s", l.Addr, l.CertFile) + log.Printf("[INFO] HTTPS proxy listening on %s", l.Addr) if srv.TLSConfig.ClientAuth == tls.RequireAndVerifyClientCert { - log.Printf("[INFO] Client certificate authentication enabled on %s with certificates from %s", l.Addr, l.ClientAuthFile) + log.Printf("[INFO] Client certificate authentication enabled on %s", l.Addr) } } else { log.Printf("[INFO] HTTP proxy listening on %s", l.Addr) @@ -64,41 +77,32 @@ func listenAndServe(l config.Listen, h http.Handler) { } } -var tlsLoadX509KeyPair = tls.LoadX509KeyPair - -func newServer(l config.Listen, h http.Handler) (*http.Server, error) { - srv := &http.Server{ - Addr: l.Addr, - Handler: h, - ReadTimeout: l.ReadTimeout, - WriteTimeout: l.WriteTimeout, - } - - if l.CertFile != "" { - cert, err := tlsLoadX509KeyPair(l.CertFile, l.KeyFile) - if err != nil { - return nil, err - } - - srv.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - - if l.ClientAuthFile != "" { - pemBlock, err := ioutil.ReadFile(l.ClientAuthFile) - if err != nil { - return nil, err - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(pemBlock) { - return nil, errors.New("failed to add client auth certs") - } - srv.TLSConfig.ClientCAs = pool - srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert - } +func makeCertSource(cfg config.TLSConfig) (cert.Source, error) { + switch cfg.Type { + case "file": + return cert.FileSource{ + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + ClientAuthFile: cfg.ClientAuthFile, + }, nil + + case "path": + return cert.PathSource{ + CertPath: cfg.CertPath, + ClientCAPath: cfg.ClientCAPath, + Refresh: cfg.Refresh, + }, nil + + case "http": + return cert.HTTPSource{ + CertURL: cfg.CertPath, + ClientCAURL: cfg.ClientCAPath, + Refresh: cfg.Refresh, + }, nil + + default: + return nil, fmt.Errorf("invalid certificate source %q", cfg.Type) } - - return srv, nil } func serve(srv *http.Server) error { diff --git a/listen_test.go b/listen_test.go index 270dff35b..6ca76099c 100644 --- a/listen_test.go +++ b/listen_test.go @@ -1,10 +1,8 @@ package main import ( - "crypto/tls" "net/http" "net/http/httptest" - "reflect" "sync" "testing" "time" @@ -14,48 +12,6 @@ import ( "github.com/eBay/fabio/route" ) -func TestNewServer(t *testing.T) { - h := http.DefaultServeMux - cert := tls.Certificate{} - tlsLoadX509KeyPair = func(string, string) (tls.Certificate, error) { - return cert, nil - } - defer func() { tlsLoadX509KeyPair = tls.LoadX509KeyPair }() - - tests := []struct { - in config.Listen - out *http.Server - err string - }{ - { - config.Listen{Addr: ":123"}, - &http.Server{Addr: ":123", Handler: h}, - "", - }, - { - config.Listen{Addr: ":123", CertFile: "cert.pem"}, - &http.Server{ - Addr: ":123", - Handler: h, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - }, - }, - "", - }, - } - - for i, tt := range tests { - srv, err := newServer(tt.in, h) - if got, want := err, tt.err; (got != nil || want != "") && got.Error() != want { - t.Errorf("%d: got %v want %v", i, got, want) - } - if got, want := srv, tt.out; !reflect.DeepEqual(got, want) { - t.Errorf("%d: got %v want %v", i, got, want) - } - } -} - func TestGracefulShutdown(t *testing.T) { req := func(url string) int { resp, err := http.Get(url)