From 476770babc4aece60e08d9dab5bd46c739ec9e66 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 17 Sep 2014 11:19:23 -0700 Subject: [PATCH] ShippableVertexPartition.initFrom: Don't run mergeFunc on default values --- .../org/apache/spark/graphx/VertexRDD.scala | 2 -- .../impl/ShippableVertexPartition.scala | 20 ++++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index befa3b7398eb3..210217df55ec2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -406,8 +406,6 @@ object VertexRDD { * @param edges the [[EdgeRDD]] that these vertices may be joined with * @param defaultVal the vertex attribute to use when creating missing vertices * @param mergeFunc the commutative, associative duplicate vertex attribute merge function - * note that all vertices with default value created upon construction in VertexPartition - * so it will appear as b in (a, b) pair for mergeFunc. */ def apply[VD: ClassTag]( vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index f0834c317fcdb..5412d720475dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -44,7 +44,7 @@ object ShippableVertexPartition { */ def apply[VD: ClassTag]( iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD) - : ShippableVertexPartition[VD] = + : ShippableVertexPartition[VD] = apply(iter, routingTable, defaultVal, (a, b) => a) /** @@ -54,12 +54,18 @@ object ShippableVertexPartition { */ def apply[VD: ClassTag]( iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD, - mergeFunc: (VD, VD) => VD - ) - : ShippableVertexPartition[VD] = { - val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal)) - val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, mergeFunc) - new ShippableVertexPartition(index, values, mask, routingTable) + mergeFunc: (VD, VD) => VD): ShippableVertexPartition[VD] = { + val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] + // Merge the given vertices using mergeFunc + iter.foreach { pair => + map.setMerge(pair._1, pair._2, mergeFunc) + } + // Fill in missing vertices mentioned in the routing table + routingTable.iterator.foreach { vid => + map.changeValue(vid, defaultVal, identity) + } + + new ShippableVertexPartition(map.keySet, map._values, map.keySet.getBitSet, routingTable) } import scala.language.implicitConversions