From c89146475534d8e284afb24dfeb6d437bde158ec Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 17 Dec 2018 15:52:40 -0600 Subject: [PATCH] [SPARK-26019][PYSPARK] Allow insecure py4j gateways Spark always creates secure py4j connections between java and python, but it also allows users to pass in their own connection. This restores the ability for users to pass in an _insecure_ connection, though it forces them to set 'spark.python.allowInsecurePy4j=true' and still issues a warning. Added test cases verifying the failure without the extra configuration, and verifying things still work with an insecure configuration (in particular, accumulators, as those were broken with an insecure py4j gateway before). --- .../api/python/PythonGatewayServer.scala | 11 +++++-- .../apache/spark/api/python/PythonRDD.scala | 6 ++-- python/pyspark/accumulators.py | 7 +++-- python/pyspark/context.py | 11 +++++++ python/pyspark/java_gateway.py | 19 ++++++++++-- python/pyspark/tests.py | 29 +++++++++++++++++++ 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 9ddc4a4910180..a8256e6491f4c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging { // with the same secret, in case the app needs callbacks from the JVM to the underlying // python processes. val localhost = InetAddress.getLoopbackAddress() - val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() - .authToken(secret) + val builder = new GatewayServer.GatewayServerBuilder() .javaPort(0) .javaAddress(localhost) .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) - .build() + if (sys.env.getOrElse("_PYSPARK_INSECURE_GATEWAY", "0") != "1") { + builder.authToken(secret) + } else { + assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1", + "Creating insecure java gateways only allowed for testing") + } + val gatewayServer: GatewayServer = builder.build() gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 5ed5070558af7..81494b167af50 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2( if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") - // send the secret just for the initial authentication when opening a new connection - socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + if (secretToken != null) { + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + } } socket } diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 00ec094e7e3b4..105ef7b325ed1 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -262,9 +262,10 @@ def authenticate_and_accum_updates(): raise Exception( "The value of the provided token to the AccumulatorServer is not correct.") - # first we keep polling till we've received the authentication token - poll(authenticate_and_accum_updates) - # now we've authenticated, don't need to check for the token anymore + if auth_token: + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated if needed, don't need to check for the token anymore poll(accum_updates) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0924d3d95f044..94a6e4ce8e361 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -112,6 +112,17 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) + if gateway != None and gateway.gateway_parameters.auth_token == None: + if conf and conf.get("spark.python.allowInsecurePy4j", "false") == "true": + print("****BAM****") + warnings.warn("You are passing in an insecure py4j gateway. This " + "presents a security risk, and will be completely forbidden in Spark 3.0") + else: + raise Exception("You are trying to pass an insecure py4j gateway to spark. This" + " presents a security risk. If you are sure you understand and accept this" + " risk, you can add the conf 'spark.python.allowInsecurePy4j=true', but" + " note this option will be removed in Spark 3.0") + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c8c5f801f89bb..ce261b9fcad04 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,6 +43,16 @@ def launch_gateway(conf=None): :param conf: spark configuration passed to spark-submit :return: """ + return _launch_gateway(conf) + +def _launch_gateway(conf=None, insecure=False): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :return: + """ + if insecure and not os.environ.get("SPARK_TESTING", "0") == "1": + raise Exception("creating insecure gateways is only for testing") if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] @@ -74,6 +84,8 @@ def launch_gateway(conf=None): env = dict(os.environ) env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + if insecure: + env["_PYSPARK_INSECURE_GATEWAY"] = "1" # Launch the Java gateway. # We open a pipe to stdin so that the Java gateway can die when the pipe is broken @@ -116,9 +128,10 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway( - gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, - auto_convert=True)) + gateway_params = GatewayParameters(port=gateway_port, auto_convert=True) + if not insecure: + gateway_params.auth_token=gateway_secret + gateway = JavaGateway(gateway_parameters=gateway_params) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 131c51e108cad..904c2f241d435 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -61,6 +61,7 @@ from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.java_gateway import _launch_gateway from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ @@ -2381,6 +2382,34 @@ def test_startTime(self): with SparkContext() as sc: self.assertGreater(sc.startTime, 0) + def test_forbid_insecure_gateway(self): + # By default, we fail immediately if you try to create a SparkContext + # with an insecure gateway + gateway = _launch_gateway(insecure=True) + with self.assertRaises(Exception) as context: + SparkContext(gateway=gateway) + self.assertIn("insecure py4j gateway", context.exception.message) + self.assertIn("spark.python.allowInsecurePy4j", context.exception.message) + self.assertIn("removed in Spark 3.0", context.exception.message) + + def test_allow_insecure_gateway_with_conf(self): + with SparkContext._lock: + SparkContext._gateway = None + SparkContext._jvm = None + gateway = _launch_gateway(insecure=True) + conf = SparkConf() + conf.set("spark.python.allowInsecurePy4j", "true") + print("entering allow insecure test") + with SparkContext(conf=conf, gateway=gateway) as sc: + print("sc created, about to create accum") + a = sc.accumulator(1) + rdd = sc.parallelize([1,2,3]) + def f(x): + a.add(x) + rdd.foreach(f) + self.assertEqual(7, a.value) + print("exiting allow insecure test") + class ConfTests(unittest.TestCase): def test_memory_conf(self):