Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: deny incoming peer certs with wrong IP SAN #7687

Merged
merged 2 commits into from
Apr 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
}()

for i, u := range cfg.LPUrls {
var tlscfg *tls.Config
if u.Scheme == "http" {
if !cfg.PeerTLSInfo.Empty() {
plog.Warningf("The scheme of peer url %s is HTTP while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
Expand All @@ -210,12 +209,7 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
}
}
if !cfg.PeerTLSInfo.Empty() {
if tlscfg, err = cfg.PeerTLSInfo.ServerConfig(); err != nil {
return nil, err
}
}
if plns[i], err = rafthttp.NewListener(u, tlscfg); err != nil {
if plns[i], err = rafthttp.NewListener(u, &cfg.PeerTLSInfo); err != nil {
return nil, err
}
plog.Info("listening for peers on ", u.String())
Expand Down
19 changes: 6 additions & 13 deletions etcdmain/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -305,18 +304,7 @@ func startProxy(cfg *config) error {
}
// Start a proxy server goroutine for each listen address
for _, u := range cfg.LCUrls {
var (
l net.Listener
tlscfg *tls.Config
)
if !cfg.ClientTLSInfo.Empty() {
tlscfg, err = cfg.ClientTLSInfo.ServerConfig()
if err != nil {
return err
}
}

l, err := transport.NewListener(u.Host, u.Scheme, tlscfg)
l, err := transport.NewListener(u.Host, u.Scheme, &cfg.ClientTLSInfo)
if err != nil {
return err
}
Expand Down Expand Up @@ -369,6 +357,11 @@ func identifyDataDirOrDie(dir string) dirType {
}

func setupLogging(cfg *config) {
cfg.ClientTLSInfo.HandshakeFailure = func(conn *tls.Conn, err error) {
plog.Infof("rejected connection from %q (%v)", conn.RemoteAddr().String(), err)
}
cfg.PeerTLSInfo.HandshakeFailure = cfg.ClientTLSInfo.HandshakeFailure

capnslog.SetGlobalLogLevel(capnslog.INFO)
if cfg.Debug {
capnslog.SetGlobalLogLevel(capnslog.DEBUG)
Expand Down
7 changes: 3 additions & 4 deletions pkg/transport/keepalive_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"crypto/tls"
"net"
"net/http"
"os"
"testing"
)

Expand Down Expand Up @@ -50,12 +49,12 @@ func TestNewKeepAliveListener(t *testing.T) {
}

// tls
tmp, err := createTempFile([]byte("XXX"))
tlsinfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create tmpfile: %v", err)
}
defer os.Remove(tmp)
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
defer del()
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
tlscfg, err := tlsInfo.ServerConfig()
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ import (
"github.com/coreos/etcd/pkg/tlsutil"
)

func NewListener(addr, scheme string, tlscfg *tls.Config) (l net.Listener, err error) {
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
if l, err = newListener(addr, scheme); err != nil {
return nil, err
}
return wrapTLS(addr, scheme, tlscfg, l)
return wrapTLS(addr, scheme, tlsinfo, l)
}

func newListener(addr string, scheme string) (net.Listener, error) {
Expand All @@ -47,15 +47,11 @@ func newListener(addr string, scheme string) (net.Listener, error) {
return net.Listen("tcp", addr)
}

func wrapTLS(addr, scheme string, tlscfg *tls.Config, l net.Listener) (net.Listener, error) {
func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
if scheme != "https" && scheme != "unixs" {
return l, nil
}
if tlscfg == nil {
l.Close()
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
}
return tls.NewListener(l, tlscfg), nil
return newTLSListener(l, tlsinfo)
}

type TLSInfo struct {
Expand All @@ -68,6 +64,10 @@ type TLSInfo struct {
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
ServerName string

// HandshakeFailure is optinally called when a connection fails to handshake. The
// connection will be closed immediately afterwards.
HandshakeFailure func(*tls.Conn, error)

selfCert bool

// parseFunc exists to simplify testing. Typically, parseFunc
Expand Down
94 changes: 44 additions & 50 deletions pkg/transport/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,16 @@ import (
"time"
)

func createTempFile(b []byte) (string, error) {
f, err := ioutil.TempFile("", "etcd-test-tls-")
if err != nil {
return "", err
func createSelfCert() (*TLSInfo, func(), error) {
d, terr := ioutil.TempDir("", "etcd-test-tls-")
if terr != nil {
return nil, nil, terr
}
defer f.Close()

if _, err = f.Write(b); err != nil {
return "", err
info, err := SelfCert(d, []string{"127.0.0.1"})
if err != nil {
return nil, nil, err
}

return f.Name(), nil
return &info, func() { os.RemoveAll(d) }, nil
}

func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
Expand All @@ -47,28 +45,25 @@ func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBloc
// TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
// a TLS listener that accepts TLS connections.
func TestNewListenerTLSInfo(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create tmpfile: %v", err)
t.Fatalf("unable to create cert: %v", err)
}
defer os.Remove(tmp)
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
testNewListenerTLSInfoAccept(t, tlsInfo)
defer del()
testNewListenerTLSInfoAccept(t, *tlsInfo)
}

func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
tlscfg, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
ln, err := NewListener("127.0.0.1:0", "https", tlscfg)
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
if err != nil {
t.Fatalf("unexpected NewListener error: %v", err)
}
defer ln.Close()

go http.Get("https://" + ln.Addr().String())
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
cli := &http.Client{Transport: tr}
go cli.Get("https://" + ln.Addr().String())

conn, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected Accept error: %v", err)
Expand All @@ -87,25 +82,25 @@ func TestNewListenerTLSEmptyInfo(t *testing.T) {
}

func TestNewTransportTLSInfo(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
tlsinfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("Unable to prepare tmpfile: %v", err)
t.Fatalf("unable to create cert: %v", err)
}
defer os.Remove(tmp)
defer del()

tests := []TLSInfo{
{},
{
CertFile: tmp,
KeyFile: tmp,
CertFile: tlsinfo.CertFile,
KeyFile: tlsinfo.KeyFile,
},
{
CertFile: tmp,
KeyFile: tmp,
CAFile: tmp,
CertFile: tlsinfo.CertFile,
KeyFile: tlsinfo.KeyFile,
CAFile: tlsinfo.CAFile,
},
{
CAFile: tmp,
CAFile: tlsinfo.CAFile,
},
}

Expand Down Expand Up @@ -159,17 +154,17 @@ func TestTLSInfoEmpty(t *testing.T) {
}

func TestTLSInfoMissingFields(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
tlsinfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("Unable to prepare tmpfile: %v", err)
t.Fatalf("unable to create cert: %v", err)
}
defer os.Remove(tmp)
defer del()

tests := []TLSInfo{
{CertFile: tmp},
{KeyFile: tmp},
{CertFile: tmp, CAFile: tmp},
{KeyFile: tmp, CAFile: tmp},
{CertFile: tlsinfo.CertFile},
{KeyFile: tlsinfo.KeyFile},
{CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile},
{KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile},
}

for i, info := range tests {
Expand All @@ -184,44 +179,43 @@ func TestTLSInfoMissingFields(t *testing.T) {
}

func TestTLSInfoParseFuncError(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
tlsinfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("Unable to prepare tmpfile: %v", err)
t.Fatalf("unable to create cert: %v", err)
}
defer os.Remove(tmp)
defer del()

info := TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp}
info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))

if _, err = info.ServerConfig(); err == nil {
if _, err = tlsinfo.ServerConfig(); err == nil {
t.Errorf("expected non-nil error from ServerConfig()")
}

if _, err = info.ClientConfig(); err == nil {
if _, err = tlsinfo.ClientConfig(); err == nil {
t.Errorf("expected non-nil error from ClientConfig()")
}
}

func TestTLSInfoConfigFuncs(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
tlsinfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("Unable to prepare tmpfile: %v", err)
t.Fatalf("unable to create cert: %v", err)
}
defer os.Remove(tmp)
defer del()

tests := []struct {
info TLSInfo
clientAuth tls.ClientAuthType
wantCAs bool
}{
{
info: TLSInfo{CertFile: tmp, KeyFile: tmp},
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile},
clientAuth: tls.NoClientCert,
wantCAs: false,
},

{
info: TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp},
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile},
clientAuth: tls.RequireAndVerifyClientCert,
wantCAs: true,
},
Expand Down
Loading