diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index bc0be07bfb36e..5d46b92e27b5c 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -110,10 +110,14 @@ def _deserialize_accumulator(aid, zero_value, accum_param): from pyspark.accumulators import _accumulatorRegistry - accum = Accumulator(aid, zero_value, accum_param) - accum._deserialized = True - _accumulatorRegistry[aid] = accum - return accum + # If this certain accumulator was deserialized, don't overwrite it. + if aid in _accumulatorRegistry: + return _accumulatorRegistry[aid] + else: + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum class Accumulator(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3c5fc97c921bc..dc5ed198f4c50 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2932,6 +2932,31 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) + + test_accum = self.sc.accumulator(0) + + def first_udf(x): + test_accum.add(1) + return x + + def second_udf(x): + test_accum.add(100) + return x + + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + class HiveSparkSubmitTests(SparkSubmitTests):