Skip to content

Commit

Permalink
Merge pull request #1428 from mkaminski1988/master
Browse files Browse the repository at this point in the history
Handle SASL/OAUTHBEARER token rejection
  • Loading branch information
bai committed Jul 17, 2019
2 parents dde3ddd + 6c040e1 commit d194841
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 68 deletions.
59 changes: 38 additions & 21 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(correlationID)
bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
Expand All @@ -1037,26 +1037,53 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
return err
}

message, err := buildClientFirstMessage(token)
if err != nil {
return err
}

challenged, err := b.sendClientMessage(message)
if err != nil {
return err
}

if challenged {
// Abort the token exchange. The broker returns the failure code.
_, err = b.sendClientMessage([]byte(`\x01`))
}

return err
}

// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true
// if the broker responds with a challenge, in which case the token is
// rejected.
func (b *Broker) sendClientMessage(message []byte) (bool, error) {

requestTime := time.Now()
correlationID := b.correlationID

bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
if err != nil {
return err
return false, err
}

b.updateOutgoingCommunicationMetrics(bytesWritten)
b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(correlationID)
if err != nil {
return err
}
res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)

requestLatency := time.Since(requestTime)
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)

return nil
isChallenge := len(res.SaslAuthBytes) > 0

if isChallenge && err != nil {
Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes)
}

return isChallenge, err
}

func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
Expand Down Expand Up @@ -1154,7 +1181,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e

// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
// https://tools.ietf.org/html/rfc7628
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
var ext string

if token.Extensions != nil && len(token.Extensions) > 0 {
Expand Down Expand Up @@ -1200,11 +1227,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro
return b.conn.Write(buf)
}

func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
initialResp, err := buildClientInitialResponse(token)
if err != nil {
return 0, err
}
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {

rb := &SaslAuthenticateRequest{initialResp}

Expand All @@ -1222,7 +1245,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
return b.conn.Write(buf)
}

func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {

buf := make([]byte, responseLengthSize+correlationIDSize)

Expand Down Expand Up @@ -1250,8 +1273,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
return bytesRead, err
}

res := &SaslAuthenticateResponse{}

if err := versionedDecode(buf, res, 0); err != nil {
return bytesRead, err
}
Expand All @@ -1260,10 +1281,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
return bytesRead, res.Err
}

if len(res.SaslAuthBytes) > 0 {
Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes)
}

return bytesRead, nil
}

Expand Down
108 changes: 61 additions & 47 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package sarama
import (
"errors"
"fmt"
"gopkg.in/jcmturner/gokrb5.v7/krberror"
"net"
"reflect"
"testing"
"time"

"github.com/rcrowley/go-metrics"
"gopkg.in/jcmturner/gokrb5.v7/krberror"
)

func ExampleBroker() {
Expand Down Expand Up @@ -132,42 +132,66 @@ func newTokenProvider(token *AccessToken, err error) *TokenProvider {
func TestSASLOAuthBearer(t *testing.T) {

testTable := []struct {
name string
mockAuthErr KError // Mock and expect error returned from SaslAuthenticateRequest
mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
expectClientErr bool // Expect an internal client-side error
tokProvider *TokenProvider
name string
mockSASLHandshakeResponse MockResponse // Mock SaslHandshakeRequest response from broker
mockSASLAuthResponse MockResponse // Mock SaslAuthenticateRequest response from broker
expectClientErr bool // Expect an internal client-side error
expectedBrokerError KError // Expected Kafka error returned by client
tokProvider *TokenProvider
}{
{
name: "SASL/OAUTHBEARER OK server response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER OK server response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: false,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER authentication failure response",
mockAuthErr: ErrSASLAuthenticationFailed,
mockHandshakeErr: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER authentication failure response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSequence(
// First, the broker response with a challenge
NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`{"status":"invalid_request1"}`)),
// Next, the client terminates the token exchange. Finally, the
// broker responds with an error message.
NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`{"status":"invalid_request2"}`)).
SetError(ErrSASLAuthenticationFailed),
),
expectClientErr: true,
expectedBrokerError: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER handshake failure response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER handshake failure response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}).
SetError(ErrSASLAuthenticationFailed),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER token generation error",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
expectClientErr: true,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
name: "SASL/OAUTHBEARER token generation error",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
},
{
name: "SASL/OAUTHBEARER invalid extension",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
expectClientErr: true,
name: "SASL/OAUTHBEARER invalid extension",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{
Token: "access-token-123",
Extensions: map[string]string{"auth": "auth-value"},
Expand All @@ -180,19 +204,9 @@ func TestSASLOAuthBearer(t *testing.T) {
// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
if test.mockAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
}

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"SaslAuthenticateRequest": test.mockSASLAuthResponse,
"SaslHandshakeRequest": test.mockSASLHandshakeResponse,
})

// broker executes SASL requests against mockBroker
Expand Down Expand Up @@ -227,13 +241,13 @@ func TestSASLOAuthBearer(t *testing.T) {

err = broker.authenticateViaSASL()

if test.mockAuthErr != ErrNoError {
if test.mockAuthErr != err {
t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
if test.expectedBrokerError != ErrNoError {
if test.expectedBrokerError != err {
t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err)
}
} else if test.mockHandshakeErr != ErrNoError {
if test.mockHandshakeErr != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
} else if test.expectedBrokerError != ErrNoError {
if test.expectedBrokerError != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.expectedBrokerError, err)
}
} else if test.expectClientErr && err == nil {
t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
Expand Down Expand Up @@ -599,7 +613,7 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {

}

func TestBuildClientInitialResponse(t *testing.T) {
func TestBuildClientFirstMessage(t *testing.T) {

testTable := []struct {
name string
Expand Down Expand Up @@ -638,7 +652,7 @@ func TestBuildClientInitialResponse(t *testing.T) {

for i, test := range testTable {

actual, err := buildClientInitialResponse(test.token)
actual, err := buildClientFirstMessage(test.token)

if !reflect.DeepEqual(test.expected, actual) {
t.Errorf("Expected %s, got %s\n", test.expected, actual)
Expand Down

0 comments on commit d194841

Please sign in to comment.