diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 9ddc4a4910180..17c65f6170d67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging { // with the same secret, in case the app needs callbacks from the JVM to the underlying // python processes. val localhost = InetAddress.getLoopbackAddress() - val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() - .authToken(secret) + val builder = new GatewayServer.GatewayServerBuilder() .javaPort(0) .javaAddress(localhost) .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) - .build() + if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") { + builder.authToken(secret) + } else { + assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1", + "Creating insecure Java gateways only allowed for testing") + } + val gatewayServer: GatewayServer = builder.build() gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 5ed5070558af7..81494b167af50 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2( if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") - // send the secret just for the initial authentication when opening a new connection - socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + if (secretToken != null) { + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + } } socket } diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 00ec094e7e3b4..855d8fb4a859f 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -262,9 +262,10 @@ def authenticate_and_accum_updates(): raise Exception( "The value of the provided token to the AccumulatorServer is not correct.") - # first we keep polling till we've received the authentication token - poll(authenticate_and_accum_updates) - # now we've authenticated, don't need to check for the token anymore + if auth_token is not None: + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated if needed, don't need to check for the token anymore poll(accum_updates) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0924d3d95f044..6d99e9823f001 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) + if gateway is not None and gateway.gateway_parameters.auth_token is None: + allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0") + if allow_insecure_env == "1" or allow_insecure_env.lower() == "true": + warnings.warn( + "You are passing in an insecure Py4j gateway. This " + "presents a security risk, and will be completely forbidden in Spark 3.0") + else: + raise ValueError( + "You are trying to pass an insecure Py4j gateway to Spark. This" + " presents a security risk. If you are sure you understand and accept this" + " risk, you can set the environment variable" + " 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but" + " note this option will be removed in Spark 3.0") + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c8c5f801f89bb..feb6b7bd6aa3d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -41,8 +41,20 @@ def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit - :return: + :return: a JVM gateway """ + return _launch_gateway(conf) + + +def _launch_gateway(conf=None, insecure=False): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :param insecure: True to create an insecure gateway; only for testing + :return: a JVM gateway + """ + if insecure and os.environ.get("SPARK_TESTING", "0") != "1": + raise ValueError("creating insecure gateways is only for testing") if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] @@ -74,6 +86,8 @@ def launch_gateway(conf=None): env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + if insecure: + env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1" # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken @@ -116,9 +130,10 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway( - gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, - auto_convert=True)) + gateway_params = GatewayParameters(port=gateway_port, auto_convert=True) + if not insecure: + gateway_params.auth_token = gateway_secret + gateway = JavaGateway(gateway_parameters=gateway_params) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 131c51e108cad..a2d825ba36256 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -61,6 +61,7 @@ from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.java_gateway import _launch_gateway from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ @@ -2381,6 +2382,37 @@ def test_startTime(self): with SparkContext() as sc: self.assertGreater(sc.startTime, 0) + def test_forbid_insecure_gateway(self): + # By default, we fail immediately if you try to create a SparkContext + # with an insecure gateway + gateway = _launch_gateway(insecure=True) + log4j = gateway.jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + try: + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + with self.assertRaises(Exception) as context: + SparkContext(gateway=gateway) + self.assertIn("insecure Py4j gateway", str(context.exception)) + self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception)) + self.assertIn("removed in Spark 3.0", str(context.exception)) + finally: + log4j.LogManager.getRootLogger().setLevel(old_level) + + def test_allow_insecure_gateway_with_conf(self): + with SparkContext._lock: + SparkContext._gateway = None + SparkContext._jvm = None + gateway = _launch_gateway(insecure=True) + try: + os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1" + with SparkContext(gateway=gateway) as sc: + a = sc.accumulator(1) + rdd = sc.parallelize([1, 2, 3]) + rdd.foreach(lambda x: a.add(x)) + self.assertEqual(7, a.value) + finally: + os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None) + class ConfTests(unittest.TestCase): def test_memory_conf(self):