diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b3b60578c92e8..2ee29549ea48d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -741,9 +741,10 @@ abstract class RDD[T: ClassTag]( def mapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = withScope { + val cleanF = sc.clean(f) mapPartitionsWithIndex((index, iter) => { val a = constructA(index) - iter.map(t => f(t, a)) + iter.map(t => cleanF(t, a)) }, preservesPartitioning) } @@ -756,9 +757,10 @@ abstract class RDD[T: ClassTag]( def flatMapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = withScope { + val cleanF = sc.clean(f) mapPartitionsWithIndex((index, iter) => { val a = constructA(index) - iter.flatMap(t => f(t, a)) + iter.flatMap(t => cleanF(t, a)) }, preservesPartitioning) } @@ -769,9 +771,10 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { + val cleanF = sc.clean(f) mapPartitionsWithIndex { (index, iter) => val a = constructA(index) - iter.map(t => {f(t, a); t}) + iter.map(t => {cleanF(t, a); t}) } } @@ -901,7 +904,8 @@ abstract class RDD[T: ClassTag]( * Return an RDD that contains all matching values by applying `f`. */ def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope { - filter(f.isDefinedAt).map(f) + val cleanF = sc.clean(f) + filter(cleanF.isDefinedAt).map(cleanF) } /**