diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1ea7db0388689..b5079cf2763ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1797,14 +1797,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) } if (needCallback) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a92c99e06ff43..e04683c499a32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.Logging -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, Logging, SparkEnv} +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -220,7 +221,47 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Runs this query returning the result as an array. */ def executeCollect(): Array[InternalRow] = { - execute().map(_.copy()).collect() + // Packing the UnsafeRows into byte array for faster serialization. + // The byte arrays are in the following format: + // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + // + // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + // compressed. + val byteArrayRdd = execute().mapPartitionsInternal { iter => + val buffer = new Array[Byte](4 << 10) // 4K + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(codec.compressedOutputStream(bos)) + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + out.writeInt(row.getSizeInBytes) + row.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + out.close() + Iterator(bos.toByteArray) + } + + // Collect the byte arrays back to driver, then decode them as UnsafeRows. + val nFields = schema.length + val results = ArrayBuffer[InternalRow]() + + byteArrayRdd.collect().foreach { bytes => + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + results += row + sizeOfNextRow = ins.readInt() + } + } + results.toArray } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 2c4b4f80ff9ed..b1987c690811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -29,7 +29,9 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) - sparkContext.parallelize(Seq(row)) + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow = unsafeProj(row).copy() + sparkContext.parallelize(Seq(unsafeRow)) } override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 2d3e34d0e1292..9f33e4ab62298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -428,4 +428,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ benchmark.run() } + + ignore("collect") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect", N) + benchmark.addCase("collect 1 million") { iter => + sqlContext.range(N).collect() + } + benchmark.addCase("collect 2 millions") { iter => + sqlContext.range(N * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + sqlContext.range(N * 4).collect() + } + benchmark.run() + + /** + * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect 1 million 775 / 1170 1.4 738.9 1.0X + collect 2 millions 1153 / 1758 0.9 1099.3 0.7X + collect 4 millions 4451 / 5124 0.2 4244.9 0.2X + */ + } }