-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13353] [SQL] fast serialization for collecting DataFrame/Dataset #11664
Changes from 1 commit
ac1a40b
c5bca23
a859392
4f9cf91
f12216f
5f00d67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.sql.execution | ||
|
||
import java.nio.ByteBuffer | ||
import java.util.concurrent.atomic.AtomicBoolean | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
@@ -34,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan | |
import org.apache.spark.sql.catalyst.plans.physical._ | ||
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} | ||
import org.apache.spark.sql.types.DataType | ||
import org.apache.spark.unsafe.Platform | ||
import org.apache.spark.util.ThreadUtils | ||
|
||
/** | ||
|
@@ -220,7 +222,61 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ | |
* Runs this query returning the result as an array. | ||
*/ | ||
def executeCollect(): Array[InternalRow] = { | ||
execute().map(_.copy()).collect() | ||
// Packing the UnsafeRows into byte array for faster serialization. | ||
// The byte arrays are in the following format: | ||
// [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] | ||
val byteArrayRdd = execute().mapPartitionsInternal { iter => | ||
new Iterator[Array[Byte]] { | ||
private var row: UnsafeRow = _ | ||
override def hasNext: Boolean = row != null || iter.hasNext | ||
override def next: Array[Byte] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. next() rather than next, since this is not side effect free |
||
var cap = 1 << 20 // 1 MB | ||
if (row != null) { | ||
// the buffered row could be larger than default buffer size | ||
cap = Math.max(cap, 4 + row.getSizeInBytes + 4) // reverse 4 bytes for ending mark (-1). | ||
} | ||
val buffer = ByteBuffer.allocate(cap) | ||
if (row != null) { | ||
buffer.putInt(row.getSizeInBytes) | ||
row.writeTo(buffer) | ||
row = null | ||
} | ||
while (iter.hasNext) { | ||
row = iter.next().asInstanceOf[UnsafeRow] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we always taking UnsafeRow now? |
||
// Reserve last 4 bytes for ending mark | ||
if (4 + row.getSizeInBytes + 4 <= buffer.remaining()) { | ||
buffer.putInt(row.getSizeInBytes) | ||
row.writeTo(buffer) | ||
row = null | ||
} else { | ||
buffer.putInt(-1) | ||
return buffer.array() | ||
} | ||
} | ||
buffer.putInt(-1) | ||
// copy the used bytes to make it smaller | ||
val bytes = new Array[Byte](buffer.limit()) | ||
System.arraycopy(buffer.array(), 0, bytes, 0, buffer.limit()) | ||
bytes | ||
} | ||
} | ||
} | ||
// Collect the byte arrays back to driver, then decode them as UnsafeRows. | ||
val nFields = schema.length | ||
byteArrayRdd.collect().flatMap { bytes => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this block would be more readable if we just write it imperatively, e.g. val results = new ArrayBuffer
byteArrayRdd.collect().foreach { bytes =>
var sizeOfNextRow = bytes.getInt()
while (sizeOfNextRow >= 0) {
val row = new UnsafeRow(nFields)
row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes)
buffer.position(buffer.position() + sizeOfNextRow)
results += row
sizeOfNextRow = bytes.getInt()
}
}
results.toArray |
||
val buffer = ByteBuffer.wrap(bytes) | ||
new Iterator[InternalRow] { | ||
private var sizeInBytes = buffer.getInt() | ||
override def hasNext: Boolean = sizeInBytes >= 0 | ||
override def next: InternalRow = { | ||
val row = new UnsafeRow(nFields) | ||
row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.position(), sizeInBytes) | ||
buffer.position(buffer.position() + sizeInBytes) | ||
sizeInBytes = buffer.getInt() | ||
row | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i also find this more understandable if you just write it imperatively within the map partitions; something like