Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NewMockBrokerListener() so that it's possible to test TLS connections #1019

Merged
merged 1 commit into from
Jan 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions client_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package sarama

import (
"math/big"
"net"
"testing"
"time"

"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
)

func TestTLS(t *testing.T) {
cakey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

clientkey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

hostkey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

nvb := time.Now().Add(-1 * time.Hour)
nva := time.Now().Add(1 * time.Hour)

caTemplate := &x509.Certificate{
Subject: pkix.Name{CommonName: "ca"},
Issuer: pkix.Name{CommonName: "ca"},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
IsCA: true,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageCertSign,
}
caDer, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &cakey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}
caFinalCert, err := x509.ParseCertificate(caDer)
if err != nil {
t.Fatal(err)
}

hostDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "host"},
Issuer: pkix.Name{CommonName: "ca"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}, caFinalCert, &hostkey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}

clientDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
Subject: pkix.Name{CommonName: "client"},
Issuer: pkix.Name{CommonName: "ca"},
SerialNumber: big.NewInt(0),
NotAfter: nva,
NotBefore: nvb,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}, caFinalCert, &clientkey.PublicKey, cakey)
if err != nil {
t.Fatal(err)
}

pool := x509.NewCertPool()
pool.AddCert(caFinalCert)

systemCerts, err := x509.SystemCertPool()
if err != nil {
t.Fatal(err)
}

// Keep server the same - it's the client that we're testing
serverTLSConfig := &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{hostDer},
PrivateKey: hostkey,
}},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
}

for _, tc := range []struct {
Succeed bool
Server, Client *tls.Config
}{
{ // Verify client fails if wrong CA cert pool is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: systemCerts,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if wrong key is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: hostkey,
}},
},
},
{ // Verify client fails if wrong cert is specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{hostDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if no CAs are specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
{ // Verify client fails if no keys are specified
Succeed: false,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
},
},
{ // Finally, verify it all works happily with client and server cert in place
Succeed: true,
Server: serverTLSConfig,
Client: &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{clientDer},
PrivateKey: clientkey,
}},
},
},
} {
doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client)
}
}

func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientConfig *tls.Config) {
serverConfig.BuildNameToCertificate()
clientConfig.BuildNameToCertificate()

seedListener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatal("cannot open listener", err)
}

var childT *testing.T
if expectSuccess {
childT = t
} else {
childT = &testing.T{} // we want to swallow errors
}

seedBroker := NewMockBrokerListener(childT, 1, seedListener)
defer seedBroker.Close()

seedBroker.Returns(new(MetadataResponse))

config := NewConfig()
config.Net.TLS.Enable = true
config.Net.TLS.Config = clientConfig

client, err := NewClient([]string{seedBroker.Addr()}, config)
if err == nil {
safeClose(t, client)
}

if expectSuccess {
if err != nil {
t.Fatal(err)
}
} else {
if err == nil {
t.Fatal("expected failure")
}
}
}
14 changes: 10 additions & 4 deletions mockbroker.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
// NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
// it rather than just some ephemeral port.
func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
listener, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
return NewMockBrokerListener(t, brokerID, listener)
}

// NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
var err error

broker := &MockBroker{
Expand All @@ -296,13 +305,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
t: t,
brokerID: brokerID,
expectations: make(chan encoder, 512),
listener: listener,
}
broker.handler = broker.defaultRequestHandler

broker.listener, err = net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
_, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
if err != nil {
Expand Down