Skip to content

Commit

Permalink
[SPARK-49259][SS] Size based partition creation during kafka read
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Adds support for size based partition creation during kafka read.

### Why are the changes needed?
Currently Spark structured streaming provides `minPartitions` config to create more number of partitions than kafka has. This is helpful to increase parallelism but this value is can not be changed dynamically.

It would be better to dynamically increase spark partitions based on input size, if input size is high create more partitions. With this change we can dynamically create more partitions to handle varying loads.

### Does this PR introduce _any_ user-facing change?
An additional parameter(maxRecordsPerPartition) will be accepted on the Kafka source provider.

<img width="940" alt="Screenshot 2024-10-17 at 11 13 27 AM" src="https://github.com/user-attachments/assets/29ecc65e-98fa-40ff-8565-480eeb207ff7">

<img width="1580" alt="Screenshot 2024-10-17 at 11 11 51 AM" src="https://github.com/user-attachments/assets/63652f82-f24f-4a24-ab24-acd3feb5e0d6">

### How was this patch tested?
Added Unit tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47927 from SubhamSinghal/SPARK-49259-structured-streaming-size-based-partition-creation-kafka.

Authored-by: subham611 <subhamsinghal@sharechat.co>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
subham611 authored and HeartSaVioR committed Oct 17, 2024
1 parent 8405c9b commit f96a6f8
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,26 @@ import org.apache.kafka.common.TopicPartition

import org.apache.spark.sql.util.CaseInsensitiveStringMap


