Skip to content

Commit

Permalink
address comments: add comments, do sort in shuffle, and others
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 12, 2015
1 parent 3af6ba5 commit 078d69b
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 64 deletions.
15 changes: 8 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ private[spark] object SQLConf {
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val CODEGEN_ENABLED = "spark.sql.codegen"
Expand All @@ -46,6 +45,7 @@ private[spark] object SQLConf {
// Options that control which operators can be chosen by the query planner. These should be
// considered hints and may be ignored by future versions of Spark SQL.
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"

// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
Expand Down Expand Up @@ -123,6 +123,13 @@ private[sql] class SQLConf extends Serializable {
/** When true the planner will use the external sort, which may spill to disk. */
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean

/**
* Sort merge join would sort the two side of join first, and then iterate both sides together
* only once to get all matches. Using sort merge join can save a lot of memory usage compared
* to HashJoin.
*/
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean

/**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
Expand All @@ -144,12 +151,6 @@ private[sql] class SQLConf extends Serializable {
private[spark] def autoBroadcastJoinThreshold: Int =
getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt

/**
* By default not choose sort merge join.
*/
private[spark] def autoSortMergeJoin: Boolean =
getConf(AUTO_SORTMERGEJOIN, false.toString).toBoolean

/**
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
* it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ case class Aggregate(

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)

override def outputOrdering: Seq[SortOrder] = Nil

/**
* An aggregate that needs to be computed for each row in a group.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ import org.apache.spark.util.MutablePair
* :: DeveloperApi ::
*/
@DeveloperApi
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
case class Exchange(
newPartitioning: Partitioning,
child: SparkPlan,
sort: Boolean = false)
extends UnaryNode {

override def outputPartitioning: Partitioning = newPartitioning

Expand Down Expand Up @@ -68,7 +72,16 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
}
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
val shuffled = sort match {
case false => new ShuffledRDD[Row, Row, Row](rdd, part)
case true =>
val sortingExpressions = expressions.zipWithIndex.map {
case (exp, index) =>
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
}
val ordering = new RowOrdering(sortingExpressions, child.output)
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

Expand Down Expand Up @@ -158,37 +171,35 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl

// Check if the partitioning we want to ensure is the same as the child's output
// partitioning. If so, we do not need to add the Exchange operator.
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child

// Check if the ordering we want to ensure is the same as the child's output
// ordering. If so, we do not need to add the Sort operator.
def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan =
if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child
def addExchangeIfNecessary(
partitioning: Partitioning,
child: SparkPlan,
rowOrdering: Option[Ordering[Row]] = None): SparkPlan =
if (child.outputPartitioning != partitioning) {
Exchange(partitioning, child, sort = child.outputOrdering != rowOrdering)
} else {
child
}

if (meetsRequirements && compatible) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
case (AllTuples, child) =>
val repartitionedChildren = operator.requiredChildDistribution.zip(
operator.children.zip(operator.requiredChildOrdering)
).map {
case (AllTuples, (child, _)) =>
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), child) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
case (OrderedDistribution(ordering), child) =>
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
case (OrderedDistribution(ordering), (child, _)) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, child) => child
case (UnspecifiedDistribution, (child, _)) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
}
val reorderedChildren =
operator.requiredInPartitionOrdering.zip(repartitionedChildren).map {
case (Nil, child) => child
case (ordering, child) =>
addSortIfNecessary(ordering, child)
}
operator.withNewChildren(reorderedChildren)
operator.withNewChildren(repartitionedChildren)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
Seq.fill(children.size)(UnspecifiedDistribution)

/** Specifies how data is ordered in each partition. */
def outputOrdering: Seq[SortOrder] = Nil
def outputOrdering: Option[Ordering[Row]] = None

/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None)

/**
* Runs this query returning the result as an RDD.
Expand Down Expand Up @@ -183,7 +183,6 @@ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] {
self: Product =>
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

// for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoSortMergeJoin =>
if sqlContext.conf.sortMergeJoinEnabled =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
val resuableProjection = buildProjection()
iter.map(resuableProjection)
}

/**
* outputOrdering of Project is not always same with child's outputOrdering if the certain
* key is pruned, however, if the key is pruned then we must not require child using this
* ordering from upper layer, only if the ordering would not be changed by a negative, there
* would be a way to keep the ordering.
* TODO: we may utilize this feature later to avoid some unnecessary sorting.
*/
override def outputOrdering: Option[Ordering[Row]] = None
}

