diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index 804772efb2fd..71535d234b75 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -77,6 +77,8 @@ */ class KafkaUnboundedReader extends UnboundedReader> { + boolean atleastOnePollCompleted = false; + ///////////////////// Reader API //////////////////////////////////////////////////////////// @SuppressWarnings("FutureReturnValueIgnored") @Override @@ -158,7 +160,6 @@ public boolean advance() throws IOException { */ while (true) { if (curBatch.hasNext()) { - // data from the next partition? PartitionState pState = curBatch.next(); if (!pState.recordIter.hasNext()) { // -- (c) @@ -229,16 +230,14 @@ public boolean advance() throws IOException { for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { backlogBytesOfSplit.set(backlogSplit.getValue()); } - return true; // record has been read and proccessed, so we return. (only read a record at a time) + return true; } else { // -- (b) - nextBatch(); // void, returns nothing, can this be done in the background instead of when we call advance? - - if (!curBatch.hasNext()) { // returns false because nothing returned in time? + nextBatch(); + atleastOnePollCompleted = false; // Reset it for next call + if (!curBatch.hasNext()) { return false; } - // Gives no such element exception - // return true; //? returns, then will call advance again. Whats the difference between repeatedly calling advance vs iterating over constatnly? } } } @@ -333,7 +332,7 @@ public long getSplitBacklogBytes() { private final KafkaUnboundedSource source; private final String name; - private @Nullable Consumer consumer = null; + @VisibleForTesting @Nullable Consumer consumer = null; private final List> partitionStates; private @Nullable KafkaRecord curRecord = null; private @Nullable Instant curTimestamp = null; @@ -572,6 +571,7 @@ private void consumerPollLoop() { try { if (records.isEmpty()) { records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + atleastOnePollCompleted = true; } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); @@ -579,15 +579,18 @@ private void consumerPollLoop() { commitCheckpointMark(); } catch (InterruptedException e) { + atleastOnePollCompleted = true; LOG.warn("{}: consumer thread is interrupted", this, e); // not expected break; } catch (WakeupException e) { + atleastOnePollCompleted = true; break; } } LOG.info("{}: Returning from consumer pool loop", this); } catch (Exception e) { // mostly an unrecoverable KafkaException. LOG.error("{}: Exception while reading from Kafka", this, e); + atleastOnePollCompleted = true; consumerPollException.set(e); throw e; } @@ -624,19 +627,24 @@ void finalizeCheckpointMarkAsync(KafkaCheckpointMark checkpointMark) { checkpointMarkCommitsEnqueued.inc(); } + // Ensure atleast one consumer poll has completed since this was last called. private void nextBatch() throws IOException { curBatch = Collections.emptyIterator(); - ConsumerRecords records; - try { - // poll available records, wait (if necessary) up to the specified timeout. - records = - availableRecordsQueue.poll(recordsDequeuePollTimeout.getMillis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - LOG.warn("{}: Unexpected", this, e); - return; - } + do { + // try until background poll has completed atleast once + try { + // poll available records, wait (if necessary) up to the specified timeout. + records = + availableRecordsQueue.poll( + recordsDequeuePollTimeout.getMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("{}: Unexpected", this, e); + return; + } + + } while (!atleastOnePollCompleted); if (records == null) { // Check if the poll thread failed with an exception. @@ -656,7 +664,8 @@ private void nextBatch() throws IOException { LOG.debug("Record count: " + records.count()); } - partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); + ConsumerRecords finalRecords = records; + partitionStates.forEach(p -> p.recordIter = finalRecords.records(p.topicPartition).iterator()); // cycle through the partitions in order to interleave records from each. curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 73aee5aeeef0..af9df9d81cea 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -117,6 +117,7 @@ import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.MockConsumer; import org.apache.kafka.clients.consumer.OffsetAndTimestamp; import org.apache.kafka.clients.consumer.OffsetResetStrategy; @@ -186,104 +187,149 @@ public class KafkaIOTest { private static final String TIMESTAMP_START_MILLIS_CONFIG = "test.timestamp.start.millis"; private static final String TIMESTAMP_TYPE_CONFIG = "test.timestamp.type"; - // Update mock consumer with records distributed among the given topics, each with given number - // of partitions. Records are assigned in round-robin order among the partitions. - private static MockConsumer mkMockConsumer( - List topics, - int partitionsPerTopic, - int numElements, - OffsetResetStrategy offsetResetStrategy, - Map config, - SerializableFunction keyFunction, - SerializableFunction valueFunction) { + static class CustomMockConsumer extends MockConsumer { final List partitions = new ArrayList<>(); - final Map>> records = new HashMap<>(); + Map>> records = new HashMap<>(); Map> partitionMap = new HashMap<>(); + int pollCounter = 0; + + // This is updated when reader assigns partitions. + AtomicReference> assignedPartitions = + new AtomicReference<>(Collections.emptyList()); + + long[] offsets; + + public CustomMockConsumer( + List topics, + int partitionsPerTopic, + int numElements, + OffsetResetStrategy offsetResetStrategy, + Map config, + SerializableFunction keyFunction, + SerializableFunction valueFunction) { + + super(offsetResetStrategy); - for (String topic : topics) { - List partIds = new ArrayList<>(partitionsPerTopic); - for (int i = 0; i < partitionsPerTopic; i++) { - TopicPartition tp = new TopicPartition(topic, i); - partitions.add(tp); - partIds.add(new PartitionInfo(topic, i, null, null, null)); - records.put(tp, new ArrayList<>()); + for (String topic : topics) { + List partIds = new ArrayList<>(partitionsPerTopic); + for (int i = 0; i < partitionsPerTopic; i++) { + TopicPartition tp = new TopicPartition(topic, i); + partitions.add(tp); + partIds.add(new PartitionInfo(topic, i, null, null, null)); + records.put(tp, new ArrayList<>()); + } + partitionMap.put(topic, partIds); } - partitionMap.put(topic, partIds); - } - int numPartitions = partitions.size(); - final long[] offsets = new long[numPartitions]; + int numPartitions = partitions.size(); + offsets = new long[numPartitions]; - long timestampStartMillis = - (Long) - config.getOrDefault(TIMESTAMP_START_MILLIS_CONFIG, LOG_APPEND_START_TIME.getMillis()); - TimestampType timestampType = - TimestampType.forName( - (String) - config.getOrDefault( - TIMESTAMP_TYPE_CONFIG, TimestampType.LOG_APPEND_TIME.toString())); + long timestampStartMillis = + (Long) + config.getOrDefault(TIMESTAMP_START_MILLIS_CONFIG, LOG_APPEND_START_TIME.getMillis()); + TimestampType timestampType = + TimestampType.forName( + (String) + config.getOrDefault( + TIMESTAMP_TYPE_CONFIG, TimestampType.LOG_APPEND_TIME.toString())); - for (int i = 0; i < numElements; i++) { - int pIdx = i % numPartitions; - TopicPartition tp = partitions.get(pIdx); - - byte[] key = keyFunction.apply(i); - byte[] value = valueFunction.apply(i); - - records - .get(tp) - .add( - new ConsumerRecord<>( - tp.topic(), - tp.partition(), - offsets[pIdx]++, - timestampStartMillis + Duration.standardSeconds(i).getMillis(), - timestampType, - 0, - key.length, - value.length, - key, - value)); + for (int i = 0; i < numElements; i++) { + int pIdx = i % numPartitions; + TopicPartition tp = partitions.get(pIdx); + + byte[] key = keyFunction.apply(i); + byte[] value = valueFunction.apply(i); + + records + .get(tp) + .add( + new ConsumerRecord( + tp.topic(), + tp.partition(), + offsets[pIdx]++, + timestampStartMillis + Duration.standardSeconds(i).getMillis(), + timestampType, + 0, + key.length, + value.length, + key, + value)); + } + + for (String topic : topics) { + super.updatePartitions(topic, partitionMap.get(topic)); + } } - // This is updated when reader assigns partitions. - final AtomicReference> assignedPartitions = - new AtomicReference<>(Collections.emptyList()); + public int pollCounter() { + return pollCounter; + } - final MockConsumer consumer = - new MockConsumer(offsetResetStrategy) { - @Override - public synchronized void assign(final Collection assigned) { - super.assign(assigned); - assignedPartitions.set(ImmutableList.copyOf(assigned)); - for (TopicPartition tp : assigned) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); - } - } - // Override offsetsForTimes() in order to look up the offsets by timestamp. - @Override - public synchronized Map offsetsForTimes( - Map timestampsToSearch) { - return timestampsToSearch.entrySet().stream() - .map( - e -> { - // In test scope, timestamp == offset. - long maxOffset = offsets[partitions.indexOf(e.getKey())]; - long offset = e.getValue(); - OffsetAndTimestamp value = - (offset >= maxOffset) ? null : new OffsetAndTimestamp(offset, offset); - return new SimpleEntry<>(e.getKey(), value); - }) - .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); - } - }; + @Override + public synchronized void assign(final Collection assigned) { + super.assign(assigned); + assignedPartitions.set(ImmutableList.copyOf(assigned)); + for (TopicPartition tp : assigned) { + updateBeginningOffsets(ImmutableMap.of(tp, 0L)); + updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); + } + } + + // Add a count for each time poll is triggered. + @Override + public synchronized ConsumerRecords poll(java.time.Duration timeout) { + ConsumerRecords records = super.poll(timeout); + pollCounter++; + return records; + } - for (String topic : topics) { - consumer.updatePartitions(topic, partitionMap.get(topic)); + // Needed to support older Consumer versions + @Override + public synchronized ConsumerRecords poll(long timeout) { + ConsumerRecords records = super.poll(timeout); + pollCounter++; + return records; } + // Override offsetsForTimes() in order to look up the offsets by timestamp. + @Override + public synchronized Map offsetsForTimes( + Map timestampsToSearch) { + return timestampsToSearch.entrySet().stream() + .map( + e -> { + // In test scope, timestamp == offset. + long maxOffset = offsets[partitions.indexOf(e.getKey())]; + long offset = e.getValue(); + OffsetAndTimestamp value = + (offset >= maxOffset) ? null : new OffsetAndTimestamp(offset, offset); + return new SimpleEntry<>(e.getKey(), value); + }) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + } + + // Update mock consumer with records distributed among the given topics, each with given number + // of partitions. Records are assigned in round-robin order among the partitions. + private static MockConsumer mkMockConsumer( + List topics, + int partitionsPerTopic, + int numElements, + OffsetResetStrategy offsetResetStrategy, + Map config, + SerializableFunction keyFunction, + SerializableFunction valueFunction) { + CustomMockConsumer customMockConsumer = + new CustomMockConsumer( + topics, + partitionsPerTopic, + numElements, + offsetResetStrategy, + config, + keyFunction, + valueFunction); + // MockConsumer does not maintain any relationship between partition seek position and the // records added. e.g. if we add 10 records to a partition and then seek to end of the // partition, MockConsumer is still going to return the 10 records in next poll. It is @@ -296,18 +342,19 @@ public synchronized Map offsetsForTimes( public void run() { // add all the records with offset >= current partition position. int recordsAdded = 0; - for (TopicPartition tp : assignedPartitions.get()) { - long curPos = consumer.position(tp); - for (ConsumerRecord r : records.get(tp)) { + for (TopicPartition tp : customMockConsumer.assignedPartitions.get()) { + long curPos = customMockConsumer.position(tp); + for (ConsumerRecord r : customMockConsumer.records.get(tp)) { if (r.offset() >= curPos) { - consumer.addRecord(r); + customMockConsumer.addRecord(r); recordsAdded++; } } } if (recordsAdded == 0) { if (config.get("inject.error.at.eof") != null) { - consumer.setException(new KafkaException("Injected error in consumer.poll()")); + customMockConsumer.setException( + new KafkaException("Injected error in consumer.poll()")); } // MockConsumer.poll(timeout) does not actually wait even when there aren't any // records. @@ -316,12 +363,12 @@ public void run() { // TODO: BEAM-4086: testUnboundedSourceWithoutBoundedWrapper() occasionally hangs // without this wait. Need to look into it. } - consumer.schedulePollTask(this); + customMockConsumer.schedulePollTask(this); } }; - consumer.schedulePollTask(recordEnqueueTask); - return consumer; + customMockConsumer.schedulePollTask(recordEnqueueTask); + return customMockConsumer; } private static class ConsumerFactoryFn @@ -1199,6 +1246,25 @@ private static void advanceOnce(UnboundedReader reader, boolean isStarted) th } } + // Ensure that the reader waits for consumer poll before returning from nextBatch. + @Test + public void testNextBatch() throws Exception { + int numElements = 10; + + // create a single split: + UnboundedSource, KafkaCheckpointMark> source = + mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) + .makeSource() + .split(1, PipelineOptionsFactory.create()) + .get(0); + + UnboundedReader> reader = source.createReader(null, null); + reader.start(); + + advanceOnce(reader, true); + assertTrue(((CustomMockConsumer) (((KafkaUnboundedReader) reader).consumer)).pollCounter() > 0); + } + @Test public void testUnboundedSourceCheckpointMark() throws Exception { int numElements = 85; // 85 to make sure some partitions have more records than other.