Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure consistent use of read/write deadlines #1529

Merged
merged 1 commit into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 37 additions & 57 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,26 @@ func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogD
return response, nil
}

// readFull ensures the conn ReadDeadline has been setup before making a
// call to io.ReadFull
func (b *Broker) readFull(buf []byte) (n int, err error) {
if err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)); err != nil {
return 0, err
}

return io.ReadFull(b.conn, buf)
}

// write ensures the conn WriteDeadline has been setup before making a
// call to conn.Write
func (b *Broker) write(buf []byte) (n int, err error) {
if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
return 0, err
}

return b.conn.Write(buf)
}

func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) {
b.lock.Lock()
defer b.lock.Unlock()
Expand All @@ -692,14 +712,9 @@ func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise,
return nil, err
}

err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
return nil, err
}

requestTime := time.Now()
bytes, err := b.conn.Write(buf)
b.updateOutgoingCommunicationMetrics(bytes) //TODO: should it be after error check
bytes, err := b.write(buf)
b.updateOutgoingCommunicationMetrics(bytes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -806,14 +821,7 @@ func (b *Broker) responseReceiver() {
continue
}

err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout))
if err != nil {
dead = err
response.errors <- err
continue
}

bytesReadHeader, err := io.ReadFull(b.conn, header)
bytesReadHeader, err := b.readFull(header)
requestLatency := time.Since(response.requestTime)
if err != nil {
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
Expand All @@ -840,7 +848,7 @@ func (b *Broker) responseReceiver() {
}

buf := make([]byte, decodedHeader.length-4)
bytesReadBody, err := io.ReadFull(b.conn, buf)
bytesReadBody, err := b.readFull(buf)
b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
if err != nil {
dead = err
Expand Down Expand Up @@ -883,30 +891,25 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
return err
}

err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
return err
}

requestTime := time.Now()
bytes, err := b.conn.Write(buf)
bytes, err := b.write(buf)
b.updateOutgoingCommunicationMetrics(bytes)
if err != nil {
Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
return err
}
b.correlationID++
//wait for the response

header := make([]byte, 8) // response header
_, err = io.ReadFull(b.conn, header)
_, err = b.readFull(header)
if err != nil {
Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
return err
}

length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
n, err := io.ReadFull(b.conn, payload)
n, err := b.readFull(payload)
if err != nil {
Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
return err
Expand Down Expand Up @@ -980,22 +983,16 @@ func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
binary.BigEndian.PutUint32(authBytes, uint32(length))
copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password))

err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
return err
}

requestTime := time.Now()
bytesWritten, err := b.conn.Write(authBytes)
bytesWritten, err := b.write(authBytes)
b.updateOutgoingCommunicationMetrics(bytesWritten)
if err != nil {
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}

header := make([]byte, 4)
n, err := io.ReadFull(b.conn, header)
n, err := b.readFull(header)
b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
// If the credentials are valid, we would get a 4 byte response filled with null characters.
// Otherwise, the broker closes the connection and we get an EOF
Expand Down Expand Up @@ -1151,16 +1148,12 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
return 0, err
}

if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
return 0, err
}

return b.conn.Write(buf)
return b.write(buf)
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
_, err := io.ReadFull(b.conn, buf)
_, err := b.readFull(buf)
if err != nil {
return nil, err
}
Expand All @@ -1176,7 +1169,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e
}

buf = make([]byte, header.length-correlationIDSize)
_, err = io.ReadFull(b.conn, buf)
_, err = b.readFull(buf)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1231,12 +1224,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro
return 0, err
}

err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
if err != nil {
Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
return 0, err
}
return b.conn.Write(buf)
return b.write(buf)
}

func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
Expand All @@ -1250,24 +1238,17 @@ func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlatio
return 0, err
}

if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
return 0, err
}

return b.conn.Write(buf)
return b.write(buf)
}

func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {

buf := make([]byte, responseLengthSize+correlationIDSize)

bytesRead, err := io.ReadFull(b.conn, buf)
bytesRead, err := b.readFull(buf)
if err != nil {
return bytesRead, err
}

header := responseHeader{}

err = decode(buf, &header)
if err != nil {
return bytesRead, err
Expand All @@ -1278,8 +1259,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
}

buf = make([]byte, header.length-correlationIDSize)

c, err := io.ReadFull(b.conn, buf)
c, err := b.readFull(buf)
bytesRead += c
if err != nil {
return bytesRead, err
Expand Down
49 changes: 49 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,55 @@ func TestSASLPlainAuth(t *testing.T) {
}
}

// TestSASLReadTimeout ensures that the broker connection won't block forever
// if the remote end never responds after the handshake
func TestSASLReadTimeout(t *testing.T) {
mockBroker := NewMockBroker(t, 0)
defer mockBroker.Close()

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`))

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
})

broker := NewBroker(mockBroker.Addr())
{
broker.requestRate = metrics.NilMeter{}
broker.outgoingByteRate = metrics.NilMeter{}
broker.incomingByteRate = metrics.NilMeter{}
broker.requestSize = metrics.NilHistogram{}
broker.responseSize = metrics.NilHistogram{}
broker.responseRate = metrics.NilMeter{}
broker.requestLatency = metrics.NilHistogram{}
}

conf := NewConfig()
{
conf.Net.ReadTimeout = time.Millisecond
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"
conf.Net.SASL.Version = SASLHandshakeV1
}

broker.conf = conf
broker.conf.Version = V1_0_0_0
dialer := net.Dialer{}

conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
if err != nil {
t.Fatal(err)
}

broker.conn = conn
err = broker.authenticateViaSASL()
if err == nil {
t.Errorf("should never happen - expected read timeout")
}
}

func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {

testTable := []struct {
Expand Down