Skip to content

Commit

Permalink
Merge pull request apache#571 from holdenk/switchtobinarysearch.
Browse files Browse the repository at this point in the history
SPARK-1072 Use binary search when needed in RangePartioner

Author: Holden Karau <holden@pigscanfly.ca>

Closes apache#571 and squashes the following commits:

f31a2e1 [Holden Karau] Swith to using CollectionsUtils in Partitioner
4c7a0c3 [Holden Karau] Add CollectionsUtil as suggested by aarondav
7099962 [Holden Karau] Add the binary search to only init once
1bef01d [Holden Karau] CR feedback
a21e097 [Holden Karau] Use binary search if we have more than 1000 elements inside of RangePartitioner
  • Loading branch information
holdenk authored and rxin committed Feb 11, 2014
1 parent ba38d98 commit b0dab1b
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
21 changes: 18 additions & 3 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.util.CollectionsUtils
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -118,12 +119,26 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V](

def numPartitions = partitions

private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

def getPartition(key: Any): Int = {
// TODO: Use a binary search here if number of partitions is large
val k = key.asInstanceOf[K]
var partition = 0
while (partition < rangeBounds.length && k > rangeBounds(partition)) {
partition += 1
if (rangeBounds.length < 1000) {
// If we have less than 100 partitions naive search
while (partition < rangeBounds.length && k > rangeBounds(partition)) {
partition += 1
}
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
}
if (partition > rangeBounds.length) {
partition = rangeBounds.length
}
}
if (ascending) {
partition
Expand Down
46 changes: 46 additions & 0 deletions core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.util

import scala.Array
import scala.reflect._

object CollectionsUtils {
def makeBinarySearch[K <% Ordered[K] : ClassTag] : (Array[K], K) => Int = {
classTag[K] match {
case ClassTag.Float =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Float]], x.asInstanceOf[Float])
case ClassTag.Double =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Double]], x.asInstanceOf[Double])
case ClassTag.Byte =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Byte]], x.asInstanceOf[Byte])
case ClassTag.Char =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Char]], x.asInstanceOf[Char])
case ClassTag.Short =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Short]], x.asInstanceOf[Short])
case ClassTag.Int =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Int]], x.asInstanceOf[Int])
case ClassTag.Long =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Long]], x.asInstanceOf[Long])
case _ =>
(l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x)
}
}
}
29 changes: 27 additions & 2 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ package org.apache.spark
import scala.math.abs
import scala.collection.mutable.ArrayBuffer

import org.scalatest.FunSuite
import org.scalatest.{FunSuite, PrivateMethodTester}

import org.apache.spark.SparkContext._
import org.apache.spark.util.StatCounter
import org.apache.spark.rdd.RDD

class PartitioningSuite extends FunSuite with SharedSparkContext {
class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester {

test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
Expand Down Expand Up @@ -67,6 +67,31 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
assert(descendingP4 != p4)
}

test("RangePartitioner getPartition") {
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
// We have different behaviour of getPartition for partitions with less than 1000 and more than
// 1000 partitions.
val partitionSizes = List(1, 2, 10, 100, 500, 1000, 1500)
val partitioners = partitionSizes.map(p => (p, new RangePartitioner(p, rdd)))
val decoratedRangeBounds = PrivateMethod[Array[Int]]('rangeBounds)
partitioners.map { case (numPartitions, partitioner) =>
val rangeBounds = partitioner.invokePrivate(decoratedRangeBounds())
1.to(1000).map { element => {
val partition = partitioner.getPartition(element)
if (numPartitions > 1) {
if (partition < rangeBounds.size) {
assert(element <= rangeBounds(partition))
}
if (partition > 0) {
assert(element > rangeBounds(partition - 1))
}
} else {
assert(partition === 0)
}
}}
}
}

test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
Expand Down

0 comments on commit b0dab1b

Please sign in to comment.