Skip to content

Commit

Permalink
use socket to transfer data from JVM
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 6, 2015
1 parent 9517c8f commit ba54614
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 58 deletions.
48 changes: 33 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

import scala.util.control.NonFatal

private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
Expand Down Expand Up @@ -340,29 +342,28 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
* This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
allowLocal: Boolean): Iterator[Array[Byte]] = {
allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
flattenedPartition.iterator
serveIterator(flattenedPartition.iterator)
}

/**
* A helper function to collect an RDD as an iterator, then it only export the Iterator
* object to Py4j, easily be GCed.
* A helper function to collect an RDD as an iterator, then serve it via socket
*/
def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = {
jrdd.collect().iterator()
def collectAndServe[T](rdd: RDD[T]): Int = {
serveIterator(rdd.collect().iterator)
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -582,15 +583,32 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}

def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
}
private def serveIterator[T](items: Iterator[T]): Int = {
val serverSocket = new ServerSocket(0, 1)
serverSocket.setReuseAddress(true)
serverSocket.setSoTimeout(3000)

new Thread("serve iterator") {
setDaemon(true)
override def run() {
try {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
writeIteratorToStream(items, out)
} finally {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator: $e")
} finally {
serverSocket.close()
}
}
}.start()

def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
writeIteratorToStream(items, file)
file.close()
serverSocket.getLocalPort
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
19 changes: 3 additions & 16 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from threading import Lock
from tempfile import NamedTemporaryFile

from py4j.java_gateway import JavaObject
from py4j.java_collections import ListConverter
import py4j.protocol

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -34,7 +32,7 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
Expand All @@ -51,15 +49,6 @@
}


# The implementation in Py4j will create 'Java' member for parameter (JavaObject)
# because of circular reference between JavaObject and JavaMember, then the object
# can not be released after used until GC kick-in.
def is_python_proxy(parameter):
return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter)
_old_is_python_proxy = py4j.protocol.is_python_proxy
py4j.protocol.is_python_proxy = is_python_proxy


class SparkContext(object):

"""
Expand All @@ -70,7 +59,6 @@ class SparkContext(object):

_gateway = None
_jvm = None
_writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
Expand Down Expand Up @@ -232,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile

if instance:
if (SparkContext._active_spark_context and
Expand Down Expand Up @@ -851,8 +838,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))

def show_profiles(self):
""" Print the profile stats to stdout """
Expand Down
30 changes: 14 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
from subprocess import Popen, PIPE
Expand All @@ -29,6 +28,7 @@
import heapq
import bisect
import random
import socket
from math import sqrt, log, isinf, isnan, pow, ceil

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
Expand Down Expand Up @@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(port, serializer):
sock = socket.socket()
try:
sock.connect(("localhost", port))
rf = sock.makefile("rb", 65536)
for item in serializer.load_stream(rf):
yield item
finally:
sock.close()


class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
Expand Down Expand Up @@ -698,21 +709,8 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd)
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
return list(_load_from_socket(port, self._jrdd_deserializer))

def reduce(self, f):
"""
Expand Down
14 changes: 3 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import itertools
import warnings
import random
import os
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -310,14 +308,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython())
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
self._sc._writeToFile(bytesInJava, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]

Expand Down

0 comments on commit ba54614

Please sign in to comment.