diff --git a/broker.go b/broker.go index 7b32a03d3..9c3e5a04a 100644 --- a/broker.go +++ b/broker.go @@ -1013,7 +1013,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { b.correlationID++ - bytesRead, err := b.receiveSASLServerResponse(correlationID) + bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID) b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime)) // With v1 sasl we get an error message set in the response we can return @@ -1037,26 +1037,53 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error { return err } + message, err := buildClientFirstMessage(token) + if err != nil { + return err + } + + challenged, err := b.sendClientMessage(message) + if err != nil { + return err + } + + if challenged { + // Abort the token exchange. The broker returns the failure code. + _, err = b.sendClientMessage([]byte(`\x01`)) + } + + return err +} + +// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true +// if the broker responds with a challenge, in which case the token is +// rejected. +func (b *Broker) sendClientMessage(message []byte) (bool, error) { + requestTime := time.Now() correlationID := b.correlationID - bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID) + bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID) if err != nil { - return err + return false, err } b.updateOutgoingCommunicationMetrics(bytesWritten) b.correlationID++ - bytesRead, err := b.receiveSASLServerResponse(correlationID) - if err != nil { - return err - } + res := &SaslAuthenticateResponse{} + bytesRead, err := b.receiveSASLServerResponse(res, correlationID) requestLatency := time.Since(requestTime) b.updateIncomingCommunicationMetrics(bytesRead, requestLatency) - return nil + isChallenge := len(res.SaslAuthBytes) > 0 + + if isChallenge && err != nil { + Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes) + } + + return isChallenge, err } func (b *Broker) sendAndReceiveSASLSCRAMv1() error { @@ -1154,7 +1181,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e // Build SASL/OAUTHBEARER initial client response as described by RFC-7628 // https://tools.ietf.org/html/rfc7628 -func buildClientInitialResponse(token *AccessToken) ([]byte, error) { +func buildClientFirstMessage(token *AccessToken) ([]byte, error) { var ext string if token.Extensions != nil && len(token.Extensions) > 0 { @@ -1200,11 +1227,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro return b.conn.Write(buf) } -func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) { - initialResp, err := buildClientInitialResponse(token) - if err != nil { - return 0, err - } +func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) { rb := &SaslAuthenticateRequest{initialResp} @@ -1222,7 +1245,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati return b.conn.Write(buf) } -func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) { +func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) { buf := make([]byte, responseLengthSize+correlationIDSize) @@ -1250,8 +1273,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) { return bytesRead, err } - res := &SaslAuthenticateResponse{} - if err := versionedDecode(buf, res, 0); err != nil { return bytesRead, err } @@ -1260,10 +1281,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) { return bytesRead, res.Err } - if len(res.SaslAuthBytes) > 0 { - Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes) - } - return bytesRead, nil } diff --git a/broker_test.go b/broker_test.go index 7bf020fcd..f12c16901 100644 --- a/broker_test.go +++ b/broker_test.go @@ -3,13 +3,13 @@ package sarama import ( "errors" "fmt" - "gopkg.in/jcmturner/gokrb5.v7/krberror" "net" "reflect" "testing" "time" "github.com/rcrowley/go-metrics" + "gopkg.in/jcmturner/gokrb5.v7/krberror" ) func ExampleBroker() { @@ -132,42 +132,66 @@ func newTokenProvider(token *AccessToken, err error) *TokenProvider { func TestSASLOAuthBearer(t *testing.T) { testTable := []struct { - name string - mockAuthErr KError // Mock and expect error returned from SaslAuthenticateRequest - mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest - expectClientErr bool // Expect an internal client-side error - tokProvider *TokenProvider + name string + mockSASLHandshakeResponse MockResponse // Mock SaslHandshakeRequest response from broker + mockSASLAuthResponse MockResponse // Mock SaslAuthenticateRequest response from broker + expectClientErr bool // Expect an internal client-side error + expectedBrokerError KError // Expected Kafka error returned by client + tokProvider *TokenProvider }{ { - name: "SASL/OAUTHBEARER OK server response", - mockAuthErr: ErrNoError, - mockHandshakeErr: ErrNoError, - tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), + name: "SASL/OAUTHBEARER OK server response", + mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypeOAuth}), + mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t), + expectClientErr: false, + expectedBrokerError: ErrNoError, + tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), }, { - name: "SASL/OAUTHBEARER authentication failure response", - mockAuthErr: ErrSASLAuthenticationFailed, - mockHandshakeErr: ErrNoError, - tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), + name: "SASL/OAUTHBEARER authentication failure response", + mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypeOAuth}), + mockSASLAuthResponse: NewMockSequence( + // First, the broker response with a challenge + NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`{"status":"invalid_request1"}`)), + // Next, the client terminates the token exchange. Finally, the + // broker responds with an error message. + NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`{"status":"invalid_request2"}`)). + SetError(ErrSASLAuthenticationFailed), + ), + expectClientErr: true, + expectedBrokerError: ErrSASLAuthenticationFailed, + tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), }, { - name: "SASL/OAUTHBEARER handshake failure response", - mockAuthErr: ErrNoError, - mockHandshakeErr: ErrSASLAuthenticationFailed, - tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), + name: "SASL/OAUTHBEARER handshake failure response", + mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypeOAuth}). + SetError(ErrSASLAuthenticationFailed), + mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t), + expectClientErr: true, + expectedBrokerError: ErrSASLAuthenticationFailed, + tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil), }, { - name: "SASL/OAUTHBEARER token generation error", - mockAuthErr: ErrNoError, - mockHandshakeErr: ErrNoError, - expectClientErr: true, - tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure), + name: "SASL/OAUTHBEARER token generation error", + mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypeOAuth}), + mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t), + expectClientErr: true, + expectedBrokerError: ErrNoError, + tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure), }, { - name: "SASL/OAUTHBEARER invalid extension", - mockAuthErr: ErrNoError, - mockHandshakeErr: ErrNoError, - expectClientErr: true, + name: "SASL/OAUTHBEARER invalid extension", + mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypeOAuth}), + mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t), + expectClientErr: true, + expectedBrokerError: ErrNoError, tokProvider: newTokenProvider(&AccessToken{ Token: "access-token-123", Extensions: map[string]string{"auth": "auth-value"}, @@ -180,19 +204,9 @@ func TestSASLOAuthBearer(t *testing.T) { // mockBroker mocks underlying network logic and broker responses mockBroker := NewMockBroker(t, 0) - mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload")) - if test.mockAuthErr != ErrNoError { - mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr) - } - - mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth}) - if test.mockHandshakeErr != ErrNoError { - mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr) - } - mockBroker.SetHandlerByMap(map[string]MockResponse{ - "SaslAuthenticateRequest": mockSASLAuthResponse, - "SaslHandshakeRequest": mockSASLHandshakeResponse, + "SaslAuthenticateRequest": test.mockSASLAuthResponse, + "SaslHandshakeRequest": test.mockSASLHandshakeResponse, }) // broker executes SASL requests against mockBroker @@ -227,13 +241,13 @@ func TestSASLOAuthBearer(t *testing.T) { err = broker.authenticateViaSASL() - if test.mockAuthErr != ErrNoError { - if test.mockAuthErr != err { - t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err) + if test.expectedBrokerError != ErrNoError { + if test.expectedBrokerError != err { + t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err) } - } else if test.mockHandshakeErr != ErrNoError { - if test.mockHandshakeErr != err { - t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err) + } else if test.expectedBrokerError != ErrNoError { + if test.expectedBrokerError != err { + t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.expectedBrokerError, err) } } else if test.expectClientErr && err == nil { t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name) @@ -599,7 +613,7 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) { } -func TestBuildClientInitialResponse(t *testing.T) { +func TestBuildClientFirstMessage(t *testing.T) { testTable := []struct { name string @@ -638,7 +652,7 @@ func TestBuildClientInitialResponse(t *testing.T) { for i, test := range testTable { - actual, err := buildClientInitialResponse(test.token) + actual, err := buildClientFirstMessage(test.token) if !reflect.DeepEqual(test.expected, actual) { t.Errorf("Expected %s, got %s\n", test.expected, actual)