diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index aabbbd958080a..7f99d38771ce8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -18,7 +18,7 @@ import sys from py4j.java_collections import ListConverter -from py4j.java_gateway import java_import +from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer @@ -38,6 +38,8 @@ def _daemonize_callback_server(): from exiting if it's not shutdown. The following code replace `start()` of CallbackServer with a new version, which set daemon=True for this thread. + + Also, it will update the port number (0) with real port """ # TODO: create a patch for Py4J import socket @@ -54,8 +56,11 @@ def start(self): 1) try: self.server_socket.bind((self.address, self.port)) - except Exception: - msg = 'An error occurred while trying to start the callback server' + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e logger.exception(msg) raise Py4JNetworkError(msg) @@ -105,15 +110,24 @@ def _jduration(self, seconds): def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway - # start callback server - # getattr will fallback to JVM - if "_callback_server" not in gw.__dict__: - _daemonize_callback_server() - gw._start_callback_server(gw._python_proxy_port) java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 96b84b45b2ebf..e171fb5730616 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,6 +24,8 @@ import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.language.existentials +import py4j.GatewayServer + import org.apache.spark.api.java._ import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD @@ -88,10 +90,14 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun */ private[python] object PythonTransformFunctionSerializer { - // A serializer in Python, used to serialize PythonTransformFunction + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ private var serializer: PythonTransformFunctionSerializer = _ - // Register a serializer from Python, should be called during initialization + /* + * Register a serializer from Python, should be called during initialization + */ def register(ser: PythonTransformFunctionSerializer): Unit = { serializer = ser } @@ -117,20 +123,36 @@ private[python] object PythonTransformFunctionSerializer { */ private[python] object PythonDStream { - // can not access PythonTransformFunctionSerializer.register() via Py4j - // Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { PythonTransformFunctionSerializer.register(ser) } - // helper function for DStream.foreachRDD(), - // cannot be `foreachRDD`, it will confusing py4j + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { val func = new TransformFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } - // convert list of RDD into queue of RDDs, for ssc.queueStream() + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] rdds.forall(queue.add(_))