Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(metrics): fix race when accessing metric registry #2409

Merged
merged 1 commit into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ type ProduceCallback func(*ProduceResponse, error)
//
// Make sure not to Close the broker in the callback as it will lead to a deadlock.
func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error {
metricRegistry := b.metricRegistry
needAcks := request.RequiredAcks != NoResponse
// Use a nil promise when no acks is required
var promise *responsePromise
Expand All @@ -446,7 +447,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
return
}

if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil {
if err := versionedDecode(packets, res, request.version(), metricRegistry); err != nil {
// Malformed response
cb(nil, err)
return
Expand All @@ -459,6 +460,8 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
}
}

b.lock.Lock()
defer b.lock.Unlock()
return b.sendWithPromise(request, promise)
}

Expand Down Expand Up @@ -939,6 +942,7 @@ func (b *Broker) write(buf []byte) (n int, err error) {
return b.conn.Write(buf)
}

// b.lock must be haled by caller
func (b *Broker) send(rb protocolBody, promiseResponse bool, responseHeaderVersion int16) (*responsePromise, error) {
var promise *responsePromise
if promiseResponse {
Expand All @@ -963,10 +967,8 @@ func makeResponsePromise(responseHeaderVersion int16) *responsePromise {
return promise
}

// b.lock must be held by caller
func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error {
b.lock.Lock()
defer b.lock.Unlock()

if b.conn == nil {
if b.connErr != nil {
return b.connErr
Expand Down Expand Up @@ -1022,6 +1024,8 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
}

func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
b.lock.Lock()
defer b.lock.Unlock()
responseHeaderVersion := int16(-1)
if res != nil {
responseHeaderVersion = res.headerVersion()
Expand All @@ -1036,13 +1040,13 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
return nil
}

return b.handleResponsePromise(req, res, promise)
return handleResponsePromise(req, res, promise, b.metricRegistry)
}

func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error {
func handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise, metricRegistry metrics.Registry) error {
select {
case buf := <-promise.packets:
return versionedDecode(buf, res, req.version(), b.metricRegistry)
return versionedDecode(buf, res, req.version(), metricRegistry)
case err := <-promise.errors:
return err
}
Expand Down Expand Up @@ -1185,6 +1189,7 @@ func (b *Broker) authenticateViaSASLv0() error {
}

func (b *Broker) authenticateViaSASLv1() error {
metricRegistry := b.metricRegistry
if b.conf.Net.SASL.Handshake {
handshakeRequest := &SaslHandshakeRequest{Mechanism: string(b.conf.Net.SASL.Mechanism), Version: b.conf.Net.SASL.Version}
handshakeResponse := new(SaslHandshakeResponse)
Expand All @@ -1195,7 +1200,7 @@ func (b *Broker) authenticateViaSASLv1() error {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
}
handshakeErr = b.handleResponsePromise(handshakeRequest, handshakeResponse, prom)
handshakeErr = handleResponsePromise(handshakeRequest, handshakeResponse, prom, metricRegistry)
if handshakeErr != nil {
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
return handshakeErr
Expand All @@ -1215,7 +1220,7 @@ func (b *Broker) authenticateViaSASLv1() error {
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
return nil, authErr
}
authErr = b.handleResponsePromise(authenticateRequest, authenticateResponse, prom)
authErr = handleResponsePromise(authenticateRequest, authenticateResponse, prom, metricRegistry)
if authErr != nil {
Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
return nil, authErr
Expand Down
2 changes: 1 addition & 1 deletion broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestSimpleBrokerCommunication(t *testing.T) {
pendingNotify <- brokerMetrics{bytesRead, bytesWritten}
})
broker := NewBroker(mb.Addr())
// Set the broker id in order to validate local broujhjker metrics
// Set the broker id in order to validate local broker metrics
broker.id = 0
conf := NewTestConfig()
conf.ApiVersionsRequest = false
Expand Down
107 changes: 107 additions & 0 deletions consumer_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package sarama

import (
"context"
"errors"
"sync"
"testing"
"time"
)

type handler struct {
Expand Down Expand Up @@ -93,3 +95,108 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) {
}()
wg.Wait()
}

func TestConsume_RaceTest(t *testing.T) {
const groupID = "test-group"
const topic = "test-topic"
const offsetStart = int64(1234)

cfg := NewConfig()
cfg.Version = V2_8_1_0
cfg.Consumer.Return.Errors = true

seedBroker := NewMockBroker(t, 1)

joinGroupResponse := &JoinGroupResponse{}

syncGroupResponse := &SyncGroupResponse{
Version: 3, // sarama > 2.3.0.0 uses version 3
}
// Leverage mock response to get the MemberAssignment bytes
mockSyncGroupResponse := NewMockSyncGroupResponse(t).SetMemberAssignment(&ConsumerGroupMemberAssignment{
Version: 1,
Topics: map[string][]int32{topic: {0}}, // map "test-topic" to partition 0
UserData: []byte{0x01},
})
syncGroupResponse.MemberAssignment = mockSyncGroupResponse.MemberAssignment

heartbeatResponse := &HeartbeatResponse{
Err: ErrNoError,
}
offsetFetchResponse := &OffsetFetchResponse{
Version: 1,
ThrottleTimeMs: 0,
Err: ErrNoError,
}
offsetFetchResponse.AddBlock(topic, 0, &OffsetFetchResponseBlock{
Offset: offsetStart,
LeaderEpoch: 0,
Metadata: "",
Err: ErrNoError})

offsetResponse := &OffsetResponse{
Version: 1,
}
offsetResponse.AddTopicPartition(topic, 0, offsetStart)

metadataResponse := new(MetadataResponse)
metadataResponse.AddBroker(seedBroker.Addr(), seedBroker.BrokerID())
metadataResponse.AddTopic("mismatched-topic", ErrUnknownTopicOrPartition)

handlerMap := map[string]MockResponse{
"ApiVersionsRequest": NewMockApiVersionsResponse(t),
"MetadataRequest": NewMockSequence(metadataResponse),
"OffsetRequest": NewMockSequence(offsetResponse),
"OffsetFetchRequest": NewMockSequence(offsetFetchResponse),
"FindCoordinatorRequest": NewMockSequence(NewMockFindCoordinatorResponse(t).
SetCoordinator(CoordinatorGroup, groupID, seedBroker)),
"JoinGroupRequest": NewMockSequence(joinGroupResponse),
"SyncGroupRequest": NewMockSequence(syncGroupResponse),
"HeartbeatRequest": NewMockSequence(heartbeatResponse),
}
seedBroker.SetHandlerByMap(handlerMap)

cancelCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(4*time.Second))

defer seedBroker.Close()

retryWait := 20 * time.Millisecond
var err error
clientRetries := 0
outerFor:
for {
_, err = NewConsumerGroup([]string{seedBroker.Addr()}, groupID, cfg)
if err == nil {
break
}

if retryWait < time.Minute {
retryWait *= 2
}

clientRetries++

timer := time.NewTimer(retryWait)
select {
case <-cancelCtx.Done():
err = cancelCtx.Err()
timer.Stop()
break outerFor
case <-timer.C:
}
timer.Stop()
}
if err == nil {
t.Fatalf("should not proceed to Consume")
}

if clientRetries <= 0 {
t.Errorf("clientRetries = %v; want > 0", clientRetries)
}

if err != nil && !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}

cancel()
}