From 3af423c92f117b5dd4dc6832dc50911cedb29abc Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 7 May 2015 20:59:42 -0700 Subject: [PATCH] [SPARK-6986] [SQL] Use Serializer2 in more cases. With https://github.com/apache/spark/commit/0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases. Author: Yin Huai Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits: 53a5eaa [Yin Huai] Josh's comments. 487f540 [Yin Huai] Use BufferedOutputStream. 8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join. c7e2129 [Yin Huai] Update tests. 4513d13 [Yin Huai] Use Serializer2 in more places. --- .../apache/spark/sql/execution/Exchange.scala | 23 ++---- .../sql/execution/SparkSqlSerializer2.scala | 74 ++++++++++++------- .../execution/SparkSqlSerializer2Suite.scala | 30 ++++---- 3 files changed, 69 insertions(+), 58 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 5b2e46962cd3b..f0d54cd6cd94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -84,18 +84,8 @@ case class Exchange( def serializer( keySchema: Array[DataType], valueSchema: Array[DataType], + hasKeyOrdering: Boolean, numPartitions: Int): Serializer = { - // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out - // through write(key) and then write(value) instead of write((key, value)). Because - // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use - // it when spillToMergeableFile in ExternalSorter will be used. - // So, we will not use SparkSqlSerializer2 when - // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater - // then the bypassMergeThreshold; or - // - newOrdering is defined. - val cannotUseSqlSerializer2 = - (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty - // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. val noField = @@ -104,14 +94,13 @@ case class Exchange( val useSqlSerializer2 = child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - !cannotUseSqlSerializer2 && // Safe to use Serializer2. SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. !noField val serializer = if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema) + new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) @@ -154,7 +143,8 @@ case class Exchange( } val keySchema = expressions.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray - shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.setSerializer( + serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions)) shuffled.map(_._2) @@ -179,7 +169,8 @@ case class Exchange( new ShuffledRDD[Row, Null, Null](rdd, part) } val keySchema = child.output.map(_.dataType).toArray - shuffled.setSerializer(serializer(keySchema, null, numPartitions)) + shuffled.setSerializer( + serializer(keySchema, null, newOrdering.nonEmpty, numPartitions)) shuffled.map(_._1) @@ -199,7 +190,7 @@ case class Exchange( val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) val valueSchema = child.output.map(_.dataType).toArray - shuffled.setSerializer(serializer(null, valueSchema, 1)) + shuffled.setSerializer(serializer(null, valueSchema, false, 1)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 35ad987eb1a63..256d527d7b636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.serializer._ import org.apache.spark.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow} import org.apache.spark.sql.types._ /** @@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream( out: OutputStream) extends SerializationStream with Logging { - val rowOut = new DataOutputStream(out) - val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) + private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] @@ -86,31 +86,44 @@ private[sql] class Serializer2SerializationStream( private[sql] class Serializer2DeserializationStream( keySchema: Array[DataType], valueSchema: Array[DataType], + hasKeyOrdering: Boolean, in: InputStream) extends DeserializationStream with Logging { - val rowIn = new DataInputStream(new BufferedInputStream(in)) + private val rowIn = new DataInputStream(new BufferedInputStream(in)) + + private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = { + if (schema == null) { + () => null + } else { + if (hasKeyOrdering) { + // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row. + () => new GenericMutableRow(schema.length) + } else { + // It is safe to reuse the mutable row. + val mutableRow = new SpecificMutableRow(schema) + () => mutableRow + } + } + } - val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null - val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null - val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) - val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + // Functions used to return rows for key and value. + private val getKey = rowGenerator(keySchema) + private val getValue = rowGenerator(valueSchema) + // Functions used to read a serialized row from the InputStream and deserialize it. + private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn) + private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn) override def readObject[T: ClassTag](): T = { - readKeyFunc() - readValueFunc() - - (key, value).asInstanceOf[T] + (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T] } override def readKey[T: ClassTag](): T = { - readKeyFunc() - key.asInstanceOf[T] + readKeyFunc(getKey()).asInstanceOf[T] } override def readValue[T: ClassTag](): T = { - readValueFunc() - value.asInstanceOf[T] + readValueFunc(getValue()).asInstanceOf[T] } override def close(): Unit = { @@ -118,9 +131,10 @@ private[sql] class Serializer2DeserializationStream( } } -private[sql] class ShuffleSerializerInstance( +private[sql] class SparkSqlSerializer2Instance( keySchema: Array[DataType], - valueSchema: Array[DataType]) + valueSchema: Array[DataType], + hasKeyOrdering: Boolean) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance( } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, s) + new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s) } } @@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance( * The schema of keys is represented by `keySchema` and that of values is represented by * `valueSchema`. */ -private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) +private[sql] class SparkSqlSerializer2( + keySchema: Array[DataType], + valueSchema: Array[DataType], + hasKeyOrdering: Boolean) extends Serializer with Logging with Serializable{ - def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) + def newInstance(): SerializerInstance = + new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers @@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 { */ def createDeserializationFunction( schema: Array[DataType], - in: DataInputStream, - mutableRow: SpecificMutableRow): () => Unit = { - () => { - // If the schema is null, the returned function does nothing when it get called. - if (schema != null) { + in: DataInputStream): (MutableRow) => Row = { + if (schema == null) { + (mutableRow: MutableRow) => null + } else { + (mutableRow: MutableRow) => { var i = 0 while (i < schema.length) { schema(i) match { @@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 { } i += 1 } + + mutableRow } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 27f063d73a9a9..15337c4045436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll table("shuffle").collect()) } + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } + test("value schema is null") { val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) @@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { override def beforeAll(): Unit = { super.beforeAll() // Sort merge will not be triggered. - sql("set spark.sql.shuffle.partitions = 200") - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + val bypassMergeThreshold = + sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") } } /** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - // We are expecting SparkSqlSerializer. - override val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] - override def beforeAll(): Unit = { super.beforeAll() // To trigger the sort merge. - sql("set spark.sql.shuffle.partitions = 201") + val bypassMergeThreshold = + sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") } }