Skip to content

Commit

Permalink
Merge pull request #1119 from mreiferson/nsqd-tls-cn
Browse files Browse the repository at this point in the history
nsqd: send client TLS cert common_name on authd requests
  • Loading branch information
mreiferson authored Jan 5, 2019
2 parents cf4a8c8 + 64a4b65 commit cf04b97
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 30 deletions.
13 changes: 9 additions & 4 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ func (a *State) IsExpired() bool {
return false
}

func QueryAnyAuthd(authd []string, remoteIP, tlsEnabled, authSecret string,
func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
for _, a := range authd {
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, authSecret, connectTimeout, requestTimeout)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, connectTimeout, requestTimeout)
if err != nil {
log.Printf("Error: failed auth against %s %s", a, err)
continue
Expand All @@ -89,12 +89,17 @@ func QueryAnyAuthd(authd []string, remoteIP, tlsEnabled, authSecret string,
return nil, errors.New("Unable to access auth server")
}

func QueryAuthd(authd, remoteIP, tlsEnabled, authSecret string,
func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
v := url.Values{}
v.Set("remote_ip", remoteIP)
v.Set("tls", tlsEnabled)
if tlsEnabled {
v.Set("tls", "true")
} else {
v.Set("tls", "false")
}
v.Set("secret", authSecret)
v.Set("common_name", commonName)

endpoint := fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())

Expand Down
14 changes: 9 additions & 5 deletions nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,14 +569,18 @@ func (c *clientV2) QueryAuthd() error {
return err
}

tls := atomic.LoadInt32(&c.TLS) == 1
tlsEnabled := "false"
if tls {
tlsEnabled = "true"
tlsEnabled := atomic.LoadInt32(&c.TLS) == 1
commonName := ""
if tlsEnabled {
tlsConnState := c.tlsConn.ConnectionState()
if len(tlsConnState.PeerCertificates) > 0 {
commonName = tlsConnState.PeerCertificates[0].Subject.CommonName
}
}

authState, err := auth.QueryAnyAuthd(c.ctx.nsqd.getOpts().AuthHTTPAddresses,
remoteIP, tlsEnabled, c.AuthSecret, c.ctx.nsqd.getOpts().HTTPClientConnectTimeout,
remoteIP, tlsEnabled, commonName, c.AuthSecret,
c.ctx.nsqd.getOpts().HTTPClientConnectTimeout,
c.ctx.nsqd.getOpts().HTTPClientRequestTimeout)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions nsqd/protocol_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,21 +518,21 @@ func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
}

if client.HasAuthorizations() {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH Already set")
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH already set")
}

if !client.ctx.nsqd.IsAuthEnabled() {
return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH Disabled")
return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH disabled")
}

if err := client.Auth(string(body)); err != nil {
// we don't want to leak errors contacting the auth server to untrusted clients
p.ctx.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] Auth Failed %s", client, err)
p.ctx.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
}

if !client.HasAuthorizations() {
return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH No authorizations found")
return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH no authorizations found")
}

resp, err := json.Marshal(struct {
Expand Down Expand Up @@ -568,7 +568,7 @@ func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName str
ok, err := client.IsAuthorized(topicName, channelName)
if err != nil {
// we don't want to leak errors contacting the auth server to untrusted clients
p.ctx.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] Auth Failed %s", client, err)
p.ctx.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
}
if !ok {
Expand Down
75 changes: 59 additions & 16 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1466,30 +1466,41 @@ func TestReqTimeoutRange(t *testing.T) {
func TestClientAuth(t *testing.T) {
authResponse := `{"ttl":1, "authorizations":[]}`
authSecret := "testsecret"
authError := "E_UNAUTHORIZED AUTH No authorizations found"
authError := "E_UNAUTHORIZED AUTH no authorizations found"
authSuccess := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess)
tlsEnabled := false
commonName := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)

// now one that will succeed
authResponse = `{"ttl":10, "authorizations":
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
}`
authError = ""
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
runAuthTest(t, authResponse, authSecret, authError, authSuccess)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)

// one with TLS enabled
tlsEnabled = true
commonName = "test.local"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
}

func runAuthTest(t *testing.T, authResponse, authSecret, authError, authSuccess string) {
func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
authSuccess string, tlsEnabled bool, commonName string) {
var err error
var expectedAuthIP string
expectedAuthTLS := "false"
var expectedRemoteIP string
expectedTLS := "false"
if tlsEnabled {
expectedTLS = "true"
}

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
r.ParseForm()
test.Equal(t, expectedAuthIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedAuthTLS, r.Form.Get("tls"))
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
fmt.Fprint(w, authResponse)
}))
Expand All @@ -1502,6 +1513,11 @@ func runAuthTest(t *testing.T, authResponse, authSecret, authError, authSuccess
opts.Logger = test.NewTestLogger(t)
opts.LogLevel = "debug"
opts.AuthHTTPAddresses = []string{addr.Host}
if tlsEnabled {
opts.TLSCert = "./test/certs/server.pem"
opts.TLSKey = "./test/certs/server.key"
opts.TLSClientAuthPolicy = "require"
}
tcpAddr, _, nsqd := mustStartNSQD(opts)
defer os.RemoveAll(opts.DataPath)
defer nsqd.Exit()
Expand All @@ -1510,19 +1526,46 @@ func runAuthTest(t *testing.T, authResponse, authSecret, authError, authSuccess
test.Nil(t, err)
defer conn.Close()

expectedAuthIP, _, _ = net.SplitHostPort(conn.LocalAddr().String())
data := identify(t, conn, map[string]interface{}{
"tls_v1": tlsEnabled,
}, frameTypeResponse)
r := struct {
TLSv1 bool `json:"tls_v1"`
}{}
err = json.Unmarshal(data, &r)
test.Nil(t, err)
test.Equal(t, tlsEnabled, r.TLSv1)

identify(t, conn, map[string]interface{}{
"tls_v1": false,
}, nsq.FrameTypeResponse)
var c io.ReadWriter
var tlsConn *tls.Conn
c = conn
if tlsEnabled {
cert, err := tls.LoadX509KeyPair("./test/certs/cert.pem", "./test/certs/key.pem")
test.Nil(t, err)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
tlsConn = tls.Client(conn, tlsConfig)
err = tlsConn.Handshake()
test.Nil(t, err)
c = tlsConn

authCmd(t, conn, authSecret, authSuccess)
resp, _ := nsq.ReadResponse(tlsConn)
frameType, data, _ := nsq.UnpackResponse(resp)
t.Logf("frameType: %d, data: %s", frameType, data)
test.Equal(t, frameTypeResponse, frameType)
test.Equal(t, []byte("OK"), data)
}

expectedRemoteIP, _, _ = net.SplitHostPort(conn.LocalAddr().String())

authCmd(t, c, authSecret, authSuccess)
if authError != "" {
readValidate(t, conn, nsq.FrameTypeError, authError)
readValidate(t, c, frameTypeError, authError)
} else {
sub(t, conn, "test", "ch")
sub(t, c, "test", "ch")
}

}

func TestIOLoopReturnsClientErrWhenSendFails(t *testing.T) {
Expand Down

0 comments on commit cf04b97

Please sign in to comment.