diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 148c12e64d2ce..8b5abb250f78d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -113,8 +113,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. * - * Use (retval & POSITION_MASK) to get the actual position, and - * (retval & EXISTENCE_MASK) != 0 for prior existence. + * Use (retval & _mask) to get the actual position, and + * (retval & NONEXISTENCE_MASK) == 0 for prior existence. * * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. @@ -151,7 +151,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( * @param moveFunc Callback invoked when we move the key from one position (in the old data array) * to a new position (in the new data array). */ - def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit = grow, + moveFunc: (Int, Int) => Unit = move) { if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index ff4a98f5dcd4a..c9880f0c4a85a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -73,6 +73,9 @@ class OpenHashSetSuite extends FunSuite with ShouldMatchers { assert(set.contains(50)) assert(set.contains(999)) assert(!set.contains(10000)) + + assert((set.addWithoutResize(50) & OpenHashSet.NONEXISTENCE_MASK) === 0) + assert((set.addWithoutResize(10000) & OpenHashSet.NONEXISTENCE_MASK) != 0) } test("primitive long") { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 1d029bf009e8c..3a52198d97b4b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.util.collection.PrimitiveVector +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ @@ -388,9 +389,21 @@ object GraphImpl { private def collectVertexIdsFromEdges( edges: EdgeRDD[_], partitioner: Partitioner): RDD[(VertexId, Int)] = { - // TODO: Consider doing map side distinct before shuffle. new ShuffledRDD[VertexId, Int, (VertexId, Int)]( - edges.collectVertexIds.map(vid => (vid, 0)), partitioner) + edges.collectVertexIds.mapPartitions { vids => + val present = new OpenHashSet[VertexId]() + vids.filter{ vid => + // This is a bit ugly but we can't just call add since add is of type unit + val isPresent = ((present.addWithoutResize(vid) & OpenHashSet.NONEXISTENCE_MASK) == 0) + if (!isPresent) { + present.rehashIfNeeded(vid) + true + } else { + false + } + }.map(vid => (vid, 0)) + }, + partitioner) .setSerializer(classOf[VertexIdMsgSerializer].getName) } } // end of object GraphImpl