diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 76305237b03d5..545807ffbce55 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1002,7 +1002,9 @@ class SparkContext(config: SparkConf) extends Logging { require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p") } val callSite = getCallSite - val cleanedFunc = clean(func) + // There's no need to check this function for serializability, + // since it will be run right away. + val cleanedFunc = clean(func, false) logInfo("Starting job: " + callSite) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, @@ -1135,14 +1137,18 @@ class SparkContext(config: SparkConf) extends Logging { def cancelAllJobs() { dagScheduler.cancelAllJobs() } - + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) + * + * @param f closure to be cleaned and optionally serialized + * @param captureNow whether or not to serialize this closure and capture any free + * variables immediately; defaults to true. If this is set and f is not serializable, + * it will raise an exception. */ - private[spark] def clean[F <: AnyRef](f: F): F = { - ClosureCleaner.clean(f) - f + private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = { + ClosureCleaner.clean(f, captureNow) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 3437b2cac19c2..e363ea777d8eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -660,14 +660,16 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index cdbbc65292188..e474b1a850d65 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,10 +22,14 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set +import scala.reflect.ClassTag + import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import org.apache.spark.SparkException private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -101,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef) { + def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -150,6 +154,21 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(func, outer) } + + if (captureNow) { + cloneViaSerializing(func) + } else { + func + } + } + + private def cloneViaSerializing[T: ClassTag](func: T): T = { + try { + val serializer = SparkEnv.get.closureSerializer.newInstance() + serializer.deserialize[T](serializer.serialize[T](func)) + } catch { + case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString) + } } private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 12dbebcb28644..4f9300419e6f8 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } - test("failure because task closure is not serializable") { + test("failure because closure in final-stage task is not serializable") { sc = new SparkContext("local[1,1]", "test") val a = new NonSerializable @@ -118,6 +118,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) + FailureSuiteState.clear() + } + + test("failure because closure in early-stage task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in an earlier stage val thrown1 = intercept[SparkException] { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() @@ -125,6 +132,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown1.getClass === classOf[SparkException]) assert(thrown1.getMessage.contains("NotSerializableException")) + FailureSuiteState.clear() + } + + test("failure because closure in foreach task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { sc.parallelize(1 to 10, 2).foreach(x => println(a)) @@ -135,5 +149,6 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala new file mode 100644 index 0000000000000..76662264e7e94 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.NotSerializableException + +import org.scalatest.FunSuite + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkException +import org.apache.spark.SharedSparkContext + +/* A trivial (but unserializable) container for trivial functions */ +class UnserializableClass { + def op[T](x: T) = x.toString + + def pred[T](x: T) = x.toString.length % 2 == 0 +} + +class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { + + def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + + test("throws expected serialization exceptions on actions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + data.map(uc.op(_)).count + } + + assert(ex.getMessage.matches(".*Task not serializable.*")) + } + + // There is probably a cleaner way to eliminate boilerplate here, but we're + // iterating over a map from transformation names to functions that perform that + // transformation on a given RDD, creating one test case for each + + for (transformation <- + Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _, + "mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _, + "mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) { + val (name, xf) = transformation + + test(s"$name transformations throw proactive serialization exceptions") { + val (data, uc) = fixture + + val ex = intercept[SparkException] { + xf(data, uc) + } + + assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException") + } + } + + def map(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.map(y => uc.op(y)) + + def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapWith(x => x.toString)((x,y) => x + uc.op(y)) + + def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.flatMap(y=>Seq(uc.op(y))) + + def filter(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filter(y=>uc.pred(y)) + + def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filterWith(x => x.toString)((x,y) => uc.pred(y)) + + def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitions(_.map(y => uc.op(y))) + + def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) + + def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y))) + +} diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 439e5644e20a3..c635da6cacd70 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -50,6 +50,27 @@ class ClosureCleanerSuite extends FunSuite { val obj = new TestClassWithNesting(1) assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 } + + test("capturing free variables in closures at RDD definition") { + val obj = new TestCaptureVarClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } + + test("capturing free variable fields in closures at RDD definition") { + val obj = new TestCaptureFieldClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } + + test("capturing arrays in closures at RDD definition") { + val obj = new TestCaptureArrayEltClass() + val (observed, expected) = obj.run() + + assert(observed === expected) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -143,3 +164,50 @@ class TestClassWithNesting(val y: Int) extends Serializable { } } } + +class TestCaptureFieldClass extends Serializable { + class ZeroBox extends Serializable { + var zero = 0 + } + + def run(): (Int, Int) = { + val zb = new ZeroBox + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zb.zero) + + zb.zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +} + +class TestCaptureArrayEltClass extends Serializable { + def run(): (Int, Int) = { + withSpark(new SparkContext("local", "test")) {sc => + val rdd = sc.parallelize(1 to 10) + val data = Array(1, 2, 3) + val expected = data(0) + val mapped = rdd.map(x => data(0)) + data(0) = 4 + (mapped.first, expected) + } + } +} + +class TestCaptureVarClass extends Serializable { + def run(): (Int, Int) = { + var zero = 0 + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zero) + + zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 28d34dd9a1a41..c65e36636fe10 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.map { et => + graph.triplets.collect.map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d043200f71a0b..4759b629a9931 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -539,7 +539,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r))) + transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false)) } /** @@ -547,7 +547,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - val cleanedF = context.sparkContext.clean(transformFunc) + val cleanedF = context.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) @@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc) + val cleanedF = ssc.sparkContext.clean(transformFunc, false) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) } @@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc) + val cleanedF = ssc.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 2) val rdd1 = rdds(0).asInstanceOf[RDD[T]]