Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle SASL/OAUTHBEARER token rejection #1428

Merged
merged 1 commit into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(correlationID)
bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
Expand All @@ -1037,26 +1037,53 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
return err
}

message, err := buildClientFirstMessage(token)
if err != nil {
return err
}

challenged, err := b.sendClientMessage(message)
if err != nil {
return err
}

if challenged {
// Abort the token exchange. The broker returns the failure code.
_, err = b.sendClientMessage([]byte(`\x01`))
}

return err
}

// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true
// if the broker responds with a challenge, in which case the token is
// rejected.
func (b *Broker) sendClientMessage(message []byte) (bool, error) {

requestTime := time.Now()
correlationID := b.correlationID

bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
if err != nil {
return err
return false, err
}

b.updateOutgoingCommunicationMetrics(bytesWritten)
b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(correlationID)
if err != nil {
return err
}
res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)

requestLatency := time.Since(requestTime)
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)

return nil
isChallenge := len(res.SaslAuthBytes) > 0

if isChallenge && err != nil {
Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes)
}

return isChallenge, err
}

func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
Expand Down Expand Up @@ -1154,7 +1181,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e

// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
// https://tools.ietf.org/html/rfc7628
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
var ext string

if token.Extensions != nil && len(token.Extensions) > 0 {
Expand Down Expand Up @@ -1200,11 +1227,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro
return b.conn.Write(buf)
}

func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
initialResp, err := buildClientInitialResponse(token)
if err != nil {
return 0, err
}
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {

rb := &SaslAuthenticateRequest{initialResp}

Expand All @@ -1222,7 +1245,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
return b.conn.Write(buf)
}

func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {

buf := make([]byte, responseLengthSize+correlationIDSize)

Expand Down Expand Up @@ -1250,8 +1273,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
return bytesRead, err
}

res := &SaslAuthenticateResponse{}

if err := versionedDecode(buf, res, 0); err != nil {
return bytesRead, err
}
Expand All @@ -1260,10 +1281,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
return bytesRead, res.Err
}

if len(res.SaslAuthBytes) > 0 {
Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes)
}

return bytesRead, nil
}

Expand Down
108 changes: 61 additions & 47 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package sarama
import (
"errors"
"fmt"
"gopkg.in/jcmturner/gokrb5.v7/krberror"
"net"
"reflect"
"testing"
"time"

"github.com/rcrowley/go-metrics"
"gopkg.in/jcmturner/gokrb5.v7/krberror"
)

func ExampleBroker() {
Expand Down Expand Up @@ -132,42 +132,66 @@ func newTokenProvider(token *AccessToken, err error) *TokenProvider {
func TestSASLOAuthBearer(t *testing.T) {

testTable := []struct {
name string
mockAuthErr KError // Mock and expect error returned from SaslAuthenticateRequest
mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
expectClientErr bool // Expect an internal client-side error
tokProvider *TokenProvider
name string
mockSASLHandshakeResponse MockResponse // Mock SaslHandshakeRequest response from broker
mockSASLAuthResponse MockResponse // Mock SaslAuthenticateRequest response from broker
expectClientErr bool // Expect an internal client-side error
expectedBrokerError KError // Expected Kafka error returned by client
tokProvider *TokenProvider
}{
{
name: "SASL/OAUTHBEARER OK server response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER OK server response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: false,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER authentication failure response",
mockAuthErr: ErrSASLAuthenticationFailed,
mockHandshakeErr: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER authentication failure response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSequence(
// First, the broker response with a challenge
NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`{"status":"invalid_request1"}`)),
// Next, the client terminates the token exchange. Finally, the
// broker responds with an error message.
NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`{"status":"invalid_request2"}`)).
SetError(ErrSASLAuthenticationFailed),
),
expectClientErr: true,
expectedBrokerError: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER handshake failure response",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
name: "SASL/OAUTHBEARER handshake failure response",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}).
SetError(ErrSASLAuthenticationFailed),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrSASLAuthenticationFailed,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
},
{
name: "SASL/OAUTHBEARER token generation error",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
expectClientErr: true,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
name: "SASL/OAUTHBEARER token generation error",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
},
{
name: "SASL/OAUTHBEARER invalid extension",
mockAuthErr: ErrNoError,
mockHandshakeErr: ErrNoError,
expectClientErr: true,
name: "SASL/OAUTHBEARER invalid extension",
mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypeOAuth}),
mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
expectClientErr: true,
expectedBrokerError: ErrNoError,
tokProvider: newTokenProvider(&AccessToken{
Token: "access-token-123",
Extensions: map[string]string{"auth": "auth-value"},
Expand All @@ -180,19 +204,9 @@ func TestSASLOAuthBearer(t *testing.T) {
// mockBroker mocks underlying network logic and broker responses
mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
if test.mockAuthErr != ErrNoError {
mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
}

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
if test.mockHandshakeErr != ErrNoError {
mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"SaslAuthenticateRequest": test.mockSASLAuthResponse,
"SaslHandshakeRequest": test.mockSASLHandshakeResponse,
})

// broker executes SASL requests against mockBroker
Expand Down Expand Up @@ -227,13 +241,13 @@ func TestSASLOAuthBearer(t *testing.T) {

err = broker.authenticateViaSASL()

if test.mockAuthErr != ErrNoError {
if test.mockAuthErr != err {
t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
if test.expectedBrokerError != ErrNoError {
if test.expectedBrokerError != err {
t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err)
}
} else if test.mockHandshakeErr != ErrNoError {
if test.mockHandshakeErr != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
} else if test.expectedBrokerError != ErrNoError {
if test.expectedBrokerError != err {
t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.expectedBrokerError, err)
}
} else if test.expectClientErr && err == nil {
t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
Expand Down Expand Up @@ -599,7 +613,7 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {

}

func TestBuildClientInitialResponse(t *testing.T) {
func TestBuildClientFirstMessage(t *testing.T) {

testTable := []struct {
name string
Expand Down Expand Up @@ -638,7 +652,7 @@ func TestBuildClientInitialResponse(t *testing.T) {

for i, test := range testTable {

actual, err := buildClientInitialResponse(test.token)
actual, err := buildClientFirstMessage(test.token)

if !reflect.DeepEqual(test.expected, actual) {
t.Errorf("Expected %s, got %s\n", test.expected, actual)
Expand Down