diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 07ebbc90fce71..f28076999ddbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -30,6 +30,7 @@ object Optimizer extends RuleExecutor[LogicalPlan] { Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification, + SimplifyFilters, SimplifyCasts) :: Batch("Filter Pushdown", Once, CombineFilters, @@ -90,6 +91,22 @@ object CombineFilters extends Rule[LogicalPlan] { } } +/** + * Removes filters that can be evaluated trivially. This is done either by eliding the filter for + * cases where it will always evaluate to `true`, or substituting a dummy empty relation when the + * filter will always evaluate to `false`. + */ +object SimplifyFilters extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Filter(Literal(true, BooleanType), child) => + child + case Filter(Literal(null, _), child) => + LocalRelation(child.output) + case Filter(Literal(false, BooleanType), child) => + LocalRelation(child.output) + } +} + /** * Pushes [[catalyst.plans.logical.Filter Filter]] operators through * [[catalyst.plans.logical.Project Project]] operators, in-lining any diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3e98bd3ca627d..cf3c06acce5b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -118,9 +118,47 @@ class SQLContext(@transient val sparkContext: SparkContext) TopK :: PartialAggregation :: SparkEquiInnerJoin :: + ParquetOperations :: BasicOperators :: CartesianProduct :: BroadcastNestedLoopJoin :: Nil + + /** + * Used to build table scan operators where complex projection and filtering are done using + * separate physical operators. This function returns the given scan operator with Project and + * Filter nodes added only when needed. For example, a Project operator is only used when the + * final desired output requires complex expressions to be evaluated or when columns can be + * further eliminated out after filtering has been done. + * + * The required attributes for both filtering and expression evaluation are passed to the + * provided `scanBuilder` function so that it can avoid unnecessary column materialization. + */ + def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { + + val projectSet = projectList.flatMap(_.references).toSet + val filterSet = filterPredicates.flatMap(_.references).toSet + val filterCondition = filterPredicates.reduceLeftOption(And) + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + + if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) + filterCondition.map(Filter(_, scan)).getOrElse(scan) + } else { + val scan = scanBuilder((projectSet ++ filterSet).toSeq) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + } + } } @transient diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9eb1032113cd9..8a39ded0a9ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -18,19 +18,15 @@ package org.apache.spark.sql package execution -import org.apache.spark.SparkContext - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.parquet.InsertIntoParquetTable +import org.apache.spark.sql.parquet._ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { - - val sparkContext: SparkContext + self: SQLContext#SparkPlanner => object SparkEquiInnerJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -170,6 +166,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object ParquetOperations extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // TODO: need to support writing to other types of files. Unify the below code paths. + case logical.WriteToFile(path, child) => + val relation = + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None) + InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil + case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + InsertIntoParquetTable(table, planLater(child))(sparkContext) :: Nil + case PhysicalOperation(projectList, filters, relation: parquet.ParquetRelation) => + // TODO: Should be pushing down filters as well. + pruneFilterProject( + projectList, + filters, + ParquetTableScan(_, relation, None)(sparkContext)) :: Nil + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { // TODO: Set @@ -185,14 +200,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. execution.Sort(sortExprs, global = false, planLater(child)) :: Nil - case logical.Project(projectList, r: ParquetRelation) - if projectList.forall(_.isInstanceOf[Attribute]) => - - // simple projection of data loaded from Parquet file - parquet.ParquetTableScan( - projectList.asInstanceOf[Seq[Attribute]], - r, - None)(sparkContext) :: Nil case logical.Project(projectList, child) => execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => @@ -216,12 +223,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case logical.WriteToFile(path, child) => - val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None) - InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil - case p: parquet.ParquetRelation => - parquet.ParquetTableScan(p.output, p, None)(sparkContext) :: Nil case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index af35c919df308..3bcf586662f2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -133,8 +133,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { results } - // TODO: Move this. - SessionState.start(sessionState) /** @@ -191,8 +189,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override val strategies: Seq[Strategy] = Seq( TopK, - ColumnPrunings, - PartitionPrunings, + ParquetOperations, HiveTableScans, DataSinks, Scripts, @@ -217,7 +214,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - // TODO: We are loosing schema here. override lazy val toRdd: RDD[Row] = analyzed match { case NativeCommand(cmd) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3dd0530225be3..141067247d736 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -21,9 +21,8 @@ package hive import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.parquet.{InsertIntoParquetTable, ParquetRelation, ParquetTableScan} trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -43,121 +42,31 @@ trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child))(hiveContext.sparkContext) :: Nil - case _ => Nil - } - } - - object HiveTableScans extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Push attributes into table scan when possible. - case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) => - HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None)(hiveContext) :: Nil - case m: MetastoreRelation => - HiveTableScan(m.output, m, None)(hiveContext) :: Nil case _ => Nil } } /** - * A strategy used to detect filtering predicates on top of a partitioned relation to help - * partition pruning. - * - * This strategy itself doesn't perform partition pruning, it just collects and combines all the - * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator, - * which does the actual pruning work. + * Retrieves data using a HiveTableScan. Partition pruning predicates are also detected and + * applied. */ - object PartitionPrunings extends Strategy { + object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p @ FilteredOperation(predicates, relation: MetastoreRelation) - if relation.isPartitioned => - + case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) => + // Filter out all predicates that only deal with partition keys, these are given to the + // hive table scan operator to be used for partition pruning. val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet - - // Filter out all predicates that only deal with partition keys val (pruningPredicates, otherPredicates) = predicates.partition { _.references.map(_.exprId).subsetOf(partitionKeyIds) - } - - val scan = HiveTableScan( - relation.output, relation, pruningPredicates.reduceLeftOption(And))(hiveContext) - otherPredicates - .reduceLeftOption(And) - .map(Filter(_, scan)) - .getOrElse(scan) :: Nil - - case _ => - Nil - } - } - - /** - * A strategy that detects projects and filters over some relation and applies column pruning if - * possible. Partition pruning is applied first if the relation is partitioned. - */ - object ColumnPrunings extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO(andre): the current mix of HiveRelation and ParquetRelation - // here appears artificial; try to refactor to break it into two - case PhysicalOperation(projectList, predicates, relation: BaseRelation) => - val predicateOpt = predicates.reduceOption(And) - val predicateRefs = predicateOpt.map(_.references).getOrElse(Set.empty) - val projectRefs = projectList.flatMap(_.references) - - // To figure out what columns to preserve after column pruning, we need to consider: - // - // 1. Columns referenced by the project list (order preserved) - // 2. Columns referenced by filtering predicates but not by project list - // 3. Relation output - // - // Then the final result is ((1 union 2) intersect 3) - val prunedCols = (projectRefs ++ (predicateRefs -- projectRefs)).intersect(relation.output) - - val filteredScans = - if (relation.isPartitioned) { // from here on relation must be a [[MetaStoreRelation]] - // Applies partition pruning first for partitioned table - val filteredRelation = predicateOpt.map(logical.Filter(_, relation)).getOrElse(relation) - PartitionPrunings(filteredRelation).view.map(_.transform { - case scan: HiveTableScan => - scan.copy(attributes = prunedCols)(hiveContext) - }) - } else { - val scan = relation match { - case MetastoreRelation(_, _, _) => { - HiveTableScan( - prunedCols, - relation.asInstanceOf[MetastoreRelation], - None)(hiveContext) - } - case ParquetRelation(_, _) => { - ParquetTableScan( - relation.output, - relation.asInstanceOf[ParquetRelation], - None)(hiveContext.sparkContext) - .pruneColumns(prunedCols) - } - } - predicateOpt.map(execution.Filter(_, scan)).getOrElse(scan) :: Nil - } - - if (isSimpleProject(projectList) && prunedCols == projectRefs) { - filteredScans - } else { - filteredScans.view.map(execution.Project(projectList, _)) } + pruneFilterProject( + projectList, + otherPredicates, + HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil case _ => Nil } } - - /** - * Returns true if `projectList` only performs column pruning and does not evaluate other - * complex expressions. - */ - def isSimpleProject(projectList: Seq[NamedExpression]) = { - projectList.forall(_.isInstanceOf[Attribute]) - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index bb65c91e2a651..d2f8e5df5b29e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -33,7 +33,7 @@ class PruningSuite extends HiveComparisonTest { createPruningTest("Column pruning: with partitioned table", "SELECT key FROM srcpart WHERE ds = '2008-04-08' LIMIT 3", Seq("key"), - Seq("key", "ds"), + Seq("key"), Seq( Seq("2008-04-08", "11"), Seq("2008-04-08", "12"))) @@ -97,7 +97,7 @@ class PruningSuite extends HiveComparisonTest { createPruningTest("Partition pruning: with filter on string partition key", "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08'", Seq("value", "hr"), - Seq("value", "hr", "ds"), + Seq("value", "hr"), Seq( Seq("2008-04-08", "11"), Seq("2008-04-08", "12"))) @@ -113,14 +113,14 @@ class PruningSuite extends HiveComparisonTest { createPruningTest("Partition pruning: left only 1 partition", "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08' AND hr < 12", Seq("value", "hr"), - Seq("value", "hr", "ds"), + Seq("value", "hr"), Seq( Seq("2008-04-08", "11"))) createPruningTest("Partition pruning: all partitions pruned", "SELECT value, hr FROM srcpart1 WHERE ds = '2014-01-27' AND hr = 11", Seq("value", "hr"), - Seq("value", "hr", "ds"), + Seq("value", "hr"), Seq.empty) createPruningTest("Partition pruning: pruning with both column key and partition key", @@ -147,8 +147,8 @@ class PruningSuite extends HiveComparisonTest { (columnNames, partValues) }.head - assert(actualOutputColumns sameElements expectedOutputColumns, "Output columns mismatch") - assert(actualScannedColumns sameElements expectedScannedColumns, "Scanned columns mismatch") + assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") + assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") assert( actualPartValues.length === expectedPartValues.length,