Skip to content

Commit

Permalink
Expand to cover Maps returned from other Java API methods as well
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen committed Oct 15, 2014
1 parent 51c26c2 commit ae1b36f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 5 additions & 8 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.api.java

import java.io.Serializable
import java.util.{Comparator, List => JList, Map => JMap}
import java.lang.{Iterable => JIterable}

Expand Down Expand Up @@ -266,10 +265,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
mapAsJavaMap(rdd.reduceByKeyLocally(func))
mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))

/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())

/**
* :: Experimental ::
Expand All @@ -278,7 +277,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
@Experimental
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)

/**
* :: Experimental ::
Expand All @@ -288,7 +287,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
@Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)

/**
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
Expand Down Expand Up @@ -615,10 +614,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): java.util.Map[K, V] = new SerializableMapWrapper(rdd.collectAsMap())
def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())

class SerializableMapWrapper(underlying: collection.Map[K, V])
extends MapWrapper(underlying) with Serializable

/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
Expand Down
14 changes: 11 additions & 3 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.api.java

import java.io.Serializable
import java.util.{Comparator, List => JList, Iterator => JIterator}
import java.lang.{Iterable => JIterable, Long => JLong}

Expand Down Expand Up @@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))

/**
* (Experimental) Approximate version of countByValue().
Expand All @@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)

/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout).map(mapAsJavaMap)
rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)

/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
Expand Down Expand Up @@ -587,4 +588,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
rdd.foreachAsync(x => f.call(x))
}

private[java] def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
new SerializableMapWrapper(underlying)

private class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
extends MapWrapper(underlying) with Serializable


}

0 comments on commit ae1b36f

Please sign in to comment.