diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c6f3b7a8494f8..f3c5b420db8a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1031,7 +1031,7 @@ class SparkContext( * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ - private[spark] def clean[F <: AnyRef](f: F): F = { + private[spark] def clean[F <: AnyRef : ClassTag](f: F): F = { clean(f, true) } @@ -1039,9 +1039,8 @@ class SparkContext( * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ - private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean): F = { + private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean): F = { ClosureCleaner.clean(f, checkSerializable) - f } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index fe33fa841a0d8..2db06548a1ac2 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,6 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set +import scala.reflect.ClassTag + import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ @@ -103,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef, checkSerializable: Boolean = true) { + def clean[F <: AnyRef : ClassTag](func: F, checkSerializable: Boolean = true): F = { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -155,12 +157,15 @@ private[spark] object ClosureCleaner extends Logging { if (checkSerializable) { ensureSerializable(func) + } else { + func } } - private def ensureSerializable(func: AnyRef) { + private def ensureSerializable[T: ClassTag](func: T) = { try { - SparkEnv.get.closureSerializer.newInstance().serialize(func) + val serializer = SparkEnv.get.closureSerializer.newInstance() + serializer.deserialize[T](serializer.serialize[T](func)) } catch { case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString) }