diff --git a/consumer_group.go b/consumer_group.go index 67a9f9241..de119d520 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -210,11 +210,6 @@ func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler Co return err } - // loop check topic partition numbers changed - // will trigger rebalance when any topic partitions number had changed - // avoid Consume function called again that will generate more than loopCheckPartitionNumbers coroutine - go c.loopCheckPartitionNumbers(topics, sess) - // Wait for session exit signal <-sess.ctx.Done() @@ -347,13 +342,15 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler // Prepare distribution plan if we joined as the leader var plan BalanceStrategyPlan var members map[string]ConsumerGroupMemberMetadata + var allSubscribedTopicPartitions map[string][]int32 + var allSubscribedTopics []string if join.LeaderId == join.MemberId { members, err = join.GetMembers() if err != nil { return nil, err } - plan, err = c.balance(strategy, members) + allSubscribedTopicPartitions, allSubscribedTopics, plan, err = c.balance(strategy, members) if err != nil { return nil, err } @@ -421,7 +418,17 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler } } - return newConsumerGroupSession(ctx, c, claims, join.MemberId, join.GenerationId, handler) + session, err := newConsumerGroupSession(ctx, c, claims, join.MemberId, join.GenerationId, handler) + if err != nil { + return nil, err + } + + // only the leader needs to check whether there are newly-added partitions in order to trigger a rebalance + if join.LeaderId == join.MemberId { + go c.loopCheckPartitionNumbers(allSubscribedTopicPartitions, allSubscribedTopics, session) + } + + return session, err } func (c *consumerGroup) joinGroupRequest(coordinator *Broker, topics []string) (*JoinGroupResponse, error) { @@ -551,23 +558,36 @@ func (c *consumerGroup) heartbeatRequest(coordinator *Broker, memberID string, g return coordinator.Heartbeat(req) } -func (c *consumerGroup) balance(strategy BalanceStrategy, members map[string]ConsumerGroupMemberMetadata) (BalanceStrategyPlan, error) { - topics := make(map[string][]int32) +func (c *consumerGroup) balance(strategy BalanceStrategy, members map[string]ConsumerGroupMemberMetadata) (map[string][]int32, []string, BalanceStrategyPlan, error) { + topicPartitions := make(map[string][]int32) for _, meta := range members { for _, topic := range meta.Topics { - topics[topic] = nil + topicPartitions[topic] = nil } } - for topic := range topics { + allSubscribedTopics := make([]string, 0, len(topicPartitions)) + for topic := range topicPartitions { + allSubscribedTopics = append(allSubscribedTopics, topic) + } + + // refresh metadata for all the subscribed topics in the consumer group + // to avoid using stale metadata to assigning partitions + err := c.client.RefreshMetadata(allSubscribedTopics...) + if err != nil { + return nil, nil, nil, err + } + + for topic := range topicPartitions { partitions, err := c.client.Partitions(topic) if err != nil { - return nil, err + return nil, nil, nil, err } - topics[topic] = partitions + topicPartitions[topic] = partitions } - return strategy.Plan(members, topics) + plan, err := strategy.Plan(members, topicPartitions) + return topicPartitions, allSubscribedTopics, plan, err } // Leaves the cluster, called by Close. @@ -653,24 +673,29 @@ func (c *consumerGroup) handleError(err error, topic string, partition int32) { } } -func (c *consumerGroup) loopCheckPartitionNumbers(topics []string, session *consumerGroupSession) { +func (c *consumerGroup) loopCheckPartitionNumbers(allSubscribedTopicPartitions map[string][]int32, topics []string, session *consumerGroupSession) { if c.config.Metadata.RefreshFrequency == time.Duration(0) { return } - pause := time.NewTicker(c.config.Metadata.RefreshFrequency) + defer session.cancel() - defer pause.Stop() - var oldTopicToPartitionNum map[string]int - var err error - if oldTopicToPartitionNum, err = c.topicToPartitionNumbers(topics); err != nil { - return + + oldTopicToPartitionNum := make(map[string]int, len(allSubscribedTopicPartitions)) + for topic, partitions := range allSubscribedTopicPartitions { + oldTopicToPartitionNum[topic] = len(partitions) } + + pause := time.NewTicker(c.config.Metadata.RefreshFrequency) + defer pause.Stop() for { if newTopicToPartitionNum, err := c.topicToPartitionNumbers(topics); err != nil { return } else { for topic, num := range oldTopicToPartitionNum { if newTopicToPartitionNum[topic] != num { + Logger.Printf( + "consumergroup/%s loop check partition number goroutine find partitions in topics %s changed from %d to %d\n", + c.groupID, topics, num, newTopicToPartitionNum[topic]) return // trigger the end of the session on exit } } @@ -679,7 +704,7 @@ func (c *consumerGroup) loopCheckPartitionNumbers(topics []string, session *cons case <-pause.C: case <-session.ctx.Done(): Logger.Printf( - "consumergroup/%s loop check partition number coroutine will exit, topics %s\n", + "consumergroup/%s loop check partition number goroutine will exit, topics %s\n", c.groupID, topics) // if session closed by other, should be exited return @@ -1054,7 +1079,7 @@ type ConsumerGroupClaim interface { // InitialOffset returns the initial offset that was used as a starting point for this claim. InitialOffset() int64 - // HighWaterMarkOffset returns the high water mark offset of the partition, + // HighWaterMarkOffset returns the high watermark offset of the partition, // i.e. the offset that will be used for the next message that will be produced. // You can use this to determine how far behind the processing is. HighWaterMarkOffset() int64 diff --git a/functional_consumer_group_test.go b/functional_consumer_group_test.go index f05666185..19ca48348 100644 --- a/functional_consumer_group_test.go +++ b/functional_consumer_group_test.go @@ -135,6 +135,51 @@ func TestFuncConsumerGroupExcessConsumers(t *testing.T) { m5.AssertCleanShutdown() } +func TestFuncConsumerGroupRebalanceAfterAddingPartitions(t *testing.T) { + checkKafkaVersion(t, "0.10.2") + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + config := NewTestConfig() + config.Version = V2_3_0_0 + admin, err := NewClusterAdmin(FunctionalTestEnv.KafkaBrokerAddrs, config) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = admin.Close() + }() + + groupID := testFuncConsumerGroupID(t) + + // start M1 + m1 := runTestFuncConsumerGroupMember(t, groupID, "M1", 0, nil, "test.1") + defer m1.Stop() + m1.WaitForClaims(map[string]int{"test.1": 1}) + m1.WaitForHandlers(1) + + // start M2 + m2 := runTestFuncConsumerGroupMember(t, groupID, "M2", 0, nil, "test.1_to_2") + defer m2.Stop() + m2.WaitForClaims(map[string]int{"test.1_to_2": 1}) + m1.WaitForHandlers(1) + + // add a new partition to topic "test.1_to_2" + err = admin.CreatePartitions("test.1_to_2", 2, nil, false) + if err != nil { + t.Fatal(err) + } + + // assert that claims are shared among both members + m2.WaitForClaims(map[string]int{"test.1_to_2": 2}) + m2.WaitForHandlers(2) + m1.WaitForClaims(map[string]int{"test.1": 1}) + m1.WaitForHandlers(1) + + m1.AssertCleanShutdown() + m2.AssertCleanShutdown() +} + func TestFuncConsumerGroupFuzzy(t *testing.T) { checkKafkaVersion(t, "0.10.2") setupFunctionalTest(t) @@ -360,6 +405,8 @@ func defaultConfig(clientID string) *Config { config.Consumer.Return.Errors = true config.Consumer.Offsets.Initial = OffsetOldest config.Consumer.Group.Rebalance.Timeout = 10 * time.Second + config.Metadata.Full = false + config.Metadata.RefreshFrequency = 5 * time.Second return config } diff --git a/functional_test.go b/functional_test.go index fad313a5d..602c1c45a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -40,6 +40,10 @@ var ( NumPartitions: 1, ReplicationFactor: 3, }, + "test.1_to_2": { + NumPartitions: 1, + ReplicationFactor: 3, + }, } FunctionalTestEnv *testEnvironment