diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 105ef7b325ed1..855d8fb4a859f 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -262,7 +262,7 @@ def authenticate_and_accum_updates(): raise Exception( "The value of the provided token to the AccumulatorServer is not correct.") - if auth_token: + if auth_token is not None: # first we keep polling till we've received the authentication token poll(authenticate_and_accum_updates) # now we've authenticated if needed, don't need to check for the token anymore diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ad4333b5d18dd..6d99e9823f001 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -119,7 +119,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, "You are passing in an insecure Py4j gateway. This " "presents a security risk, and will be completely forbidden in Spark 3.0") else: - raise Exception( + raise ValueError( "You are trying to pass an insecure Py4j gateway to Spark. This" " presents a security risk. If you are sure you understand and accept this" " risk, you can set the environment variable" diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3ef498b62225d..feb6b7bd6aa3d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -53,8 +53,8 @@ def _launch_gateway(conf=None, insecure=False): :param insecure: True to create an insecure gateway; only for testing :return: a JVM gateway """ - if insecure and not os.environ.get("SPARK_TESTING", "0") == "1": - raise Exception("creating insecure gateways is only for testing") + if insecure and os.environ.get("SPARK_TESTING", "0") != "1": + raise ValueError("creating insecure gateways is only for testing") if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6267b682f3e2b..a2d825ba36256 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2386,11 +2386,17 @@ def test_forbid_insecure_gateway(self): # By default, we fail immediately if you try to create a SparkContext # with an insecure gateway gateway = _launch_gateway(insecure=True) - with self.assertRaises(Exception) as context: - SparkContext(gateway=gateway) - self.assertIn("insecure Py4j gateway", str(context.exception)) - self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception)) - self.assertIn("removed in Spark 3.0", str(context.exception)) + log4j = gateway.jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + try: + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + with self.assertRaises(Exception) as context: + SparkContext(gateway=gateway) + self.assertIn("insecure Py4j gateway", str(context.exception)) + self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception)) + self.assertIn("removed in Spark 3.0", str(context.exception)) + finally: + log4j.LogManager.getRootLogger().setLevel(old_level) def test_allow_insecure_gateway_with_conf(self): with SparkContext._lock: @@ -2400,7 +2406,6 @@ def test_allow_insecure_gateway_with_conf(self): try: os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1" with SparkContext(gateway=gateway) as sc: - print("sc created, about to create accum") a = sc.accumulator(1) rdd = sc.parallelize([1, 2, 3]) rdd.foreach(lambda x: a.add(x))