From c0d9b726f210f67e290c790c4c4165eae45fc8d3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 7 Aug 2014 20:23:20 -0700 Subject: [PATCH] Avoid creating a HadoopRDD per partition. Add dirty hacks to retrieve partition values from the InputSplit. --- .../spark/sql/parquet/ParquetRelation.scala | 7 +- .../sql/parquet/ParquetTableOperations.scala | 74 ++++++++++++++----- .../spark/sql/hive/HiveStrategies.scala | 43 +++++------ 3 files changed, 79 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b3bae5db0edbc..b3a12cdc74035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} private[sql] case class ParquetRelation( path: String, @transient conf: Option[Configuration], - @transient sqlContext: SQLContext) + @transient sqlContext: SQLContext, + partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { self: Product => @@ -60,7 +61,9 @@ private[sql] case class ParquetRelation( .getSchema /** Attributes */ - override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) + override val output = + partitioningAttributes ++ + ParquetTypesConverter.readSchemaFromFile(new Path(path.split(",").head), conf) override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 759a2a586b926..68141ce83c796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -42,7 +42,7 @@ import parquet.schema.MessageType import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -59,11 +59,18 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map { a => - relation.output - .find(o => o.exprId == a.exprId) - .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) - } + val normalOutput = + attributes + .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) + .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + + val partOutput = + attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + + def output = partOutput ++ normalOutput + + assert(normalOutput.size + partOutput.size == attributes.size, + s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext @@ -71,16 +78,19 @@ case class ParquetTableScan( ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) val conf: Configuration = ContextUtil.getConfiguration(job) - val qualifiedPath = { - val path = new Path(relation.path) - path.getFileSystem(conf).makeQualified(path) + + relation.path.split(",").foreach { curPath => + val qualifiedPath = { + val path = new Path(curPath) + path.getFileSystem(conf).makeQualified(path) + } + NewFileInputFormat.addInputPath(job, qualifiedPath) } - NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) + ParquetTypesConverter.convertToString(normalOutput)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -96,13 +106,41 @@ case class ParquetTableScan( ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) } - sc.newAPIHadoopRDD( - conf, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row]) - .map(_._2) - .filter(_ != null) // Parquet's record filters may produce null values + val baseRDD = + new org.apache.spark.rdd.NewHadoopRDD( + sc, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + conf) + + if (partOutput.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => + val partValue = "([^=]+)=([^=]+)".r + val partValues = + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + .getPath + .toString + .split("/") + .flatMap { + case partValue(key, value) => Some(key -> value) + case _ => None + }.toMap + + val partitionRowValues = + partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + + def hasNext = iter.hasNext + + def next() = joinedRow.withRight(iter.next()._2) + } + } + } else { + baseRDD.map(_._2) + }.filter(_ != null) // Parquet's record filters may produce null values } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index de301fc8d6254..f2be1eae410ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.parquet.ParquetTableScan +import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} import scala.collection.JavaConversions._ @@ -51,6 +51,13 @@ private[hive] trait HiveStrategies { implicit class LogicalPlanHacks(s: SchemaRDD) { def lowerCase = new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + + def addPartitioningAttributes(attrs: Seq[Attribute]) = + new SchemaRDD( + s.sqlContext, + s.logicalPlan transform { + case p: ParquetRelation => p.copy(partitioningAttributes = attrs) + }) } implicit class PhysicalPlanHacks(s: SparkPlan) { @@ -76,8 +83,7 @@ private[hive] trait HiveStrategies { }).reduceOption(And).getOrElse(Literal(true)) val unresolvedProjection = projectList.map(_ transform { - // Handle non-partitioning columns - case a: AttributeReference if !partitionKeyIds.contains(a.exprId) => UnresolvedAttribute(a.name) + case a: AttributeReference => UnresolvedAttribute(a.name) }) if (relation.hiveQlTable.isPartitioned) { @@ -109,28 +115,15 @@ private[hive] trait HiveStrategies { pruningCondition(inputData) } - org.apache.spark.sql.execution.Union( - partitions.par.map { p => - val partValues = p.getValues() - val internalProjection = unresolvedProjection.map(_ transform { - // Handle partitioning columns - case a: AttributeReference if partitionKeyIds.contains(a.exprId) => { - val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) - val key = relation.partitionKeys(idx) - - Alias(Cast(Literal(partValues.get(idx), StringType), key.dataType), a.name)() - } - }) - - hiveContext - .parquetFile(p.getLocation) - .lowerCase - .where(unresolvedOtherPredicates) - .select(internalProjection:_*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) - }.seq) :: Nil + hiveContext + .parquetFile(partitions.map(_.getLocation).mkString(",")) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)):: Nil } else { hiveContext .parquetFile(relation.hiveQlTable.getDataLocation.getPath)