diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index cfba43dec3111..ad9988226470c 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.reflect.ClassTag import org.apache.spark.rdd.RDD +import org.apache.spark.util.CollectionsUtils import org.apache.spark.util.Utils /** @@ -118,12 +119,26 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V]( def numPartitions = partitions + private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] + def getPartition(key: Any): Int = { - // TODO: Use a binary search here if number of partitions is large val k = key.asInstanceOf[K] var partition = 0 - while (partition < rangeBounds.length && k > rangeBounds(partition)) { - partition += 1 + if (rangeBounds.length < 1000) { + // If we have less than 100 partitions naive search + while (partition < rangeBounds.length && k > rangeBounds(partition)) { + partition += 1 + } + } else { + // Determine which binary search method to use only once. + partition = binarySearch(rangeBounds, k) + // binarySearch either returns the match location or -[insertion point]-1 + if (partition < 0) { + partition = -partition-1 + } + if (partition > rangeBounds.length) { + partition = rangeBounds.length + } } if (ascending) { partition diff --git a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala new file mode 100644 index 0000000000000..db3db87e6618e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util + +import scala.Array +import scala.reflect._ + +object CollectionsUtils { + def makeBinarySearch[K <% Ordered[K] : ClassTag] : (Array[K], K) => Int = { + classTag[K] match { + case ClassTag.Float => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Float]], x.asInstanceOf[Float]) + case ClassTag.Double => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Double]], x.asInstanceOf[Double]) + case ClassTag.Byte => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Byte]], x.asInstanceOf[Byte]) + case ClassTag.Char => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Char]], x.asInstanceOf[Char]) + case ClassTag.Short => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Short]], x.asInstanceOf[Short]) + case ClassTag.Int => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Int]], x.asInstanceOf[Int]) + case ClassTag.Long => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Long]], x.asInstanceOf[Long]) + case _ => + (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 1374d01774693..1c5d5ea4364f5 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark import scala.math.abs import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite +import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.SparkContext._ import org.apache.spark.util.StatCounter import org.apache.spark.rdd.RDD -class PartitioningSuite extends FunSuite with SharedSparkContext { +class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) @@ -67,6 +67,31 @@ class PartitioningSuite extends FunSuite with SharedSparkContext { assert(descendingP4 != p4) } + test("RangePartitioner getPartition") { + val rdd = sc.parallelize(1.to(2000)).map(x => (x, x)) + // We have different behaviour of getPartition for partitions with less than 1000 and more than + // 1000 partitions. + val partitionSizes = List(1, 2, 10, 100, 500, 1000, 1500) + val partitioners = partitionSizes.map(p => (p, new RangePartitioner(p, rdd))) + val decoratedRangeBounds = PrivateMethod[Array[Int]]('rangeBounds) + partitioners.map { case (numPartitions, partitioner) => + val rangeBounds = partitioner.invokePrivate(decoratedRangeBounds()) + 1.to(1000).map { element => { + val partition = partitioner.getPartition(element) + if (numPartitions > 1) { + if (partition < rangeBounds.size) { + assert(element <= rangeBounds(partition)) + } + if (partition > 0) { + assert(element > rangeBounds(partition - 1)) + } + } else { + assert(partition === 0) + } + }} + } + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd)