Skip to content

Commit

Permalink
Exclude the processing time in records.hasNext from the serialization…
Browse files Browse the repository at this point in the history
… time estimation

Signed-off-by: Jihoon Son <ghoonson@gmail.com>
  • Loading branch information
jihoonson committed Jul 12, 2024
1 parent be34c6a commit 6cf48ab
Showing 1 changed file with 25 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,26 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
val diskBlockObjectWriters = new mutable.HashMap[Int, (Int, DiskBlockObjectWriter)]()

override def write(records: Iterator[Product2[K, V]]): Unit = {
class TimeTrackingIterator extends Iterator[Product2[K, V]] {
var iterateTimeNs: Long = 0L

override def hasNext: Boolean = {
val start = System.nanoTime()
val ret = records.hasNext
iterateTimeNs += System.nanoTime() - start
ret
}

override def next(): Product2[K, V] = {
val start = System.nanoTime()
val ret = records.next
iterateTimeNs += System.nanoTime() - start
ret
}
}

val timeTrackingIterator = new TimeTrackingIterator

withResource(new NvtxRange("ThreadedWriter.write", NvtxColor.RED)) { _ =>
withResource(new NvtxRange("compute", NvtxColor.GREEN)) { _ =>
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
Expand All @@ -283,7 +303,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
numPartitions)
try {
var openTimeNs = 0L
val partLengths = if (!records.hasNext) {
val partLengths = if (!timeTrackingIterator.hasNext) {
commitAllPartitions(mapOutputWriter, true /*empty checksum*/)
} else {
// per reduce partition id
Expand All @@ -307,13 +327,10 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
val writeFutures = new mutable.Queue[Future[Unit]]
val writeTimeStart: Long = System.nanoTime()
val recordWriteTime: AtomicLong = new AtomicLong(0L)
var computeTime: Long = 0L
try {
while (records.hasNext) {
while (timeTrackingIterator.hasNext) {
// get the record
val computeStartTime = System.nanoTime()
val record = records.next()
computeTime += System.nanoTime() - computeStartTime
val record = timeTrackingIterator.next()
val key = record._1
val value = record._2
val reducePartitionId: Int = partitioner.getPartition(key)
Expand Down Expand Up @@ -373,7 +390,8 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](

// writeTime is the amount of time it took to push bytes through the stream
// minus the amount of time it took to get the batch from the upstream execs
val writeTimeNs = (System.nanoTime() - writeTimeStart) - computeTime
val writeTimeNs = (System.nanoTime() - writeTimeStart) -
timeTrackingIterator.iterateTimeNs

val combineTimeStart = System.nanoTime()
val pl = writePartitionedData(mapOutputWriter)
Expand Down

0 comments on commit 6cf48ab

Please sign in to comment.