/**
* Class to calculate offset ranges to process based on the from and until offsets, and
* the configured `minPartitions`.
* Class to calculate offset ranges to process based on the from and until offsets, and the
* configured `minPartitions` and `maxRecordsPerPartition`.
*/
private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) {
private[kafka010] class KafkaOffsetRangeCalculator(
val minPartitions: Option[Int],
val maxRecordsPerPartition: Option[Long]) {
require(minPartitions.isEmpty || minPartitions.get > 0)
require(maxRecordsPerPartition.isEmpty || maxRecordsPerPartition.get > 0)

/**
* Calculate the offset ranges that we are going to process this batch. If `minPartitions`
* is not set or is set less than or equal the number of `topicPartitions` that we're going to
* consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If
* `minPartitions` is set higher than the number of our `topicPartitions`, then we will split up
* the read tasks of the skewed partitions to multiple Spark tasks.
* The number of Spark tasks will be *approximately* `minPartitions`. It can be less or more
* depending on rounding errors or Kafka partitions that didn't receive any new data.
* Calculate the offset ranges that we are going to process this batch. If `minPartitions` is
* not set or is set less than or equal the number of `topicPartitions` that we're going to
* consume and, `maxRecordsPerPartition` is not set then we fall back to a 1-1 mapping of Spark
* tasks to Kafka partitions. If `maxRecordsPerPartition` is set, then we will split up read
* task to multiple tasks as per `maxRecordsPerPartition` value. If `minPartitions` is set
* higher than the number of our `topicPartitions`, then we will split up the read tasks of the
* skewed partitions to multiple Spark tasks. The number of Spark tasks will be *approximately*
* max of `(recordsPerPartition/maxRecordsPerPartition)` and `minPartitions`. It can be less or
* more depending on rounding errors or Kafka partitions that didn't receive any new data.
*
* Empty (`KafkaOffsetRange.size == 0`) or invalid (`KafkaOffsetRange.size < 0`) ranges will be
* dropped.
Expand All @@ -47,51 +51,81 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int
val offsetRanges = ranges.filter(_.size > 0)

// If minPartitions not set or there are enough partitions to satisfy minPartitions
if (minPartitions.isEmpty || offsetRanges.size >= minPartitions.get) {
// and maxRecordsPerPartition is empty
if ((minPartitions.isEmpty || offsetRanges.size >= minPartitions.get)
&& maxRecordsPerPartition.isEmpty) {
// Assign preferred executor locations to each range such that the same topic-partition is
// preferentially read from the same executor and the KafkaConsumer can be reused.
offsetRanges.map { range =>
range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations))
}
} else {
val dividedOffsetRanges = if (maxRecordsPerPartition.isDefined) {
val maxRecords = maxRecordsPerPartition.get
offsetRanges
.flatMap { range =>
val size = range.size
// number of partitions to divvy up this topic partition to
val parts = math.ceil(size.toDouble / maxRecords).toInt
getDividedPartition(parts, range)
}
.filter(_.size > 0)
} else {
offsetRanges
}

// Splits offset ranges with relatively large amount of data to smaller ones.
val totalSize = offsetRanges.map(_.size).sum
if (minPartitions.isDefined && minPartitions.get > dividedOffsetRanges.size) {
// Splits offset ranges with relatively large amount of data to smaller ones.
val totalSize = dividedOffsetRanges.map(_.size).sum

// First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges,
// in order to exclude the contents of unsplit ranges from the proportional math applied to
// split ranges
val unsplitRanges = dividedOffsetRanges.filter { range =>
getPartCount(range.size, totalSize, minPartitions.get) == 1
}

// First distinguish between any small (i.e. unsplit) ranges and large (i.e. split) ranges,
// in order to exclude the contents of unsplit ranges from the proportional math applied to
// split ranges
val unsplitRanges = offsetRanges.filter { range =>
getPartCount(range.size, totalSize, minPartitions.get) == 1
val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum
val splitRangeTotalSize = totalSize - unsplitRangeTotalSize
val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet
val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1)

// Now we can apply the main calculation logic
dividedOffsetRanges
.flatMap { range =>
val tp = range.topicPartition
val size = range.size
// number of partitions to divvy up this topic partition to
val parts = if (unsplitRangeTopicPartitions.contains(tp)) {
1
} else {
getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions)
}
getDividedPartition(parts, range)
}
.filter(_.size > 0)
} else {
dividedOffsetRanges
}
}
}

val unsplitRangeTotalSize = unsplitRanges.map(_.size).sum
val splitRangeTotalSize = totalSize - unsplitRangeTotalSize
val unsplitRangeTopicPartitions = unsplitRanges.map(_.topicPartition).toSet
val splitRangeMinPartitions = math.max(minPartitions.get - unsplitRanges.size, 1)

// Now we can apply the main calculation logic
offsetRanges.flatMap { range =>
val tp = range.topicPartition
val size = range.size
// number of partitions to divvy up this topic partition to
val parts = if (unsplitRangeTopicPartitions.contains(tp)) {
1
} else {
getPartCount(size, splitRangeTotalSize, splitRangeMinPartitions)
}
var remaining = size
var startOffset = range.fromOffset
(0 until parts).map { part =>
// Fine to do integer division. Last partition will consume all the round off errors
val thisPartition = remaining / (parts - part)
remaining -= thisPartition
val endOffset = math.min(startOffset + thisPartition, range.untilOffset)
val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None)
startOffset = endOffset
offsetRange
}
}.filter(_.size > 0)
private def getDividedPartition(
parts: Int,
offsetRange: KafkaOffsetRange): IndexedSeq[KafkaOffsetRange] = {
var remaining = offsetRange.size
var startOffset = offsetRange.fromOffset
val tp = offsetRange.topicPartition
val untilOffset = offsetRange.untilOffset

(0 until parts).map { part =>
// Fine to do integer division. Last partition will consume all the round off errors
val thisPartition = remaining / (parts - part)
remaining -= thisPartition
val endOffset = math.min(startOffset + thisPartition, untilOffset)
val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None)
startOffset = endOffset
offsetRange
}
}

Expand All @@ -114,9 +148,12 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int
private[kafka010] object KafkaOffsetRangeCalculator {

def apply(options: CaseInsensitiveStringMap): KafkaOffsetRangeCalculator = {
val optionalValue = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY))
val minPartition = Option(options.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY))
.map(_.toInt)
new KafkaOffsetRangeCalculator(optionalValue)
val maxRecordsPerPartition =
Option(options.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY))
.map(_.toLong)
new KafkaOffsetRangeCalculator(minPartition, maxRecordsPerPartition)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,18 @@ private[kafka010] class KafkaOffsetReaderAdmin(
*/
private val minPartitions =
readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt)
private val maxRecordsPerPartition =
readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong)

