diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 417e607..b02e933 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -8,18 +8,17 @@ import ( "github.com/VictoriaMetrics/metrics" ) -const defaultPushInterval = time.Minute - // ForNerds captures metrics from various bifrost processes var ForNerds = metrics.NewSet() -// MaybePushMetrics pushes metrics to url if url is not empty +// MaybePushMetrics pushes metrics to url if url is not empty. +// If interval is zero, a one minute interval is used func MaybePushMetrics(url string, interval time.Duration) { if url == "" { return } if interval == 0 { - interval = defaultPushInterval + interval = time.Minute } log.Printf("pushing metrics to %s every %.2fs\n", url, interval.Seconds()) diff --git a/pkg/club/bouncer.go b/pkg/club/bouncer.go index b6e710b..b0ca3ad 100644 --- a/pkg/club/bouncer.go +++ b/pkg/club/bouncer.go @@ -44,8 +44,8 @@ func Bouncer(rp *httputil.ReverseProxy) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := time.Now() - if len(r.TLS.PeerCertificates) == 0 { - panic("TLS Client Authentication is required") + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + panic("request must have tls client certificate") } peerCert := r.TLS.PeerCertificates[0] diff --git a/pkg/club/bouncer_test.go b/pkg/club/bouncer_test.go new file mode 100644 index 0000000..5adfac1 --- /dev/null +++ b/pkg/club/bouncer_test.go @@ -0,0 +1,123 @@ +package club + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "math/big" + "math/rand" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" + "time" +) + +func TestBouncerNoTLS(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("this should panic but did not") + } + }() + + backendServer := httptest.NewServer(nil) + defer backendServer.Close() + backendUrl, _ := url.Parse(backendServer.URL) + + br := Bouncer(httputil.NewSingleHostReverseProxy(backendUrl)) + rr := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/", nil) + br.ServeHTTP(rr, request) +} + +func TestBouncer(t *testing.T) { + randReader := rand.New(rand.NewSource(42)) + // generate key pair and certificate + priv, err := ecdsa.GenerateKey(elliptic.P256(), randReader) + if err != nil { + t.Errorf("error generating private key %s", err) + } + privBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Errorf("error marshaling private key %s", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + crtBytes, err := x509.CreateCertificate(randReader, &template, &template, priv.Public(), priv) + if err != nil { + t.Errorf("error creating certificate %s", err) + } + crtPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: crtBytes}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + crt, err := tls.X509KeyPair(crtPem, keyPem) + if err != nil { + t.Errorf("error loading certificate %s", err) + } + + // backend server handler checks if request has expected header + backendServer := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + rctx := r.Header.Get(RequestContextHeader) + if rctx == "" { + t.Errorf("expected %s header in request", RequestContextHeader) + } + requestContext := RequestContext{} + if err := json.Unmarshal([]byte(rctx), &requestContext); err != nil { + t.Errorf("error unmarshaling request context %s", err) + } + + if string(requestContext.Authentication.ClientCert.ClientCertPEM) != string(crtPem) { + t.Errorf("unexpected certificate in request context header") + } + })) + defer backendServer.Close() + backendUrl, err := url.Parse(backendServer.URL) + if err != nil { + t.Errorf("error parsing backedn url %s", err) + } + + // bouncer wraps around a reverse proxy that proxies requests to the HTTP backend + br := Bouncer(httputil.NewSingleHostReverseProxy(backendUrl)) + + // TLS server accepts client requests requiring TLS client cert auth + server := httptest.NewUnstartedServer(br) + server.TLS = &tls.Config{ + ClientAuth: tls.RequireAnyClientCert, + } + server.StartTLS() + defer server.Close() + + // add generated certs to client + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{crt}, + InsecureSkipVerify: true, + }, + }, + } + + // create request to TLS server + request, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Errorf("error creating request %s", err) + } + + if _, err := client.Do(request); err != nil { + t.Errorf("error doing request %s", err) + } +}