Skip to content

Commit

Permalink
use random port for callback server
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 2, 2014
1 parent d05871e commit 37fe06f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
30 changes: 22 additions & 8 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys

from py4j.java_collections import ListConverter
from py4j.java_gateway import java_import
from py4j.java_gateway import java_import, JavaObject

from pyspark import RDD, SparkConf
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
Expand All @@ -38,6 +38,8 @@ def _daemonize_callback_server():
from exiting if it's not shutdown. The following code replace `start()`
of CallbackServer with a new version, which set daemon=True for this
thread.
Also, it will update the port number (0) with real port
"""
# TODO: create a patch for Py4J
import socket
Expand All @@ -54,8 +56,11 @@ def start(self):
1)
try:
self.server_socket.bind((self.address, self.port))
except Exception:
msg = 'An error occurred while trying to start the callback server'
if not self.port:
# update port with real port
self.port = self.server_socket.getsockname()[1]
except Exception as e:
msg = 'An error occurred while trying to start the callback server: %s' % e
logger.exception(msg)
raise Py4JNetworkError(msg)

Expand Down Expand Up @@ -105,15 +110,24 @@ def _jduration(self, seconds):
def _ensure_initialized(cls):
SparkContext._ensure_initialized()
gw = SparkContext._gateway
# start callback server
# getattr will fallback to JVM
if "_callback_server" not in gw.__dict__:
_daemonize_callback_server()
gw._start_callback_server(gw._python_proxy_port)

java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")

# start callback server
# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__:
_daemonize_callback_server()
# use random port
gw._start_callback_server(0)
# gateway with real port
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)

# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.language.existentials

import py4j.GatewayServer

import org.apache.spark.api.java._
import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -88,10 +90,14 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun
*/
private[python] object PythonTransformFunctionSerializer {

// A serializer in Python, used to serialize PythonTransformFunction
/**
* A serializer in Python, used to serialize PythonTransformFunction
*/
private var serializer: PythonTransformFunctionSerializer = _

// Register a serializer from Python, should be called during initialization
/*
* Register a serializer from Python, should be called during initialization
*/
def register(ser: PythonTransformFunctionSerializer): Unit = {
serializer = ser
}
Expand All @@ -117,20 +123,36 @@ private[python] object PythonTransformFunctionSerializer {
*/
private[python] object PythonDStream {

// can not access PythonTransformFunctionSerializer.register() via Py4j
// Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
/**
* can not access PythonTransformFunctionSerializer.register() via Py4j
* Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
*/
def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
PythonTransformFunctionSerializer.register(ser)
}

// helper function for DStream.foreachRDD(),
// cannot be `foreachRDD`, it will confusing py4j
/**
* Update the port of callback client to `port`
*/
def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
val cl = gws.getCallbackClient
val f = cl.getClass.getDeclaredField("port")
f.setAccessible(true)
f.setInt(cl, port)
}

/**
* helper function for DStream.foreachRDD(),
* cannot be `foreachRDD`, it will confusing py4j
*/
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
val func = new TransformFunction((pfunc))
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}

// convert list of RDD into queue of RDDs, for ssc.queueStream()
/**
* convert list of RDD into queue of RDDs, for ssc.queueStream()
*/
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
rdds.forall(queue.add(_))
Expand Down

0 comments on commit 37fe06f

Please sign in to comment.