private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions)
private val rangeCalculator =
new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition)

/**
* Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks.
*/
private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = {
minPartitions.map(_ > numTopicPartitions).getOrElse(false)
private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = {
minPartitions.map(_ > offsetRanges.size).getOrElse(false) ||
offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue))
}

override def toString(): String = consumerStrategy.toString
Expand Down Expand Up @@ -397,7 +401,7 @@ private[kafka010] class KafkaOffsetReaderAdmin(
KafkaOffsetRange(tp, fromOffset, untilOffset, None)
}.toSeq

if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) {
if (shouldDivvyUpLargePartitions(offsetRangesBase)) {
val fromOffsetsMap =
offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap
val untilOffsetsMap =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,21 @@ private[kafka010] class KafkaOffsetReaderConsumer(
*/
private val minPartitions =
readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt)
private val maxRecordsPerPartition =
readerOptions.get(KafkaSourceProvider.MAX_RECORDS_PER_PARTITION_OPTION_KEY).map(_.toLong)

private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions)
private val rangeCalculator =
new KafkaOffsetRangeCalculator(minPartitions, maxRecordsPerPartition)

private[kafka010] val offsetFetchAttemptIntervalMs =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong

/**
* Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks.
*/
private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = {
minPartitions.map(_ > numTopicPartitions).getOrElse(false)
private def shouldDivvyUpLargePartitions(offsetRanges: Seq[KafkaOffsetRange]): Boolean = {
minPartitions.map(_ > offsetRanges.size).getOrElse(false) ||
offsetRanges.exists(_.size > maxRecordsPerPartition.getOrElse(Long.MaxValue))
}

private def nextGroupId(): String = {
Expand Down Expand Up @@ -446,7 +450,7 @@ private[kafka010] class KafkaOffsetReaderConsumer(
KafkaOffsetRange(tp, fromOffset, untilOffset, None)
}.toSeq

if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) {
if (shouldDivvyUpLargePartitions(offsetRangesBase)) {
val fromOffsetsMap =
offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap
val untilOffsetsMap =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
}

if (params.contains(MAX_RECORDS_PER_PARTITION_OPTION_KEY)) {
val p = params(MAX_RECORDS_PER_PARTITION_OPTION_KEY).toLong
if (p <= 0) {
throw new IllegalArgumentException(
s"$MAX_RECORDS_PER_PARTITION_OPTION_KEY must be positive")
}
}

// Validate user-specified Kafka options

if (params.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
Expand Down Expand Up @@ -557,6 +565,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val ENDING_TIMESTAMP_OPTION_KEY = "endingtimestamp"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
private[kafka010] val MAX_RECORDS_PER_PARTITION_OPTION_KEY = "maxrecordsperpartition"
private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger"
private[kafka010] val MIN_OFFSET_PER_TRIGGER = "minoffsetspertrigger"
private[kafka010] val MAX_TRIGGER_DELAY = "maxtriggerdelay"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite {
}
}

def testWithMaxRecordsPerPartition(name: String, maxRecordsPerPartition: Long)(
f: KafkaOffsetRangeCalculator => Unit): Unit = {
val options = new CaseInsensitiveStringMap(
Map("maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava)
test(s"with maxRecordsPerPartition = $maxRecordsPerPartition: $name") {
f(KafkaOffsetRangeCalculator(options))
}
}

def testWithMinPartitionsAndMaxRecordsPerPartition(
name: String,
minPartitions: Int,
maxRecordsPerPartition: Long)(f: KafkaOffsetRangeCalculator => Unit): Unit = {
val options = new CaseInsensitiveStringMap(
Map(
"minPartitions" -> minPartitions.toString,
"maxRecordsPerPartition" -> maxRecordsPerPartition.toString).asJava)
test(
s"with minPartitions = $minPartitions " +
s"and maxRecordsPerPartition = $maxRecordsPerPartition: $name") {
f(KafkaOffsetRangeCalculator(options))
}
}

test("with no minPartition: N TopicPartitions to N offset ranges") {
val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty())
assert(
Expand Down Expand Up @@ -253,6 +277,59 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite {
KafkaOffsetRange(tp3, 7500, 10000, None)))
}

testWithMaxRecordsPerPartition("SPARK-49259: 1 TopicPartition to N offset ranges", 4) { calc =>
assert(
calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 5))) == Seq(KafkaOffsetRange(tp1, 1, 5, None)))

assert(
calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 2))) == Seq(KafkaOffsetRange(tp1, 1, 2, None)))

