Skip to content

Commit

Permalink
Merge pull request #1019 from govau/newmockbrokerlistener
Browse files Browse the repository at this point in the history
Add NewMockBrokerListener() so that it's possible to test TLS connections
  • Loading branch information
eapache committed Jan 17, 2018
2 parents 08ccbbb + f933fb4 commit d0cc7fe
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 4 deletions.
206 changes: 206 additions & 0 deletions client_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package sarama

import (
"math/big"
"net"
"testing"
"time"

"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
)

func TestTLS(t *testing.T) {
cakey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

clientkey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

hostkey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

nvb := time.Now().Add(-1 * time.Hour)
nva := time.Now().Add(1 * time.Hour)

caTemplate := &x509.Certificate{
Subject: pkix.Name{CommonName: "ca"},
Issuer: pkix.Name{CommonName: "ca"},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
IsCA: true,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign,
}
caDer, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &cakey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}
caFinalCert, err := x509.ParseCertificate(caDer)
if err != nil {
t.Fatal(err)
}

hostDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "host"},
Issuer: pkix.Name{CommonName: "ca"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}, caFinalCert, &hostkey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}

clientDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "client"},
Issuer: pkix.Name{CommonName: "ca"},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}, caFinalCert, &clientkey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}

pool := x509.NewCertPool()
pool.AddCert(caFinalCert)

systemCerts, err := x509.SystemCertPool()
if err != nil {
t.Fatal(err)
}

// Keep server the same - it's the client that we're testing
serverTLSConfig := &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{hostDer},
PrivateKey: hostkey,
}},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
}

for _, tc := range []struct {
Succeed bool
Server, Client *tls.Config
}{
{ // Verify client fails if wrong CA cert pool is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: systemCerts,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if wrong key is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: hostkey,
}},
},
},
{ // Verify client fails if wrong cert is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{hostDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if no CAs are specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if no keys are specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
},
},
{ // Finally, verify it all works happily with client and server cert in place
Succeed: true,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
} {
doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client)
}
}

func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientConfig *tls.Config) {
serverConfig.BuildNameToCertificate()
clientConfig.BuildNameToCertificate()

seedListener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatal("cannot open listener", err)
}

var childT *testing.T
if expectSuccess {
childT = t
} else {
childT = &testing.T{} // we want to swallow errors
}

seedBroker := NewMockBrokerListener(childT, 1, seedListener)
defer seedBroker.Close()

seedBroker.Returns(new(MetadataResponse))

config := NewConfig()
config.Net.TLS.Enable = true
config.Net.TLS.Config = clientConfig

client, err := NewClient([]string{seedBroker.Addr()}, config)
if err == nil {
safeClose(t, client)
}

if expectSuccess {
if err != nil {
t.Fatal(err)
}
} else {
if err == nil {
t.Fatal("expected failure")
}
}
}
14 changes: 10 additions & 4 deletions mockbroker.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
// NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
// it rather than just some ephemeral port.
func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
listener, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
return NewMockBrokerListener(t, brokerID, listener)
}

// NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
var err error

broker := &MockBroker{
Expand All @@ -296,13 +305,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
t: t,
brokerID: brokerID,
expectations: make(chan encoder, 512),
listener: listener,
}
broker.handler = broker.defaultRequestHandler

broker.listener, err = net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
_, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
if err != nil {
Expand Down

0 comments on commit d0cc7fe

Please sign in to comment.