From 1e99f4ec5d030b80971603f090afa4e51079c5e7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 3 Jan 2019 11:10:55 +0800 Subject: [PATCH] [SPARK-26019][PYSPARK] Allow insecure py4j gateways Spark always creates secure py4j connections between java and python, but it also allows users to pass in their own connection. This restores the ability for users to pass in an _insecure_ connection, though it forces them to set the env variable 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', and still issues a warning. Added test cases verifying the failure without the extra configuration, and verifying things still work with an insecure configuration (in particular, accumulators, as those were broken with an insecure py4j gateway before). For the tests, I added ways to create insecure gateways, but I tried to put in protections to make sure that wouldn't get used incorrectly. Closes #23337 from squito/SPARK-26019. Authored-by: Imran Rashid Signed-off-by: Hyukjin Kwon --- .../api/python/PythonGatewayServer.scala | 11 +++++-- .../apache/spark/api/python/PythonRDD.scala | 6 ++-- python/pyspark/accumulators.py | 7 ++-- python/pyspark/context.py | 14 ++++++++ python/pyspark/java_gateway.py | 23 ++++++++++--- python/pyspark/tests.py | 32 +++++++++++++++++++ 6 files changed, 81 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 9ddc4a4910180..17c65f6170d67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging { // with the same secret, in case the app needs callbacks from the JVM to the underlying // python processes. val localhost = InetAddress.getLoopbackAddress() - val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() - .authToken(secret) + val builder = new GatewayServer.GatewayServerBuilder() .javaPort(0) .javaAddress(localhost) .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) - .build() + if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") { + builder.authToken(secret) + } else { + assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1", + "Creating insecure Java gateways only allowed for testing") + } + val gatewayServer: GatewayServer = builder.build() gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 5ed5070558af7..81494b167af50 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2( if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") - // send the secret just for the initial authentication when opening a new connection - socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + if (secretToken != null) { + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + } } socket } diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 00ec094e7e3b4..855d8fb4a859f 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -262,9 +262,10 @@ def authenticate_and_accum_updates(): raise Exception( "The value of the provided token to the AccumulatorServer is not correct.") - # first we keep polling till we've received the authentication token - poll(authenticate_and_accum_updates) - # now we've authenticated, don't need to check for the token anymore + if auth_token is not None: + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated if needed, don't need to check for the token anymore poll(accum_updates) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0924d3d95f044..6d99e9823f001 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) + if gateway is not None and gateway.gateway_parameters.auth_token is None: + allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0") + if allow_insecure_env == "1" or allow_insecure_env.lower() == "true": + warnings.warn( + "You are passing in an insecure Py4j gateway. This " + "presents a security risk, and will be completely forbidden in Spark 3.0") + else: + raise ValueError( + "You are trying to pass an insecure Py4j gateway to Spark. This" + " presents a security risk. If you are sure you understand and accept this" + " risk, you can set the environment variable" + " 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but" + " note this option will be removed in Spark 3.0") + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c8c5f801f89bb..feb6b7bd6aa3d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -41,8 +41,20 @@ def launch_gateway(conf=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit - :return: + :return: a JVM gateway """ + return _launch_gateway(conf) + + +def _launch_gateway(conf=None, insecure=False): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :param insecure: True to create an insecure gateway; only for testing + :return: a JVM gateway + """ + if insecure and os.environ.get("SPARK_TESTING", "0") != "1": + raise ValueError("creating insecure gateways is only for testing") if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] @@ -74,6 +86,8 @@ def launch_gateway(conf=None): env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + if insecure: + env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1" # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken @@ -116,9 +130,10 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway( - gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, - auto_convert=True)) + gateway_params = GatewayParameters(port=gateway_port, auto_convert=True) + if not insecure: + gateway_params.auth_token = gateway_secret + gateway = JavaGateway(gateway_parameters=gateway_params) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 131c51e108cad..a2d825ba36256 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -61,6 +61,7 @@ from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.java_gateway import _launch_gateway from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ @@ -2381,6 +2382,37 @@ def test_startTime(self): with SparkContext() as sc: self.assertGreater(sc.startTime, 0) + def test_forbid_insecure_gateway(self): + # By default, we fail immediately if you try to create a SparkContext + # with an insecure gateway + gateway = _launch_gateway(insecure=True) + log4j = gateway.jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + try: + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + with self.assertRaises(Exception) as context: + SparkContext(gateway=gateway) + self.assertIn("insecure Py4j gateway", str(context.exception)) + self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception)) + self.assertIn("removed in Spark 3.0", str(context.exception)) + finally: + log4j.LogManager.getRootLogger().setLevel(old_level) + + def test_allow_insecure_gateway_with_conf(self): + with SparkContext._lock: + SparkContext._gateway = None + SparkContext._jvm = None + gateway = _launch_gateway(insecure=True) + try: + os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1" + with SparkContext(gateway=gateway) as sc: + a = sc.accumulator(1) + rdd = sc.parallelize([1, 2, 3]) + rdd.foreach(lambda x: a.add(x)) + self.assertEqual(7, a.value) + finally: + os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None) + class ConfTests(unittest.TestCase): def test_memory_conf(self):