Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13353] [SQL] fast serialization for collecting DataFrame/Dataset #11664

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1803,14 +1803,14 @@ class Dataset[T] private[sql](
*/
def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
java.util.Arrays.asList(values : _*)
}
}

private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
}

if (needCallback) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable.ArrayBuffer
Expand All @@ -34,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
import org.apache.spark.sql.types.DataType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -220,7 +222,61 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[InternalRow] = {
execute().map(_.copy()).collect()
// Packing the UnsafeRows into byte array for faster serialization.
// The byte arrays are in the following format:
// [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
val byteArrayRdd = execute().mapPartitionsInternal { iter =>
new Iterator[Array[Byte]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also find this more understandable if you just write it imperatively within the map partitions; something like

execute().mapPartitionsInternal { iter =>
  while (iter.hasNext) {
    // write each row to a buffer
  }
  Iterator(buffer)
}

private var row: UnsafeRow = _
override def hasNext: Boolean = row != null || iter.hasNext
override def next: Array[Byte] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next() rather than next, since this is not side effect free

var cap = 1 << 20 // 1 MB
if (row != null) {
// the buffered row could be larger than default buffer size
cap = Math.max(cap, 4 + row.getSizeInBytes + 4) // reverse 4 bytes for ending mark (-1).
}
val buffer = ByteBuffer.allocate(cap)
if (row != null) {
buffer.putInt(row.getSizeInBytes)
row.writeTo(buffer)
row = null
}
while (iter.hasNext) {
row = iter.next().asInstanceOf[UnsafeRow]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we always taking UnsafeRow now?

// Reserve last 4 bytes for ending mark
if (4 + row.getSizeInBytes + 4 <= buffer.remaining()) {
buffer.putInt(row.getSizeInBytes)
row.writeTo(buffer)
row = null
} else {
buffer.putInt(-1)
return buffer.array()
}
}
buffer.putInt(-1)
// copy the used bytes to make it smaller
val bytes = new Array[Byte](buffer.limit())
System.arraycopy(buffer.array(), 0, bytes, 0, buffer.limit())
bytes
}
}
}
// Collect the byte arrays back to driver, then decode them as UnsafeRows.
val nFields = schema.length
byteArrayRdd.collect().flatMap { bytes =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this block would be more readable if we just write it imperatively, e.g.

val results = new ArrayBuffer

byteArrayRdd.collect().foreach { bytes =>
  var sizeOfNextRow = bytes.getInt()
  while (sizeOfNextRow >= 0) {
    val row = new UnsafeRow(nFields)
    row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes)
    buffer.position(buffer.position() + sizeOfNextRow)
    results += row
    sizeOfNextRow = bytes.getInt()
  }
}
results.toArray

val buffer = ByteBuffer.wrap(bytes)
new Iterator[InternalRow] {
private var sizeInBytes = buffer.getInt()
override def hasNext: Boolean = sizeInBytes >= 0
override def next: InternalRow = {
val row = new UnsafeRow(nFields)
row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes)
buffer.position(buffer.position() + sizeInBytes)
sizeInBytes = buffer.getInt()
row
}
}
}
}

/**
Expand Down