assert(
calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 6)), executorLocations = Seq("location")) ==
Seq(KafkaOffsetRange(tp1, 1, 3, None), KafkaOffsetRange(tp1, 3, 6, None))
) // location pref not set when maxRecordsPerPartition is set
}

testWithMaxRecordsPerPartition("SPARK-49259: N TopicPartition to N offset ranges", 20) { calc =>
assert(
calc.getRanges(
Seq(
KafkaOffsetRange(tp1, 1, 40),
KafkaOffsetRange(tp2, 1, 50),
KafkaOffsetRange(tp3, 1, 60))) ==
Seq(
KafkaOffsetRange(tp1, 1, 20, None),
KafkaOffsetRange(tp1, 20, 40, None),
KafkaOffsetRange(tp2, 1, 17, None),
KafkaOffsetRange(tp2, 17, 33, None),
KafkaOffsetRange(tp2, 33, 50, None),
KafkaOffsetRange(tp3, 1, 20, None),
KafkaOffsetRange(tp3, 20, 40, None),
KafkaOffsetRange(tp3, 40, 60, None)))
}

testWithMinPartitionsAndMaxRecordsPerPartition(
"SPARK-49259: 1 TopicPartition with low minPartitions value",
1,
20) { calc =>
assert(
calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) ==
Seq(KafkaOffsetRange(tp1, 1, 20, None), KafkaOffsetRange(tp1, 20, 40, None)))
}

testWithMinPartitionsAndMaxRecordsPerPartition(
"SPARK-49259: 1 TopicPartition with high minPartitions value",
4,
20) { calc =>
assert(
calc.getRanges(Seq(KafkaOffsetRange(tp1, 1, 40))) ==
Seq(
KafkaOffsetRange(tp1, 1, 10, None),
KafkaOffsetRange(tp1, 10, 20, None),
KafkaOffsetRange(tp1, 20, 30, None),
KafkaOffsetRange(tp1, 30, 40, None)))
}

private val tp1 = new TopicPartition("t1", 1)
private val tp2 = new TopicPartition("t2", 1)
private val tp3 = new TopicPartition("t3", 1)
Expand Down
14 changes: 14 additions & 0 deletions docs/streaming/structured-streaming-kafka-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,20 @@ The following configurations are optional:
number of Spark tasks will be <strong>approximately</strong> <code>minPartitions</code>. It can be less or more depending on
rounding errors or Kafka partitions that didn't receive any new data.</td>
</tr>
<tr>
<td>maxRecordsPerPartition</td>
<td>long</td>
<td>none</td>
<td>streaming and batch</td>
<td>Limit maximum number of records present in a partition.
By default, Spark has a 1-1 mapping of topicPartitions to Spark partitions consuming from Kafka.
If you set this option, Spark will divvy up Kafka partitions to smaller pieces so that each partition
has upto <code>maxRecordsPerPartition</code> records. When both <code>minPartitions</code> and
<code>maxRecordsPerPartition</code> are set, number of partitions will be <strong>approximately</strong>
max of <code>(recordsPerPartition / maxRecordsPerPartition)</code> and <code>minPartitions</code>. In such case spark
will divvy up partitions based on <code>maxRecordsPerPartition</code> and if the final partition count is less than
<code>minPartitions</code> it will divvy up partitions again based on <code>minPartitions</code>.</td>
</tr>
<tr>
<td>groupIdPrefix</td>
<td>string</td>
Expand Down

0 comments on commit f96a6f8

Please sign in to comment.