Skip to content

Commit

Permalink
config: reload cert files from disk automatically
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Pasquier <spasquie@redhat.com>
  • Loading branch information
simonpasquier committed Mar 4, 2019
1 parent 7a3416f commit 8c3ca15
Show file tree
Hide file tree
Showing 6 changed files with 642 additions and 70 deletions.
200 changes: 160 additions & 40 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
package config

import (
"bytes"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/mwitkow/go-conntrack"
Expand Down Expand Up @@ -124,42 +127,51 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string) (*http.Client, error
// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string) (http.RoundTripper, error) {
newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: false,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
}

// If a bearer token is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if len(cfg.BearerToken) > 0 {
rt = NewBearerAuthRoundTripper(cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewBearerAuthFileRoundTripper(cfg.BearerTokenFile, rt)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}
// Return a new configured RoundTripper.
return rt, nil
}

tlsConfig, err := NewTLSConfig(&cfg.TLSConfig)
if err != nil {
return nil, err
}
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: false,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
}

// If a bearer token is provided, create a round tripper that will set the
// Authorization header correctly on each request.
if len(cfg.BearerToken) > 0 {
rt = NewBearerAuthRoundTripper(cfg.BearerToken, rt)
} else if len(cfg.BearerTokenFile) > 0 {
rt = NewBearerAuthFileRoundTripper(cfg.BearerTokenFile, rt)
if len(cfg.TLSConfig.CAFile) == 0 {
// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}

if cfg.BasicAuth != nil {
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}

// Return a new configured RoundTripper.
return rt, nil
return newTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
}

type bearerAuthRoundTripper struct {
Expand Down Expand Up @@ -258,14 +270,13 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
if len(cfg.CAFile) > 0 {
caCertPool := x509.NewCertPool()
// Load CA cert.
caCert, err := ioutil.ReadFile(cfg.CAFile)
b, err := readCAFile(cfg.CAFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified CA cert %s: %s", cfg.CAFile, err)
return nil, err
}
if !updateRootCA(tlsConfig, b) {
return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile)
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
}

if len(cfg.ServerName) > 0 {
Expand All @@ -277,13 +288,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
} else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 {
return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile)
} else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 {
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", cfg.CertFile, cfg.KeyFile, err)
// Verify that client cert and key are valid.
if _, err := cfg.getClientCertificate(nil); err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
tlsConfig.GetClientCertificate = cfg.getClientCertificate
}
tlsConfig.BuildNameToCertificate()

return tlsConfig, nil
}
Expand All @@ -308,6 +318,116 @@ func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return unmarshal((*plain)(c))
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}
return &cert, nil
}

// readCAFile reads the CA cert file from disk.
func readCAFile(f string) ([]byte, error) {
data, err := ioutil.ReadFile(f)
if err != nil {
return nil, fmt.Errorf("unable to load specified CA cert %s: %s", f, err)
}
return data, nil
}

// updateRootCA parses the given byte slice as a series of PEM encoded certificates and updates tls.Config.RootCAs.
func updateRootCA(cfg *tls.Config, b []byte) bool {
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(b) {
return false
}
cfg.RootCAs = caCertPool
return true
}

// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
tlsConfig *tls.Config
}

func newTLSRoundTripper(
cfg *tls.Config,
caFile string,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
newRT: newRT,
tlsConfig: cfg,
}

rt, err := t.newRT(t.tlsConfig)
if err != nil {
return nil, err
}
t.rt = rt

_, t.hashCAFile, err = t.getCAWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) {
b, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, err
}
h := md5.Sum(b)
return b, h[:], nil

}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
b, h, err := t.getCAWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(h[:], t.hashCAFile)
rt := t.rt
t.mtx.RUnlock()
if equal {
// The CA cert hasn't changed, use the existing RoundTripper.
return rt.RoundTrip(req)
}

// Create a new RoundTripper.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, b) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
}
rt, err = t.newRT(tlsConfig)
if err != nil {
return nil, err
}

t.mtx.Lock()
t.rt = rt
t.hashCAFile = h[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
}

func (c HTTPClientConfig) String() string {
b, err := yaml.Marshal(c)
if err != nil {
Expand Down
Loading

0 comments on commit 8c3ca15

Please sign in to comment.