diff --git a/broker.go b/broker.go index b9f0db56c..d049e9b47 100644 --- a/broker.go +++ b/broker.go @@ -429,6 +429,7 @@ type ProduceCallback func(*ProduceResponse, error) // // Make sure not to Close the broker in the callback as it will lead to a deadlock. func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error { + metricRegistry := b.metricRegistry needAcks := request.RequiredAcks != NoResponse // Use a nil promise when no acks is required var promise *responsePromise @@ -446,7 +447,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error return } - if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil { + if err := versionedDecode(packets, res, request.version(), metricRegistry); err != nil { // Malformed response cb(nil, err) return @@ -459,6 +460,8 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error } } + b.lock.Lock() + defer b.lock.Unlock() return b.sendWithPromise(request, promise) } @@ -939,6 +942,7 @@ func (b *Broker) write(buf []byte) (n int, err error) { return b.conn.Write(buf) } +// b.lock must be haled by caller func (b *Broker) send(rb protocolBody, promiseResponse bool, responseHeaderVersion int16) (*responsePromise, error) { var promise *responsePromise if promiseResponse { @@ -963,10 +967,8 @@ func makeResponsePromise(responseHeaderVersion int16) *responsePromise { return promise } +// b.lock must be held by caller func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error { - b.lock.Lock() - defer b.lock.Unlock() - if b.conn == nil { if b.connErr != nil { return b.connErr @@ -1022,6 +1024,8 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error { } func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error { + b.lock.Lock() + defer b.lock.Unlock() responseHeaderVersion := int16(-1) if res != nil { responseHeaderVersion = res.headerVersion() @@ -1036,13 +1040,13 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error { return nil } - return b.handleResponsePromise(req, res, promise) + return handleResponsePromise(req, res, promise, b.metricRegistry) } -func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error { +func handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise, metricRegistry metrics.Registry) error { select { case buf := <-promise.packets: - return versionedDecode(buf, res, req.version(), b.metricRegistry) + return versionedDecode(buf, res, req.version(), metricRegistry) case err := <-promise.errors: return err } @@ -1185,6 +1189,7 @@ func (b *Broker) authenticateViaSASLv0() error { } func (b *Broker) authenticateViaSASLv1() error { + metricRegistry := b.metricRegistry if b.conf.Net.SASL.Handshake { handshakeRequest := &SaslHandshakeRequest{Mechanism: string(b.conf.Net.SASL.Mechanism), Version: b.conf.Net.SASL.Version} handshakeResponse := new(SaslHandshakeResponse) @@ -1195,7 +1200,7 @@ func (b *Broker) authenticateViaSASLv1() error { Logger.Printf("Error while performing SASL handshake %s\n", b.addr) return handshakeErr } - handshakeErr = b.handleResponsePromise(handshakeRequest, handshakeResponse, prom) + handshakeErr = handleResponsePromise(handshakeRequest, handshakeResponse, prom, metricRegistry) if handshakeErr != nil { Logger.Printf("Error while performing SASL handshake %s\n", b.addr) return handshakeErr @@ -1215,7 +1220,7 @@ func (b *Broker) authenticateViaSASLv1() error { Logger.Printf("Error while performing SASL Auth %s\n", b.addr) return nil, authErr } - authErr = b.handleResponsePromise(authenticateRequest, authenticateResponse, prom) + authErr = handleResponsePromise(authenticateRequest, authenticateResponse, prom, metricRegistry) if authErr != nil { Logger.Printf("Error while performing SASL Auth %s\n", b.addr) return nil, authErr diff --git a/broker_test.go b/broker_test.go index 52a4e4bae..32ddb8694 100644 --- a/broker_test.go +++ b/broker_test.go @@ -123,7 +123,7 @@ func TestSimpleBrokerCommunication(t *testing.T) { pendingNotify <- brokerMetrics{bytesRead, bytesWritten} }) broker := NewBroker(mb.Addr()) - // Set the broker id in order to validate local broujhjker metrics + // Set the broker id in order to validate local broker metrics broker.id = 0 conf := NewTestConfig() conf.ApiVersionsRequest = false diff --git a/consumer_group_test.go b/consumer_group_test.go index 7a69da382..15a49c5f9 100644 --- a/consumer_group_test.go +++ b/consumer_group_test.go @@ -2,8 +2,10 @@ package sarama import ( "context" + "errors" "sync" "testing" + "time" ) type handler struct { @@ -93,3 +95,108 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) { }() wg.Wait() } + +func TestConsume_RaceTest(t *testing.T) { + const groupID = "test-group" + const topic = "test-topic" + const offsetStart = int64(1234) + + cfg := NewConfig() + cfg.Version = V2_8_1_0 + cfg.Consumer.Return.Errors = true + + seedBroker := NewMockBroker(t, 1) + + joinGroupResponse := &JoinGroupResponse{} + + syncGroupResponse := &SyncGroupResponse{ + Version: 3, // sarama > 2.3.0.0 uses version 3 + } + // Leverage mock response to get the MemberAssignment bytes + mockSyncGroupResponse := NewMockSyncGroupResponse(t).SetMemberAssignment(&ConsumerGroupMemberAssignment{ + Version: 1, + Topics: map[string][]int32{topic: {0}}, // map "test-topic" to partition 0 + UserData: []byte{0x01}, + }) + syncGroupResponse.MemberAssignment = mockSyncGroupResponse.MemberAssignment + + heartbeatResponse := &HeartbeatResponse{ + Err: ErrNoError, + } + offsetFetchResponse := &OffsetFetchResponse{ + Version: 1, + ThrottleTimeMs: 0, + Err: ErrNoError, + } + offsetFetchResponse.AddBlock(topic, 0, &OffsetFetchResponseBlock{ + Offset: offsetStart, + LeaderEpoch: 0, + Metadata: "", + Err: ErrNoError}) + + offsetResponse := &OffsetResponse{ + Version: 1, + } + offsetResponse.AddTopicPartition(topic, 0, offsetStart) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(seedBroker.Addr(), seedBroker.BrokerID()) + metadataResponse.AddTopic("mismatched-topic", ErrUnknownTopicOrPartition) + + handlerMap := map[string]MockResponse{ + "ApiVersionsRequest": NewMockApiVersionsResponse(t), + "MetadataRequest": NewMockSequence(metadataResponse), + "OffsetRequest": NewMockSequence(offsetResponse), + "OffsetFetchRequest": NewMockSequence(offsetFetchResponse), + "FindCoordinatorRequest": NewMockSequence(NewMockFindCoordinatorResponse(t). + SetCoordinator(CoordinatorGroup, groupID, seedBroker)), + "JoinGroupRequest": NewMockSequence(joinGroupResponse), + "SyncGroupRequest": NewMockSequence(syncGroupResponse), + "HeartbeatRequest": NewMockSequence(heartbeatResponse), + } + seedBroker.SetHandlerByMap(handlerMap) + + cancelCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(4*time.Second)) + + defer seedBroker.Close() + + retryWait := 20 * time.Millisecond + var err error + clientRetries := 0 +outerFor: + for { + _, err = NewConsumerGroup([]string{seedBroker.Addr()}, groupID, cfg) + if err == nil { + break + } + + if retryWait < time.Minute { + retryWait *= 2 + } + + clientRetries++ + + timer := time.NewTimer(retryWait) + select { + case <-cancelCtx.Done(): + err = cancelCtx.Err() + timer.Stop() + break outerFor + case <-timer.C: + } + timer.Stop() + } + if err == nil { + t.Fatalf("should not proceed to Consume") + } + + if clientRetries <= 0 { + t.Errorf("clientRetries = %v; want > 0", clientRetries) + } + + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + + cancel() +}