Skip to content

Commit

Permalink
address comments using sort
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 10, 2015
1 parent 068c35d commit 645c70b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,6 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
def clustering: Set[Expression] = ordering.map(_.child).toSet
}

/**
* Represents data where tuples have been ordered according to the `clustering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
* [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted
* by the expressions.
*/
case class ClusteredOrderedDistribution(clustering: Seq[Expression])
extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
}

sealed trait Partitioning {
/** Returns the number of partitions that the data is split across */
val numPartitions: Int
Expand Down Expand Up @@ -177,40 +162,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition. And rows within the same partition are sorted by the expressions.
*/
case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {

override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

private[this] lazy val clusteringSet = expressions.toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredOrderedDistribution(requiredClustering) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case ClusteredDistribution(requiredClustering) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}

override def compatibleWith(other: Partitioning) = other match {
case BroadcastPartitioning => true
case h: HashSortedPartitioning if h == this => true
case _ => false
}

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

/**
* Represents a partitioning where rows are split across partitions based on some total ordering of
* the expressions specified in `ordering`. When data is partitioned in this manner the following
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case HashSortedPartitioning(expressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
val sortingExpressions = expressions.zipWithIndex.map {
case (exp, index) =>
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
}
val ordering = new RowOrdering(sortingExpressions, child.output)
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
Expand Down Expand Up @@ -184,6 +161,11 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child

// Check if the partitioning we want to ensure is the same as the child's output
// partitioning. If so, we do not need to add the Exchange operator.
def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan =
if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child

if (meetsRequirements && compatible) {
operator
} else {
Expand All @@ -195,14 +177,18 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), child) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
case (ClusteredOrderedDistribution(clustering), child) =>
addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child)
case (OrderedDistribution(ordering), child) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, child) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
}
operator.withNewChildren(repartitionedChildren)
val reorderedChildren = operator.requiredInPartitionOrdering.zip(repartitionedChildren).map {
case (Nil, child) =>
child
case (ordering, child) =>
addSortIfNecessary(ordering, child)
}
operator.withNewChildren(reorderedChildren)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)

/** Specifies how data is ordered in each partition. */
def outputOrdering: Seq[SortOrder] = Nil

/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Runs this query returning the result as an RDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,25 @@ case class SortMergeJoin(

override def output: Seq[Attribute] = left.output ++ right.output

override def outputPartitioning: Partitioning = HashSortedPartitioning(leftKeys, 0)
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] =
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map {
case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending)
}
private val ordering: RowOrdering = new RowOrdering(orders, left.output)

private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map {
k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending)
}

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left)

override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
"auto_sortmerge_join_7",
"auto_sortmerge_join_8",
"auto_sortmerge_join_9",
"correlationoptimizer1",
"correlationoptimizer10",
"correlationoptimizer11",
"correlationoptimizer13",
"correlationoptimizer14",
"correlationoptimizer15",
"correlationoptimizer2",
"correlationoptimizer3",
"correlationoptimizer4",
"correlationoptimizer6",
"correlationoptimizer7",
"correlationoptimizer8",
"correlationoptimizer9",
"join0",
"join1",
"join10",
Expand Down

0 comments on commit 645c70b

Please sign in to comment.