diff --git a/broker.go b/broker.go index 3de202655..81467498c 100644 --- a/broker.go +++ b/broker.go @@ -671,6 +671,26 @@ func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogD return response, nil } +// readFull ensures the conn ReadDeadline has been setup before making a +// call to io.ReadFull +func (b *Broker) readFull(buf []byte) (n int, err error) { + if err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)); err != nil { + return 0, err + } + + return io.ReadFull(b.conn, buf) +} + +// write ensures the conn WriteDeadline has been setup before making a +// call to conn.Write +func (b *Broker) write(buf []byte) (n int, err error) { + if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil { + return 0, err + } + + return b.conn.Write(buf) +} + func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, error) { b.lock.Lock() defer b.lock.Unlock() @@ -692,14 +712,9 @@ func (b *Broker) send(rb protocolBody, promiseResponse bool) (*responsePromise, return nil, err } - err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) - if err != nil { - return nil, err - } - requestTime := time.Now() - bytes, err := b.conn.Write(buf) - b.updateOutgoingCommunicationMetrics(bytes) //TODO: should it be after error check + bytes, err := b.write(buf) + b.updateOutgoingCommunicationMetrics(bytes) if err != nil { return nil, err } @@ -806,14 +821,7 @@ func (b *Broker) responseReceiver() { continue } - err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)) - if err != nil { - dead = err - response.errors <- err - continue - } - - bytesReadHeader, err := io.ReadFull(b.conn, header) + bytesReadHeader, err := b.readFull(header) requestLatency := time.Since(response.requestTime) if err != nil { b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) @@ -840,7 +848,7 @@ func (b *Broker) responseReceiver() { } buf := make([]byte, decodedHeader.length-4) - bytesReadBody, err := io.ReadFull(b.conn, buf) + bytesReadBody, err := b.readFull(buf) b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency) if err != nil { dead = err @@ -883,22 +891,17 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int return err } - err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) - if err != nil { - return err - } - requestTime := time.Now() - bytes, err := b.conn.Write(buf) + bytes, err := b.write(buf) b.updateOutgoingCommunicationMetrics(bytes) if err != nil { Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error()) return err } b.correlationID++ - //wait for the response + header := make([]byte, 8) // response header - _, err = io.ReadFull(b.conn, header) + _, err = b.readFull(header) if err != nil { Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error()) return err @@ -906,7 +909,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int length := binary.BigEndian.Uint32(header[:4]) payload := make([]byte, length-4) - n, err := io.ReadFull(b.conn, payload) + n, err := b.readFull(payload) if err != nil { Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error()) return err @@ -980,14 +983,8 @@ func (b *Broker) sendAndReceiveV0SASLPlainAuth() error { binary.BigEndian.PutUint32(authBytes, uint32(length)) copy(authBytes[4:], []byte("\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password)) - err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) - if err != nil { - Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error()) - return err - } - requestTime := time.Now() - bytesWritten, err := b.conn.Write(authBytes) + bytesWritten, err := b.write(authBytes) b.updateOutgoingCommunicationMetrics(bytesWritten) if err != nil { Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) @@ -995,7 +992,7 @@ func (b *Broker) sendAndReceiveV0SASLPlainAuth() error { } header := make([]byte, 4) - n, err := io.ReadFull(b.conn, header) + n, err := b.readFull(header) b.updateIncomingCommunicationMetrics(n, time.Since(requestTime)) // If the credentials are valid, we would get a 4 byte response filled with null characters. // Otherwise, the broker closes the connection and we get an EOF @@ -1151,16 +1148,12 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i return 0, err } - if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil { - return 0, err - } - - return b.conn.Write(buf) + return b.write(buf) } func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) { buf := make([]byte, responseLengthSize+correlationIDSize) - _, err := io.ReadFull(b.conn, buf) + _, err := b.readFull(buf) if err != nil { return nil, err } @@ -1176,7 +1169,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e } buf = make([]byte, header.length-correlationIDSize) - _, err = io.ReadFull(b.conn, buf) + _, err = b.readFull(buf) if err != nil { return nil, err } @@ -1231,12 +1224,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro return 0, err } - err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)) - if err != nil { - Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error()) - return 0, err - } - return b.conn.Write(buf) + return b.write(buf) } func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) { @@ -1250,24 +1238,17 @@ func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlatio return 0, err } - if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil { - return 0, err - } - - return b.conn.Write(buf) + return b.write(buf) } func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) { - buf := make([]byte, responseLengthSize+correlationIDSize) - - bytesRead, err := io.ReadFull(b.conn, buf) + bytesRead, err := b.readFull(buf) if err != nil { return bytesRead, err } header := responseHeader{} - err = decode(buf, &header) if err != nil { return bytesRead, err @@ -1278,8 +1259,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl } buf = make([]byte, header.length-correlationIDSize) - - c, err := io.ReadFull(b.conn, buf) + c, err := b.readFull(buf) bytesRead += c if err != nil { return bytesRead, err diff --git a/broker_test.go b/broker_test.go index f12c16901..fe170ab86 100644 --- a/broker_test.go +++ b/broker_test.go @@ -493,6 +493,55 @@ func TestSASLPlainAuth(t *testing.T) { } } +// TestSASLReadTimeout ensures that the broker connection won't block forever +// if the remote end never responds after the handshake +func TestSASLReadTimeout(t *testing.T) { + mockBroker := NewMockBroker(t, 0) + defer mockBroker.Close() + + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`response_payload`)) + + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + }) + + broker := NewBroker(mockBroker.Addr()) + { + broker.requestRate = metrics.NilMeter{} + broker.outgoingByteRate = metrics.NilMeter{} + broker.incomingByteRate = metrics.NilMeter{} + broker.requestSize = metrics.NilHistogram{} + broker.responseSize = metrics.NilHistogram{} + broker.responseRate = metrics.NilMeter{} + broker.requestLatency = metrics.NilHistogram{} + } + + conf := NewConfig() + { + conf.Net.ReadTimeout = time.Millisecond + conf.Net.SASL.Mechanism = SASLTypePlaintext + conf.Net.SASL.User = "token" + conf.Net.SASL.Password = "password" + conf.Net.SASL.Version = SASLHandshakeV1 + } + + broker.conf = conf + broker.conf.Version = V1_0_0_0 + dialer := net.Dialer{} + + conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + broker.conn = conn + err = broker.authenticateViaSASL() + if err == nil { + t.Errorf("should never happen - expected read timeout") + } +} + func TestGSSAPIKerberosAuth_Authorize(t *testing.T) { testTable := []struct {