Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pluggable certificate storage (following on from #284) #332

Merged
merged 3 commits into from
Apr 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ctx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goproxy

import (
"crypto/tls"
"net/http"
"regexp"
)
Expand All @@ -19,14 +20,19 @@ type ProxyCtx struct {
// call of RespHandler
UserData interface{}
// Will connect a request to a response
Session int64
proxy *ProxyHttpServer
Session int64
certStore CertStorage
proxy *ProxyHttpServer
}

type RoundTripper interface {
RoundTrip(req *http.Request, ctx *ProxyCtx) (*http.Response, error)
}

type CertStorage interface {
Fetch(hostname string, gen func() (*tls.Certificate, error)) (*tls.Certificate, error)
}

type RoundTripperFunc func(req *http.Request, ctx *ProxyCtx) (*http.Response, error)

func (f RoundTripperFunc) RoundTrip(req *http.Request, ctx *ProxyCtx) (*http.Response, error) {
Expand Down
20 changes: 17 additions & 3 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (proxy *ProxyHttpServer) connectDial(network, addr string) (c net.Conn, err
}

func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request) {
ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy}
ctx := &ProxyCtx{Req: r, Session: atomic.AddInt64(&proxy.sess, 1), proxy: proxy, certStore: proxy.CertStore}

hij, ok := w.(http.Hijacker)
if !ok {
Expand Down Expand Up @@ -408,14 +408,28 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(https_proxy strin

func TLSConfigFromCA(ca *tls.Certificate) func(host string, ctx *ProxyCtx) (*tls.Config, error) {
return func(host string, ctx *ProxyCtx) (*tls.Config, error) {
var err error
var cert *tls.Certificate

hostname := stripPort(host)
config := *defaultTLSConfig
ctx.Logf("signing for %s", stripPort(host))
cert, err := signHost(*ca, []string{stripPort(host)})

genCert := func() (*tls.Certificate, error) {
return signHost(*ca, []string{hostname})
}
if ctx.certStore != nil {
cert, err = ctx.certStore.Fetch(hostname, genCert)
} else {
cert, err = genCert()
}

if err != nil {
ctx.Warnf("Cannot sign host certificate with provided CA: %s", err)
return nil, err
}
config.Certificates = append(config.Certificates, cert)

config.Certificates = append(config.Certificates, *cert)
return &config, nil
}
}
1 change: 1 addition & 0 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type ProxyHttpServer struct {
// ConnectDial will be used to create TCP connections for CONNECT requests
// if nil Tr.Dial will be used
ConnectDial func(network string, addr string) (net.Conn, error)
CertStore CertStorage
}

var hasPort = regexp.MustCompile(`:\d+$`)
Expand Down
87 changes: 86 additions & 1 deletion proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"image"
"io"
"io/ioutil"
Expand All @@ -19,7 +20,7 @@ import (
"testing"

"github.com/elazarl/goproxy"
"github.com/elazarl/goproxy/ext/image"
goproxy_image "github.com/elazarl/goproxy/ext/image"
)

var acceptAllCerts = &tls.Config{InsecureSkipVerify: true}
Expand Down Expand Up @@ -766,6 +767,90 @@ func TestHasGoproxyCA(t *testing.T) {
}
}

type TestCertStorage struct {
certs map[string]*tls.Certificate
hits int
misses int
}

func (tcs *TestCertStorage) Fetch(hostname string, gen func() (*tls.Certificate, error)) (*tls.Certificate, error) {
var cert *tls.Certificate
var err error
cert, ok := tcs.certs[hostname]
if ok {
fmt.Printf("hit %v\n", cert == nil)
tcs.hits++
} else {
cert, err = gen()
if err != nil {
return nil, err
}
fmt.Printf("miss %v\n", cert == nil)
tcs.certs[hostname] = cert
tcs.misses++
}
return cert, err
}

func (tcs *TestCertStorage) statHits() int {
return tcs.hits
}

func (tcs *TestCertStorage) statMisses() int {
return tcs.misses
}

func newTestCertStorage() *TestCertStorage {
tcs := &TestCertStorage{}
tcs.certs = make(map[string]*tls.Certificate)

return tcs
}

func TestProxyWithCertStorage(t *testing.T) {
tcs := newTestCertStorage()
t.Logf("TestProxyWithCertStorage started")
proxy := goproxy.NewProxyHttpServer()
proxy.CertStore = tcs
proxy.OnRequest().HandleConnect(goproxy.AlwaysMitm)
proxy.OnRequest().DoFunc(func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
req.URL.Path = "/bobo"
return req, nil
})

s := httptest.NewServer(proxy)

proxyUrl, _ := url.Parse(s.URL)
goproxyCA := x509.NewCertPool()
goproxyCA.AddCert(goproxy.GoproxyCa.Leaf)

tr := &http.Transport{TLSClientConfig: &tls.Config{RootCAs: goproxyCA}, Proxy: http.ProxyURL(proxyUrl)}
client := &http.Client{Transport: tr}

if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" {
t.Error("Wrong response when mitm", resp, "expected bobo")
}

if tcs.statHits() != 0 {
t.Fatalf("Expected 0 cache hits, got %d", tcs.statHits())
}
if tcs.statMisses() != 1 {
t.Fatalf("Expected 1 cache miss, got %d", tcs.statMisses())
}

// Another round - this time the certificate can be loaded
if resp := string(getOrFail(https.URL+"/bobo", client, t)); resp != "bobo" {
t.Error("Wrong response when mitm", resp, "expected bobo")
}

if tcs.statHits() != 1 {
t.Fatalf("Expected 1 cache hit, got %d", tcs.statHits())
}
if tcs.statMisses() != 1 {
t.Fatalf("Expected 1 cache miss, got %d", tcs.statMisses())
}
}

func TestHttpsMitmURLRewrite(t *testing.T) {
scheme := "https"

Expand Down
4 changes: 2 additions & 2 deletions signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func hashSortedBigInt(lst []string) *big.Int {

var goproxySignerVersion = ":goroxy1"

func signHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err error) {
func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
var x509ca *x509.Certificate

// Use the provided ca and not the global GoproxyCa for certificate generation.
Expand Down Expand Up @@ -81,7 +81,7 @@ func signHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err err
if derBytes, err = x509.CreateCertificate(&csprng, &template, x509ca, &certpriv.PublicKey, ca.PrivateKey); err != nil {
return
}
return tls.Certificate{
return &tls.Certificate{
Certificate: [][]byte{derBytes, ca.Certificate[0]},
PrivateKey: certpriv,
}, nil
Expand Down
2 changes: 1 addition & 1 deletion signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestSingerTls(t *testing.T) {
expected := "key verifies with Go"
server := httptest.NewUnstartedServer(ConstantHanlder(expected))
defer server.Close()
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert, GoproxyCa}}
server.TLS = &tls.Config{Certificates: []tls.Certificate{*cert, GoproxyCa}}
server.TLS.BuildNameToCertificate()
server.StartTLS()
certpool := x509.NewCertPool()
Expand Down