/**
Expand All @@ -55,6 +64,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def execute(): RDD[Row] = child.execute().mapPartitions { iter =>
iter.filter(conditionEvaluator)
}

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
}

/**
Expand All @@ -70,8 +81,6 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
}

override def outputOrdering: Seq[SortOrder] = Nil
}

/**
Expand Down Expand Up @@ -104,6 +113,8 @@ case class Limit(limit: Int, child: SparkPlan)
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering

override def executeCollect(): Array[Row] = child.executeTake(limit)

override def execute(): RDD[Row] = {
Expand Down Expand Up @@ -149,7 +160,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Pick num splits based on |limit|.
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)

override def outputOrdering: Seq[SortOrder] = sortOrder
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
}

/**
Expand All @@ -176,7 +187,7 @@ case class Sort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Seq[SortOrder] = sortOrder
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
}

/**
Expand Down Expand Up @@ -208,7 +219,7 @@ case class ExternalSort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Seq[SortOrder] = sortOrder
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import java.util.NoSuchElementException

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -47,16 +49,16 @@ case class SortMergeJoin(
private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map {
case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending)
}
private val ordering: RowOrdering = new RowOrdering(orders, left.output)
// this is to manually construct an ordering that can be used to compare keys from both sides
private val keyOrdering: RowOrdering = new RowOrdering(orders)

private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map {
k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending)
}
private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] =
newOrdering(keys.map(SortOrder(_, Ascending)), side.output)

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left)
override def outputOrdering: Option[Ordering[Row]] = Some(requiredOrders(leftKeys, left))

override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil
override def requiredChildOrdering: Seq[Option[Ordering[Row]]] =
Some(requiredOrders(leftKeys, left)) :: Some(requiredOrders(rightKeys, right)) :: Nil

@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
Expand All @@ -78,24 +80,28 @@ case class SortMergeJoin(
private[this] var stop: Boolean = false
private[this] var matchKey: Row = _

// initialize iterator
initialize()

override final def hasNext: Boolean = nextMatchingPair()

override final def next(): Row = {
if (hasNext) {
// we are using the buffered right rows and run down left iterator
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
rightPosition += 1
if (rightPosition >= rightMatches.size) {
rightPosition = 0
fetchLeft()
if (leftElement == null || ordering.compare(leftKey, matchKey) != 0) {
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
stop = false
rightMatches = null
}
}
joinedRow
} else {
// according to Scala doc, this is undefined
null
// no more result
throw new NoSuchElementException
}
}

Expand All @@ -121,33 +127,36 @@ case class SortMergeJoin(
fetchLeft()
fetchRight()
}
// initialize iterator
initialize()

/**
* Searches the left/right iterator for the next rows that matches.
* Searches the right iterator for the next rows that have matches in left side, and store
* them in a buffer.
*
* @return true if the search is successful, and false if the left/right iterator runs out
* of tuples.
* @return true if the search is successful, and false if the right iterator runs out of
* tuples.
*/
private def nextMatchingPair(): Boolean = {
if (!stop && rightElement != null) {
// run both side to get the first match pair
while (!stop && leftElement != null && rightElement != null) {
stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull
if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) {
val comparing = keyOrdering.compare(leftKey, rightKey)
// for inner join, we need to filter those null keys
stop = comparing == 0 && !leftKey.anyNull
if (comparing > 0 || rightKey.anyNull) {
fetchRight()
} else if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) {
} else if (comparing < 0 || leftKey.anyNull) {
fetchLeft()
}
}
rightMatches = new CompactBuffer[Row]()
if (stop) {
stop = false
// iterate the right side to buffer all rows that matches
// as the records should be ordered, exit when we meet the first that not match
while (!stop && rightElement != null) {
rightMatches += rightElement
fetchRight()
// exit loop when run out of right matches
stop = ordering.compare(leftKey, rightKey) != 0
stop = keyOrdering.compare(leftKey, rightKey) != 0
}
if (rightMatches.size > 0) {
rightPosition = 0
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("join operator selection") {
cacheManager.clearCache()

val AUTO_SORTMERGEJOIN: Boolean = conf.autoSortMergeJoin
val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
conf.setConf("spark.sql.autoSortMergeJoin", "false")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
Expand Down Expand Up @@ -103,7 +103,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
} finally {
conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString)
conf.setConf("spark.sql.autoSortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive
class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
override def beforeAll() {
super.beforeAll()
TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "true")
TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
}

override def afterAll() {
TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "false")
TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
super.afterAll()
}

Expand Down

0 comments on commit 078d69b

Please sign in to comment.