Skip to content

Commit

Permalink
SPARK-729: Closures not always serialized at capture time
Browse files Browse the repository at this point in the history
[SPARK-729](https://spark-project.atlassian.net/browse/SPARK-729) concerns when free variables in closure arguments to transformations are captured.  Currently, it is possible for closures to get the environment in which they are serialized (not the environment in which they are created).  There are a few possible approaches to solving this problem and this PR will discuss some of them.  The approach I took has the advantage of being simple, obviously correct, and minimally-invasive, but it preserves something that has been bothering me about Spark's closure handling, so I'd like to discuss an alternative and get some feedback on whether or not it is worth pursuing.

## What I did

The basic approach I took depends on the work I did for #143, and so this PR is based atop that.  Specifically: #143 modifies `ClosureCleaner.clean` to preemptively determine whether or not closures are serializable immediately upon closure cleaning (rather than waiting for an job involving that closure to be scheduled).  Thus non-serializable closure exceptions will be triggered by the line defining the closure rather than triggered where the closure is used.

Since the easiest way to determine whether or not a closure is serializable is to attempt to serialize it, the code in #143 is creating a serialized closure as part of `ClosureCleaner.clean`.  `clean` currently modifies its argument, but the method in `SparkContext` that wraps it to return a value (a reference to the modified-in-place argument).  This branch modifies `ClosureCleaner.clean` so that it returns a value:  if it is cleaning a serializable closure, it returns the result of deserializing its serialized argument; therefore it is returning a closure with an environment captured at cleaning time.  `SparkContext.clean` then returns the result of `ClosureCleaner.clean`, rather than a reference to its modified-in-place argument.

I've added tests for this behavior (777a1bc).  The pull request as it stands, given the changes in #143, is nearly trivial.  There is some overhead from deserializing the closure, but it is minimal and the benefit of obvious operational correctness (vs. a more sophisticated but harder-to-validate transformation in `ClosureCleaner`) seems pretty important.  I think this is a fine way to solve this problem, but it's not perfect.

## What we might want to do

The thing that has been bothering me about Spark's handling of closures is that it seems like we should be able to statically ensure that cleaning and serialization happen exactly once for a given closure.  If we serialize a closure in order to determine whether or not it is serializable, we should be able to hang on to the generated byte buffer and use it instead of re-serializing the closure later.  By replacing closures with instances of a sum type that encodes whether or not a closure has been cleaned or serialized, we could handle clean, to-be-cleaned, and serialized closures separately with case matches.  Here's a somewhat-concrete sketch (taken from my git stash) of what this might look like:

```scala
package org.apache.spark.util

import java.nio.ByteBuffer
import scala.reflect.ClassManifest

sealed abstract class ClosureBox[T] { def func: T }
final case class RawClosure[T](func: T) extends ClosureBox[T] {}
final case class CleanedClosure[T](func: T) extends ClosureBox[T] {}
final case class SerializedClosure[T](func: T, bytebuf: ByteBuffer) extends ClosureBox[T] {}

object ClosureBoxImplicits {
  implicit def closureBoxFromFunc[T <: AnyRef](fun: T) = new RawClosure[T](fun)
}
```

With these types declared, we'd be able to change `ClosureCleaner.clean` to take a `ClosureBox[T=>U]` (possibly generated by implicit conversion) and return a `ClosureBox[T=>U]` (either a `CleanedClosure[T=>U]` or a `SerializedClosure[T=>U]`, depending on whether or not serializability-checking was enabled) instead of a `T=>U`.  A case match could thus short-circuit cleaning or serializing closures that had already been cleaned or serialized (both in `ClosureCleaner` and in the closure serializer).  Cleaned-and-serialized closures would be represented by a boxed tuple of the original closure and a serialized copy (complete with an environment quiesced at transformation time).  Additional implicit conversions could convert from `ClosureBox` instances to the underlying function type where appropriate.  Tracking this sort of state in the type system seems like the right thing to do to me.

### Why we might not want to do that

_It's pretty invasive._  Every function type used by every `RDD` subclass would have to change to reflect that they expected a `ClosureBox[T=>U]` instead of a `T=>U`.  This obscures what's going on and is not a little ugly.  Although I really like the idea of using the type system to enforce the clean-or-serialize once discipline, it might not be worth adding another layer of types (even if we could hide some of the extra boilerplate with judicious application of implicit conversions).

_It statically guarantees a property whose absence is unlikely to cause any serious problems as it stands._  It appears that all closures are currently dynamically cleaned once and it's not obvious that repeated closure-cleaning is likely to be a problem in the future.  Furthermore, serializing closures is relatively cheap, so doing it once to check for serialization and once again to actually ship them across the wire doesn't seem like a big deal.

Taken together, these seem like a high price to pay for statically guaranteeing that closures are operated upon only once.

## Other possibilities

I felt like the serialize-and-deserialize approach was best due to its obvious simplicity.  But it would be possible to do a more sophisticated transformation within `ClosureCleaner.clean`.  It might also be possible for `clean` to modify its argument in a way so that whether or not a given closure had been cleaned would be apparent upon inspection; this would buy us some of the operational benefits of the `ClosureBox` approach but not the static cleanliness.

I'm interested in any feedback or discussion on whether or not the problems with the type-based approach indeed outweigh the advantage, as well as of approaches to this issue and to closure handling in general.

Author: William Benton <willb@redhat.com>

Closes #189 from willb/spark-729 and squashes the following commits:

f4cafa0 [William Benton] Stylistic changes and cleanups
b3d9c86 [William Benton] Fixed style issues in tests
9b56ce0 [William Benton] Added array-element capture test
97e9d91 [William Benton] Split closure-serializability failure tests
12ef6e3 [William Benton] Skip proactive closure capture for runJob
8ee3ee7 [William Benton] Predictable closure environment capture
12c63a7 [William Benton] Added tests for variable capture in closures
d6e8dd6 [William Benton] Don't check serializability of DStream transforms.
4ecf841 [William Benton] Make proactive serializability checking optional.
d8df3db [William Benton] Adds proactive closure-serializablilty checking
21b4b06 [William Benton] Test cases for SPARK-897.
d5947b3 [William Benton] Ensure assertions in Graph.apply are asserted.
  • Loading branch information
willb authored and mateiz committed Apr 10, 2014
1 parent 0adc932 commit 8ca3b2b
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 14 deletions.
16 changes: 11 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,9 @@ class SparkContext(config: SparkConf) extends Logging {
require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
// There's no need to check this function for serializability,
// since it will be run right away.
val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
Expand Down Expand Up @@ -1135,14 +1137,18 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}

/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*
* @param f closure to be cleaned and optionally serialized
* @param captureNow whether or not to serialize this closure and capture any free
* variables immediately; defaults to true. If this is set and f is not serializable,
* it will raise an exception.
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
f
private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = {
ClosureCleaner.clean(f, captureNow)
}

/**
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -660,14 +660,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
sc.runJob(this, (iter: Iterator[T]) => f(iter))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}

/**
Expand Down
21 changes: 20 additions & 1 deletion core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set

import scala.reflect.ClassTag

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.SparkException

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -101,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}

def clean(func: AnyRef) {
def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -150,6 +154,21 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}

if (captureNow) {
cloneViaSerializing(func)
} else {
func
}
}

private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
val serializer = SparkEnv.get.closureSerializer.newInstance()
serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString)
}
}

private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}

test("failure because task closure is not serializable") {
test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

Expand All @@ -118,13 +118,27 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in early-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in foreach task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
Expand All @@ -135,5 +149,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}


// TODO: Need to add tests with shuffle fetch failures.
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer;

import java.io.NotSerializableException

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
import org.apache.spark.SharedSparkContext

/* A trivial (but unserializable) container for trivial functions */
class UnserializableClass {
def op[T](x: T) = x.toString

def pred[T](x: T) = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {

def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)

test("throws expected serialization exceptions on actions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
data.map(uc.op(_)).count
}

assert(ex.getMessage.matches(".*Task not serializable.*"))
}

// There is probably a cleaner way to eliminate boilerplate here, but we're
// iterating over a map from transformation names to functions that perform that
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _,
"mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
"mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
xf(data, uc)
}

assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException")
}
}

def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y => uc.op(y))

def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y) => x + uc.op(y))

def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))

def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))

def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y) => uc.pred(y))

def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y => uc.op(y)))

def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}

test("capturing free variables in closures at RDD definition") {
val obj = new TestCaptureVarClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing free variable fields in closures at RDD definition") {
val obj = new TestCaptureFieldClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing arrays in closures at RDD definition") {
val obj = new TestCaptureArrayEltClass()
val (observed, expected) = obj.run()

assert(observed === expected)
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -143,3 +164,50 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}

class TestCaptureFieldClass extends Serializable {
class ZeroBox extends Serializable {
var zero = 0
}

def run(): (Int, Int) = {
val zb = new ZeroBox

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zb.zero)

zb.zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}

class TestCaptureArrayEltClass extends Serializable {
def run(): (Int, Int) = {
withSpark(new SparkContext("local", "test")) {sc =>
val rdd = sc.parallelize(1 to 10)
val data = Array(1, 2, 3)
val expected = data(0)
val mapped = rdd.map(x => data(0))
data(0) = 4
(mapped.first, expected)
}
}
}

class TestCaptureVarClass extends Serializable {
def run(): (Int, Int) = {
var zero = 0

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zero)

zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
graph.triplets.map { et =>
graph.triplets.collect.map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,15 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
}

/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
val cleanedF = context.sparkContext.clean(transformFunc)
val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
Expand All @@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}

Expand All @@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]
Expand Down

0 comments on commit 8ca3b2b

Please sign in to comment.