Skip to content

Commit

Permalink
fix memory leak in collect()
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 6, 2015
1 parent eb48fd6 commit 9517c8f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
19 changes: 13 additions & 6 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}

import org.apache.spark.input.PortableDataStream
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}

import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -358,6 +357,14 @@ private[spark] object PythonRDD extends Logging {
flattenedPartition.iterator
}

/**
* A helper function to collect an RDD as an iterator, then it only export the Iterator
* object to Py4j, easily be GCed.
*/
def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = {
jrdd.collect().iterator()
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
Expand Down
15 changes: 13 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from threading import Lock
from tempfile import NamedTemporaryFile

from py4j.java_gateway import JavaObject
from py4j.java_collections import ListConverter
import py4j.protocol

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
Expand All @@ -35,8 +39,6 @@
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler

from py4j.java_collections import ListConverter


__all__ = ['SparkContext']

Expand All @@ -49,6 +51,15 @@
}


# The implementation in Py4j will create 'Java' member for parameter (JavaObject)
# because of circular reference between JavaObject and JavaMember, then the object
# can not be released after used until GC kick-in.
def is_python_proxy(parameter):
return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter)
_old_is_python_proxy = py4j.protocol.is_python_proxy
py4j.protocol.is_python_proxy = is_python_proxy


class SparkContext(object):

"""
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd)
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.javaToPython().collect().iterator()
bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython())
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
self._sc._writeToFile(bytesInJava, tempFile.name)
Expand Down

0 comments on commit 9517c8f

Please sign in to comment.