Skip to content

Commit

Permalink
updated to use environment variable to allow insecure gateways
Browse files Browse the repository at this point in the history
  • Loading branch information
squito committed Dec 19, 2018
1 parent e83b160 commit 9cc545b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] object PythonGatewayServer extends Logging {
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
if (sys.env.getOrElse("_PYSPARK_INSECURE_GATEWAY", "0") != "1") {
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
builder.authToken(secret)
} else {
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
"""
self._callsite = first_spark_call() or CallSite(None, None, None)
if gateway is not None and gateway.gateway_parameters.auth_token is None:
if conf and conf.get("spark.python.allowInsecurePy4j", "false") == "true":
allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
if allow_insecure_env == "1" or allow_insecure_env.lower() == "true":
warnings.warn(
"You are passing in an insecure Py4j gateway. This "
"presents a security risk, and will be completely forbidden in Spark 3.0")
else:
raise Exception(
"You are trying to pass an insecure Py4j gateway to Spark. This"
" presents a security risk. If you are sure you understand and accept this"
" risk, you can add the conf 'spark.python.allowInsecurePy4j=true', but"
" risk, you can set the environment variable"
" 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
" note this option will be removed in Spark 3.0")

SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _launch_gateway(conf=None, insecure=False):
"""
launch jvm gateway
:param conf: spark configuration passed to spark-submit
:param insecure: True to create an insecure gateway; only for testing
:return: a JVM gateway
"""
if insecure and not os.environ.get("SPARK_TESTING", "0") == "1":
Expand Down Expand Up @@ -86,7 +87,7 @@ def _launch_gateway(conf=None, insecure=False):
env = dict(os.environ)
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
if insecure:
env["_PYSPARK_INSECURE_GATEWAY"] = "1"
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"

# Launch the Java gateway.
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
Expand Down
20 changes: 11 additions & 9 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2389,22 +2389,24 @@ def test_forbid_insecure_gateway(self):
with self.assertRaises(Exception) as context:
SparkContext(gateway=gateway)
self.assertIn("insecure Py4j gateway", str(context.exception))
self.assertIn("spark.python.allowInsecurePy4j", str(context.exception))
self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception))
self.assertIn("removed in Spark 3.0", str(context.exception))

def test_allow_insecure_gateway_with_conf(self):
with SparkContext._lock:
SparkContext._gateway = None
SparkContext._jvm = None
gateway = _launch_gateway(insecure=True)
conf = SparkConf()
conf.set("spark.python.allowInsecurePy4j", "true")
with SparkContext(conf=conf, gateway=gateway) as sc:
print("sc created, about to create accum")
a = sc.accumulator(1)
rdd = sc.parallelize([1, 2, 3])
rdd.foreach(lambda x: a.add(x))
self.assertEqual(7, a.value)
try:
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
with SparkContext(gateway=gateway) as sc:
print("sc created, about to create accum")
a = sc.accumulator(1)
rdd = sc.parallelize([1, 2, 3])
rdd.foreach(lambda x: a.add(x))
self.assertEqual(7, a.value)
finally:
os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)


class ConfTests(unittest.TestCase):
Expand Down

0 comments on commit 9cc545b

Please sign in to comment.